TensorFlowstatic_rnn 和dynamic_rnn的区别
tensorflow中提供了rnn接口有两种,一种是静态的rnn,一种是动态的rnn
通常用法:
1、静态接口:static_rnn
主要使用 tf.contrib.rnn
x = tf.placeholder("float", [None, n_steps, n_input])
x1 = tf.unstack(x, n_steps, 1)
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
静态 rnn 的意思就是在图中创建一个固定长度(n_steps)的网络。这将导致
缺点:
- 生成过程耗时更长,占内存更多,导出的模型更大;
- 无法传递比最初指定的更长的序列(> n_steps)。
优点:
模型中带有某个序列中间台的信息,便与调试。
2、动态接口:dynamic_rnn
主要使用 tf.nn.dynamic_rnn
x = tf.placeholder("float", [None, n_steps, n_input])
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs,_ = tf.nn.dynamic_rnn(lstm_cell ,x,dtype=tf.float32)
outputs = tf.transpose(outputs, [1, 0, 2])
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
动态的tf.nn.dynamic_rnn被执行时,它使用循环来动态构建图形。这意味着
优点:
- 图形创建速度更快,占用内存更少;
- 并且可以提供可变大小的批处理。
缺点:
- 模型中只有最后的状态。
动态rnn的意思是只创建样本中的一个序列RNN,其他序列数据会通过循环进入该RNN运算
区别:
1、输入输出不同:
dynamic_rnn实现的功能就是可以让不同迭代传入的batch可以是长度不同数据,但同一次迭代一个batch内部的所有数据长度仍然是固定的。例如,第一时刻传入的数据shape=[batch_size, 10],第二时刻传入的数据shape=[batch_size, 12],第三时刻传入的数据shape=[batch_size, 8]等等。
但是static_rnn不能这样,它要求每一时刻传入的batch数据的[batch_size, max_seq],在每次迭代过程中都保持不变。
2、训练方式不同:
具体参见参考文献1
多层LSTM的代码实现对比:
1、静态多层RNN
import tensorflow as tf
# 导入 MINST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("c:/user/administrator/data/", one_hot=True)
n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类)
batch_size = 128
tf.reset_default_graph()
# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])
gru = tf.contrib.rnn.GRUCell(n_hidden*2)
lstm_cell = tf.contrib.rnn.LSTMCell(n_hidden)
mcell = tf.contrib.rnn.MultiRNNCell([lstm_cell,gru])
x1 = tf.unstack(x, n_steps, 1)
outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
learning_rate = 0.001
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
training_iters = 100000
display_step = 10
# 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " str(step*batch_size) ", Minibatch Loss= " \
"{:.6f}".format(loss) ", Training Accuracy= " \
"{:.5f}".format(acc))
step = 1
print (" Finished!")
# 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", \
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
2、动态多层RNN
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("c:/user/administrator/data/", one_hot=True)
n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类)
batch_size = 128
tf.reset_default_graph()
# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])
gru = tf.contrib.rnn.GRUCell(n_hidden*2)
lstm_cell = tf.contrib.rnn.LSTMCell(n_hidden)
mcell = tf.contrib.rnn.MultiRNNCell([lstm_cell,gru])
outputs,states = tf.nn.dynamic_rnn(mcell,x,dtype=tf.float32)#(?, 28, 256)
outputs = tf.transpose(outputs, [1, 0, 2])#(28, ?, 256) 28个时序,取最后一个时序outputs[-1]=(?,256)
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
learning_rate = 0.001
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
training_iters = 100000
display_step = 10
# 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " str(step*batch_size) ", Minibatch Loss= " \
"{:.6f}".format(loss) ", Training Accuracy= " \
"{:.5f}".format(acc))
step = 1
print (" Finished!")
# 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", \
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
【参考文献】:
1、https://www.jianshu.com/p/1b1ea45fab47
2、What's the difference between tensorflow dynamic_rnn and rnn?
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhibegeb
系列文章
更多
同类精品
更多
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
微信运动停用后别人还能看到步数吗
PHP中文网 07-22 -
excel打印预览压线压字怎么办
PHP中文网 06-22