In [1]:
import tensorflow as tf
import numpy as np
from show import show_graph
In [2]:
# 显存管理
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'    # 指定第一块GPU可用
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5    # 最多允许占用50%显存
config.gpu_options.allow_growth = True      # 按需申请显存

变长输入

相关参数设定

In [4]:
n_inputs = 3
n_neurons = 5
n_steps = 2

构造计算图

In [5]:
# 增加一个占位符,输入 输入的序列 的长度
seq_length = tf.placeholder(tf.int32, [None])
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
In [6]:
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
# 增加一个参数sequence_length=seq_length
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32, sequence_length=seq_length)
In [9]:
# 全局初始化器
init = tf.global_variables_initializer()

模拟数据,进行训练

In [7]:
# 注意:instance1的长度为1,但是需要在step1补0
X_batch = np.array([
    # step 0      step 1
    [[0, 1, 2], [9, 8, 7]], # instance 0
    [[3, 4, 5], [0, 0, 0]], # instance 1 (padded with a zero vector)
    [[6, 7, 8], [6, 5, 4]], # instance 2
    [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])
In [10]:
with tf.Session(config=config) as sess:
    init.run()
    outputs_val, states_val = sess.run(
            [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
In [11]:
# instance1 - step1 的输出均为0
print(outputs_val)
[[[ 0.66841251  0.56325114 -0.23545252 -0.47455621 -0.69435865]
  [ 1.          0.99976754  0.80224395  0.63957655 -0.99999845]]

 [[ 0.99946588  0.98103112 -0.08152062 -0.20529182 -0.99780762]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.99999928  0.99934387  0.07640119  0.09910619 -0.99998665]
  [ 0.99993366  0.97914189  0.7965588   0.32067806 -0.99989229]]

 [[ 0.99999112  0.85536873  0.97974646  0.99966168 -0.99995744]
  [ 0.81988484  0.02078043  0.72537458  0.30290079 -0.99031717]]]
In [12]:
print(states_val)
[[ 1.          0.99976754  0.80224395  0.63957655 -0.99999845]
 [ 0.99946588  0.98103112 -0.08152062 -0.20529182 -0.99780762]
 [ 0.99993366  0.97914189  0.7965588   0.32067806 -0.99989229]
 [ 0.81988484  0.02078043  0.72537458  0.30290079 -0.99031717]]

变长输出

如果知道输出的长度(比如与输入等长一类的信息),可以直接指定一个 sequence_length 参数;
但大多数情况下是没有这种信息的,这时一般会定义一个特定的标志(End-of-Sequence, EOS),让RNN网络一直产生输出,直到输出EOS或者长度超出限制