博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
循环神经网络(LSTM和GRU)(2)
阅读量:4934 次
发布时间:2019-06-11

本文共 5658 字,大约阅读时间需要 18 分钟。

1、tf.nn.dynamic_rnn()函数

参考:http://www.360doc.com/content/17/0321/10/10408243_638692495.shtml

参考:https://blog.csdn.net/u010089444/article/details/60963053

参考:https://blog.csdn.net/u010223750/article/details/71079036

在用rnn处理长文本时,使用dynamic_rnn()可以跳过padding部分的计算,从而减少计算量。假设有两个文本,一个为10,一个为5,那么需要对第二文本进行0-padding填充,得到shape维度[2,10,dim],其中dim为词向量的维度,

使用dynamic_run的代码如下:

outputs,last_states=tf.nn.dynamic_rnn(cell=cell,dtype=tf.float32,sequence_length=x_lengths,inputs=x)

其中cell为RNN节点,比如tf.contrib.rnn.BasicLSTMCell,x是0-padding之后的数据,x_lengths是每个文本的长度。

tf.nn.dynamic_rnn()返回两个变量,第一个是每个step的输出值,以上面的例子为例,得到的维度则为:[2,10,dim],第二个是最终的状态,是由(c,h)组成的tuple,均为[batch,dim],其中dynamic有个参数sequence_length,用来指定每个example的长度,

以如下为例:

import tensorflow as tfimport numpy as np# 创建输入数据X = np.random.randn(2, 10, 8)       #其中2为batch_size,10为文本最大长度,8为embedding_size# 第二个example长度为6   #randn从标准正态分布返回值,rand从[0,1)返回值X[1,6:] = 0X_lengths = [10, 6]with tf.variable_scope('c', reuse=None) as scope:          #根据调用是否首次来调整reuse为True或False    cell = tf.contrib.rnn.BasicLSTMCell(num_units=64, state_is_tuple=True)    #num_units指的是一个cell中神经元的个数#循环层的cell数目表示:X_split = tf.split(XR, time_step_size, 0)  (X_split中划分出来的arrays数量为循环层的cell个数) #在任意时刻,LSTM cell会产生两个内部状态c和h,当state_is_tuple=True时,该状态的c和h是分开记录的,放在一个二元tuple返回,维度均为[batch,embedding_size];如果state_is_tuple=False时,两个状态就按列连接起来返回,     outputs, last_states = tf.nn.dynamic_rnn(             #outputs返回的维度是[2,10,8],last_states返回的维度是[2,64]            cell=cell,                    #tf.nn.dynamic_rnn用于实现不同迭代传入的batch可以是长度不同的数据,但是同一次迭代一个batch内部所有的数据长度仍然是固定的。也即第一次传入为[batch_size,10],第二次[batch_size,8],第三次[batch_size,11]            dtype=tf.float64,                #然而对于rnn来说,每次输入的则是最大长度,也即[batch_size,max_seq]            sequence_length=X_lengths,       #假若设置第二个example的有效长度为6,当传入这个参数时,tensorflow对6之后的padding就不计算了,其last_states将重复第6步的计算到末尾,而outputs中超过6步的结果会被置零。            inputs=X)    result = tf.contrib.learn.run_n(            {
"outputs": outputs, "last_states": last_states}, n=1, feed_dict=None) print (result[0]) assert result[0]["outputs"].shape == (2, 10, 64) # 第二个example中的outputs超过6步(7-10步)的值应该为0 assert (result[0]["outputs"][1,7,:] == np.zeros(cell.output_size)).all()

 2、tf.contrib.rnn.BasicLSTMCell()函数和tf.contrib.rnn.BasicRNNCell()函数

参考:https://blog.csdn.net/u013082989/article/details/73693392

参考:https://blog.csdn.net/u010089444/article/details/60963053

BasicRNNCell是最基本的RNN cell单元,输入参数包括:

num_units:RNN层神经元的个数

activation:内部状态间的激活函数

reuse:决定是否重用现有域的变量

BasicLSTMCell是最基本的LSTM循环神经网络单元,输入参数包括:

num_units:LSTM cell层中的单元数

forget_bias:forget gates中的偏置

state_is_tuple:返回(c_state,m_state)二元组

activation:状态间转移的激活函数

reuse:是否重用变量

 以MNIST手写体数据集为例:

from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tffrom tensorflow.contrib import rnnimport numpy as npinput_vec_size = lstm_size = 28 # 输入向量的维度time_step_size = 28 # 循环层长度batch_size = 128test_size = 256def init_weights(shape):    return tf.Variable(tf.random_normal(shape, stddev=0.01))def model(X, W, B, lstm_size):    # X, input shape: (batch_size, time_step_size, input_vec_size)    # XT shape: (time_step_size, batch_size, input_vec_size)    XT = tf.transpose(X, [1, 0, 2])  # permute time_step_size and batch_size,[28, 128, 28]    # XR shape: (time_step_size * batch_size, input_vec_size)    XR = tf.reshape(XT, [-1, lstm_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)    # Each array shape: (batch_size, input_vec_size)    X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size (28 arrays),shape = [(128, 28),(128, 28)...]    # Make lstm with lstm_size (each input vector size). num_units=lstm_size; forget_bias=1.0    lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)    # Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)    # rnn..static_rnn()的输出对应于每一个timestep,如果只关心最后一步的输出,取outputs[-1]即可    outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)  # 时间序列上每个Cell的输出:[... shape=(128, 28)..]    # Linear activation    # Get the last output    return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the statmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels# 将每张图用一个28x28的矩阵表示,(55000,28,28,1)trX = trX.reshape(-1, 28, 28) teX = teX.reshape(-1, 28, 28) X = tf.placeholder("float", [None, 28, 28])Y = tf.placeholder("float", [None, 10])# get lstm_size and output 10 labelsW = init_weights([lstm_size, 10])  # 输出层权重矩阵28×10B = init_weights([10])  # 输出层baispy_x, state_size = model(X, W, B, lstm_size)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)predict_op = tf.argmax(py_x, 1)session_conf = tf.ConfigProto()session_conf.gpu_options.allow_growth = True# Launch the graph in a sessionwith tf.Session(config=session_conf) as sess:    # you need to initialize all variables    tf.global_variables_initializer().run()    for i in range(100):        for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size)):            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})        test_indices = np.arange(len(teX))  # Get A Test Batch        np.random.shuffle(test_indices)        test_indices = test_indices[0:test_size]        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==                         sess.run(predict_op, feed_dict={X: teX[test_indices]})))

 

转载于:https://www.cnblogs.com/xiaochouk/p/8746302.html

你可能感兴趣的文章
Hive 变量和属性
查看>>
验证邮箱合法性的一些测试样例
查看>>
Python安装第三方库 xlrd 和 xlwt 。处理Excel表格
查看>>
课后作业-阅读任务-阅读提问-3
查看>>
Asp.Net Core 中利用QuartzHostedService 实现 Quartz 注入依赖 (DI)
查看>>
细说sqlserver索引及SQL性能优化原则
查看>>
一般数据库增量数据处理和数据仓库增量数据处理的几种策略
查看>>
离散数学课后作业
查看>>
centos6.5适用的国内yum源:网易、搜狐
查看>>
shell 监控脚本
查看>>
[bzoj3029] 守卫者的挑战 (概率期望dp)
查看>>
[winograd]winograd算法在卷积中的应用
查看>>
视频直播技术(三):低延时直播经验总结
查看>>
微软Office Online服务安装部署(一)
查看>>
Application failed to start because it could not find or load the QT platform plugin “windows”
查看>>
python合并多表或两表数据
查看>>
分享一下伪装刚学的
查看>>
代码优化
查看>>
获取UITableCell高度的两种方法
查看>>
python 多线程
查看>>