• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

神经网络机器翻译seq2seq+attentionnlp实践3

武飞扬头像
wx630c98f24f6b8
帮助1

本文主要是对上篇没加attention的一个补充,attention实际上是模仿人类翻译的过程,在翻译一个句子时,有时需要根据上下文判断当前要翻译的单词的含义,那么就需要去时时查看下原本的句子,因为句子中有些部分会对当前单词预测的影响很大,那么得把这样的信息加入到预测当前单词的过程中。

假如encoder的输入为[X1,...,Xj,...,XTx],即输入句子的最大长度为s,不够时padding至s,Xj是第j个单词的embedding。将此输入送入encoder中,得到enc_output为[H1,...,Hj,...,HTx]。在decoder中,已经得到状态依次为[S1,...,Si-1],当前正在预测第i个词(第i个输出yi),那么将Si-1与enc_output的各个元素进行相关度计算,也就是eij = a(Si-1,Hj),然后计算出attention_layer的权重也就是eij(j=1...Tx)形成的数组的softmax输出αij,之后将这些权重与enc_output进行加权平均得到Ci,并与Si-1,yi-1(第i-1个输出)作为参数共同计算出Si。如下图所示:

学新通

学新通

学新通

学新通

学新通

下面直接上加了attention之后的训练与预测代码,这里注意一下,因为恢复模型时,可能报kernel notFound Error,那么在写预测代码时需要加上tf.variable_scope限定,但限定范围词需要通过保存模型的网络结够来确定,cat_net.py是一个辅助查看保存模型网络结构的代码。

################
# This code used to check msg of Tensor stored in ckpt
# work well with tensorflow version of 'v1.3.0-rc2-20-g0787eee'
################
 
import os
from tensorflow.python import pywrap_tensorflow
 
# code for finall ckpt
# checkpoint_path = os.path.join('~/tensorflowTraining/ResNet/model', "model.ckpt")
 
# code for designated ckpt, change 3890 to your num
checkpoint_path = "./seq2seq_attention_ckpt-9000"
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

直接输入下面的命令:

python cat_net.py  > net.log

然后打开net.log内容如下:

tensor_name:  nmt_model/trg_emb
[[-0.0108806   0.06109285  0.00103101 ...  0.04327971 -0.0353515
   0.05592098]
 [ 0.05511344  0.10315175 -0.03260459 ...  0.01503692 -0.05173029
   0.00153936]
 [ 0.0991787  -0.06907252  0.21693856 ...  0.0805049  -0.01262149
  -0.01293714]
 ...
 [-0.12422938  0.01872133 -0.08084115 ...  0.03637449 -0.0386718
   0.0702277 ]
 [-0.09422345 -0.0029713  -0.00904827 ... -0.03110654 -0.00099467
  -0.0079839 ]
 [ 0.15127543 -0.10549527  0.00927421 ...  0.00116051  0.11979865
   0.02227078]]
tensor_name:  nmt_model/softmax_bias
[ 9.702319   -1.8773284   4.8174777  ... -0.48294842 -0.35829535
 -0.4816328 ]
tensor_name:  encoder/bidirectional_rnn/fw/basic_lstm_cell/kernel
[[-0.03092953 -0.03392044 -0.02407785 ...  0.02163492 -0.0049458
   0.02973264]
 [ 0.10684907  0.04035901  0.01169399 ...  0.02350369  0.02541667
  -0.0220029 ]
 [ 0.01336335  0.0200959   0.00845157 ... -0.01780637  0.01966156
   0.00902852]
 ...
 [-0.01916422 -0.0131671   0.0082262  ... -0.01099342 -0.00506847
   0.0146405 ]
 [ 0.0169474   0.02184602 -0.01979198 ... -0.00957554 -0.01252236
   0.03171452]
 [-0.03693858  0.01639441 -0.02785428 ... -0.02872299  0.01957132
  -0.02001939]]
tensor_name:  decoder/memory_layer/kernel
[[ 0.04521498 -0.00092734  0.00987301 ... -0.01601705 -0.01625223
  -0.00826636]
 [ 0.03350661 -0.01258853  0.03047631 ... -0.01902125 -0.01759247
   0.01519862]
 [-0.02057176 -0.01262629 -0.00525282 ... -0.03981094  0.03607614
   0.00477269]
 ...
 [-0.0367771  -0.02705046  0.01810684 ... -0.03925494  0.03783213
  -0.01419215]
 [-0.01111888  0.00990444  0.02161855 ...  0.0041062   0.02929579
  -0.00364193]
 [ 0.02131032  0.00671287  0.00193167 ... -0.02134871 -0.00051426
   0.02360947]]
tensor_name:  decoder/rnn/attention_wrapper/multi_rnn_cell/cell_1/basic_lstm_cell/bias
[-1.6907989  -0.80868345 -1.1245108  ... -0.7377462  -1.0939049
 -1.2807418 ]
tensor_name:  nmt_model/src_emb
[[ 0.06317458 -0.05404264 -0.00954251 ... -0.14450565 -0.11939629
  -0.05514779]
 [ 0.00680785  0.04471309 -0.0104601  ... -0.03551793 -0.04758103
   0.01540864]
 [ 0.32627714  0.0827579  -0.11642702 ... -0.03501745 -0.27873012
  -0.04998838]
 ...
 [-0.0220207  -0.03215215 -0.01608298 ... -0.03651857 -0.04046999
  -0.02552509]
 [ 0.00540233  0.03604389  0.06067114 ...  0.05810086  0.03965386
   0.06954922]
 [ 0.02887495 -0.02881782  0.05515011 ...  0.03075846  0.00961011
  -0.02850782]]
tensor_name:  decoder/rnn/attention_wrapper/attention_layer/kernel
[[-0.09316745 -0.07995477 -0.0146741  ...  0.0717198   0.02371014
  -0.05503882]
 [-0.00638354 -0.05642074 -0.12752905 ...  0.07572     0.02780477
   0.02916634]
 [-0.0532836   0.01808308 -0.01555931 ... -0.08836221 -0.05027555
   0.01292556]
 ...
 [-0.03378733  0.01676184 -0.01945874 ...  0.04151832 -0.04257954
  -0.00394057]
 [-0.04521075  0.02617629 -0.01065068 ...  0.06043241  0.02765347
  -0.03455104]
 [-0.02321909 -0.0051408   0.02175523 ...  0.00103944  0.03563083
   0.04527191]]
tensor_name:  decoder/rnn/attention_wrapper/multi_rnn_cell/cell_0/basic_lstm_cell/bias
[-0.47364303 -0.43505263 -0.2991495  ... -0.34608215 -0.3425427
 -0.41822633]
tensor_name:  decoder/rnn/attention_wrapper/multi_rnn_cell/cell_1/basic_lstm_cell/kernel
[[-0.06655399 -0.033209    0.00741314 ... -0.03744704  0.16143945
  -0.04238527]
 [-0.09054025 -0.05978451 -0.0919419  ... -0.05676661 -0.03161845
   0.11375111]
 [-0.01762006 -0.01342999  0.00538671 ... -0.07151254  0.00439914
   0.0617904 ]
 ...
 [ 0.01361352 -0.00989851 -0.01075909 ...  0.02791671  0.0204173
   0.03272137]
 [-0.02172133  0.01065003  0.02755076 ...  0.01163509  0.00617506
   0.02474814]
 [-0.02055892 -0.0032329  -0.01226626 ... -0.03111863  0.04921816
  -0.01788351]]
tensor_name:  encoder/bidirectional_rnn/fw/basic_lstm_cell/bias
[-1.0207075  -0.7382192  -0.75269985 ... -0.7253135  -0.83074564
 -0.71001625]
tensor_name:  decoder/rnn/attention_wrapper/bahdanau_attention/attention_v
[ 0.00795415 -0.00872286 -0.02835944 ...  0.02541727 -0.0316006
 -0.01547218]
tensor_name:  decoder/rnn/attention_wrapper/bahdanau_attention/query_layer/kernel
[[-0.02799078  0.00915903 -0.00178415 ... -0.01649223 -0.02163657
   0.01371716]
 [-0.0445041   0.00936891  0.02943462 ... -0.04068676 -0.00589912
  -0.05063123]
 [ 0.01968101  0.03777748  0.01904894 ... -0.04097166  0.05280968
   0.04113906]
 ...
 [ 0.01412237  0.02355416  0.03901715 ...  0.01330961  0.01638247
   0.00222727]
 [ 0.02915935  0.00618351  0.01156276 ...  0.04674264  0.04458835
   0.01011846]
 [-0.00728581  0.04162799 -0.01898116 ... -0.03135163 -0.04987657
   0.03854783]]
tensor_name:  decoder/rnn/attention_wrapper/multi_rnn_cell/cell_0/basic_lstm_cell/kernel
[[-0.0266049   0.06239759  0.03370405 ...  0.00847407  0.02729598
  -0.02040454]
 [-0.04149583 -0.03149587 -0.01089299 ... -0.03426768  0.0172292
  -0.05368057]
 [ 0.01183772  0.09243455 -0.02107698 ... -0.05690235  0.0284145
  -0.0332344 ]
 ...
 [-0.02697257 -0.06419387 -0.04755762 ...  0.09542636 -0.01003412
  -0.04204182]
 [-0.04266602 -0.045127    0.02201566 ... -0.08180676 -0.01398551
  -0.00633448]
 [ 0.01584598  0.01223975  0.03658367 ...  0.02622196 -0.00311522
  -0.00781288]]
tensor_name:  encoder/bidirectional_rnn/bw/basic_lstm_cell/bias
[-0.80520725 -0.88913244 -1.0078353  ... -0.7981011  -0.65148497
 -0.9233699 ]
tensor_name:  encoder/bidirectional_rnn/bw/basic_lstm_cell/kernel
[[-0.05317495  0.00402797 -0.04864402 ...  0.04332062  0.02639003
  -0.00492012]
 [ 0.03982998  0.00540096  0.09128776 ... -0.03405574  0.00860246
   0.01108253]
 [ 0.00027926  0.00077254  0.08196697 ...  0.03171543  0.03697995
   0.00165045]
 ...
 [ 0.01058249  0.00307607  0.00184137 ...  0.00661535  0.01547921
  -0.02362307]
 [ 0.00757162  0.0162105  -0.01197527 ... -0.0082445  -0.00365599
   0.03383213]
 [-0.02791     0.00413945 -0.06630697 ... -0.0176604   0.01094399
  -0.03434239]]

根据net.log的内容,我们就可以确定限定范围名称,主要是因为我们之前在声明结点时没有主动限定范围,然后又不知道tf默认的限定范围。下面是train_attention.py:

#coding:utf-8
import tensorflow as tf

MAX_LEN = 50
SOS_ID = 1

SRC_TRAIN_DATA = "../train.tags.en-zh.en.deletehtml.segment.id"
TRG_TRAIN_DATA = "../train.tags.en-zh.zh.deletehtml.segment.id"
CHECKPOINT_PATH = "./seq2seq_attention_ckpt"

HIDDEN_SIZE = 1024
NUM_LAYERS = 2
SRC_VOCAB_SIZE = 10000
TRG_VOCAB_SIZE = 4000
BATCH_SIZE = 100
NUM_EPOCH = 5
KEEP_PROB = 0.8
MAX_GRAD_NORM = 5
SHARE_EMB_AND_SOFTMAX = True

class NMTModel(object):
    def __init__(self):
        self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
        self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
        self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\
         for _ in range(NUM_LAYERS)])
        self.src_embedding = tf.get_variable(
            "src_emb",[SRC_VOCAB_SIZE,HIDDEN_SIZE])
        self.trg_embedding = tf.get_variable(
            "trg_emb",[TRG_VOCAB_SIZE,HIDDEN_SIZE])
        
        if SHARE_EMB_AND_SOFTMAX:
            self.softmax_weight = tf.transpose(self.trg_embedding)
        else:
            self.softmax_weight = tf.get_variable("weight",[HIDDEN_SIZE,TRG_VOCAB_SIZE])
        self.softmax_bias = tf.get_variable("softmax_bias",[TRG_VOCAB_SIZE])

    def forward(self,src_input,src_size,trg_input,trg_label,trg_size):
        batch_size = tf.shape(src_input)[0]
        src_emb = tf.nn.embedding_lookup(self.src_embedding,src_input)
        trg_emb = tf.nn.embedding_lookup(self.trg_embedding,trg_input)

        src_emb = tf.nn.dropout(src_emb,KEEP_PROB)
        trg_emb = tf.nn.dropout(trg_emb,KEEP_PROB)

        with tf.variable_scope("encoder"):
            enc_outputs,enc_state = tf.nn.bidirectional_dynamic_rnn(
                self.enc_cell_fw,self.enc_cell_bw,src_emb,src_size,dtype=tf.float32)
            enc_outputs = tf.concat([enc_outputs[0],enc_outputs[1]],-1)

        with tf.variable_scope("decoder"):
            self.attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(HIDDEN_SIZE,enc_outputs,memory_sequence_length=src_size)
            self.attention_cell = tf.contrib.seq2seq.AttentionWrapper(self.dec_cell,self.attention_mechanism,attention_layer_size=HIDDEN_SIZE)

            dec_outputs, _ = tf.nn.dynamic_rnn(
                self.attention_cell,trg_emb,trg_size,dtype=tf.float32)

        output = tf.reshape(dec_outputs,[-1,HIDDEN_SIZE])
        logits = tf.matmul(output,self.softmax_weight)   self.softmax_bias
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(trg_label,[-1]),logits=logits)

        label_weights = tf.sequence_mask(trg_size,maxlen=tf.shape(trg_label)[1],dtype=tf.float32)
        label_weights = tf.reshape(label_weights,[-1])

        cost = tf.reduce_sum(loss*label_weights)
        cost_per_token = cost / tf.reduce_sum(label_weights)

        trainable_variables = tf.trainable_variables()

        grads = tf.gradients(cost / tf.to_float(batch_size), trainable_variables)
        grads,_ = tf.clip_by_global_norm(grads,MAX_GRAD_NORM)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
        train_op = optimizer.apply_gradients(zip(grads,trainable_variables))

        return cost_per_token,train_op

def run_epoch(session,cost_op,train_op,saver,step):
    while True:
        try:
            cost,_ = session.run([cost_op,train_op])
            if step == 0:
                print("steps %d, per token cost is %.3f"%(step,cost))
            if step 0 == 0:
                saver.save(session,CHECKPOINT_PATH,global_step=step)
            step  = 1
        except tf.errors.OutOfRangeError:
            break
    return step


def MakeDataset(file_path):
    dataset = tf.data.TextLineDataset(file_path)
    dataset = dataset.map(lambda string: tf.string_split([string]).values)
    dataset = dataset.map(lambda string: tf.string_to_number(string,tf.int32))
    dataset = dataset.map(lambda x: (x,tf.size(x)))
    return dataset

def MakeSrcTrgDataset(src_path,trg_path,batch_size):
    src_data = MakeDataset(src_path)
    trg_data = MakeDataset(trg_path)

    dataset = tf.data.Dataset.zip((src_data,trg_data))

    def FilterLength(src_tuple,trg_tuple):
        ((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)
        src_len_ok = tf.logical_and(tf.greater(src_len,1),tf.less_equal(src_len,MAX_LEN))
        trg_len_ok = tf.logical_and(tf.greater(trg_len,1),tf.less_equal(trg_len,MAX_LEN))
        return tf.logical_and(src_len_ok,trg_len_ok)
    dataset = dataset.filter(FilterLength)

    def MakeTrgInput(src_tuple,trg_tuple):
        ((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)
        trg_input = tf.concat([[SOS_ID],trg_label[:-1]],axis=0)
        return ((src_input,src_len),(trg_input,trg_label,trg_len))
    dataset = dataset.map(MakeTrgInput)
    dataset = dataset.shuffle(10000)

    padded_shapes = (
        (tf.TensorShape([None]),
         tf.TensorShape([])),
        (tf.TensorShape([None]),
         tf.TensorShape([None]),
         tf.TensorShape([])))
    batched_dataset = dataset.padded_batch(batch_size,padded_shapes)
    return batched_dataset
                

def main():
    initializer = tf.random_uniform_initializer(-0.05,0.05)
    with tf.variable_scope("nmt_model",reuse=None,initializer=initializer):
        train_model = NMTModel()

    data = MakeSrcTrgDataset(SRC_TRAIN_DATA,TRG_TRAIN_DATA,BATCH_SIZE)
    iterator = data.make_initializable_iterator()
    (src,src_size),(trg_input,trg_label,trg_size) = iterator.get_next()

    cost_op,train_op = train_model.forward(src,src_size,trg_input,trg_label,trg_size)
    saver = tf.train.Saver()
    step = 0

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7,allow_growth=True)
    session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    with session as sess:
        tf.global_variables_initializer().run()
        for i in range(NUM_EPOCH):
            print("In iteration: %d"%(i 1))
            sess.run(iterator.initializer)
            step = run_epoch(sess,cost_op,train_op,saver,step)

if __name__ == '__main__':
    main()

下面是eval_attention.py,原书中并没有给出,尤其要注意那个修改的限定范围:

#coding:utf-8
import tensorflow as tf


CHECKPOINT_PATH = "./seq2seq_attention_ckpt-9800"

HIDDEN_SIZE = 1024
NUM_LAYERS = 2
SRC_VOCAB_SIZE = 10000
TRG_VOCAB_SIZE = 4000
BATCH_SIZE = 100
SHARE_EMB_AND_SOFTMAX = True
SOS_ID = 1
EOS_ID = 2

class NMTModel(object):
    def __init__(self):
        self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
        self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
        self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\
         for _ in range(NUM_LAYERS)])
        self.src_embedding = tf.get_variable(
            "src_emb",[SRC_VOCAB_SIZE,HIDDEN_SIZE])
        self.trg_embedding = tf.get_variable(
            "trg_emb",[TRG_VOCAB_SIZE,HIDDEN_SIZE])
        
        if SHARE_EMB_AND_SOFTMAX:
            self.softmax_weight = tf.transpose(self.trg_embedding)
        else:
            self.softmax_weight = tf.get_variable("weight",[HIDDEN_SIZE,TRG_VOCAB_SIZE])
        self.softmax_bias = tf.get_variable("softmax_bias",[TRG_VOCAB_SIZE])

    def inference(self,src_input):
        src_size = tf.convert_to_tensor([len(src_input)],dtype=tf.int32)
        src_input = tf.convert_to_tensor([src_input],dtype=tf.int32)
        src_emb = tf.nn.embedding_lookup(self.src_embedding,src_input)

        with tf.variable_scope("encoder"):
            enc_outputs,enc_state = tf.nn.bidirectional_dynamic_rnn(
                self.enc_cell_fw,self.enc_cell_bw,src_emb,src_size,dtype=tf.float32)
            enc_outputs = tf.concat([enc_outputs[0],enc_outputs[1]],-1)
        MAX_DEC_LEN = 100

        init_array = tf.TensorArray(dtype=tf.int32,size=0,dynamic_size=True,clear_after_read=False)
        init_array = init_array.write(0,SOS_ID)
        with tf.variable_scope("decoder"):
            self.attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(HIDDEN_SIZE,enc_outputs,memory_sequence_length=src_size)
        with tf.variable_scope("decoder/rnn/attention_wrapper"):
            self.attention_cell = tf.contrib.seq2seq.AttentionWrapper(self.dec_cell,self.attention_mechanism,attention_layer_size=HIDDEN_SIZE)
            state = self.attention_cell.zero_state(batch_size=1, dtype=tf.float32)
            init_loop_var = (state,init_array,0)

            def continue_loop_condition(state,trg_ids,step):
                return tf.reduce_all(tf.logical_and(tf.not_equal(trg_ids.read(step),EOS_ID),tf.less(step,MAX_DEC_LEN-1)))

            def loop_body(state,trg_ids,step):
                trg_input = [trg_ids.read(step)]
                trg_emb = tf.nn.embedding_lookup(self.trg_embedding,trg_input)

                dec_outputs,next_state = self.attention_cell.call(state=state,inputs=trg_emb)
                output = tf.reshape(dec_outputs,[-1,HIDDEN_SIZE])
                logits = (tf.matmul(output,self.softmax_weight)   self.softmax_bias)
                next_id = tf.argmax(logits,axis=1,output_type=tf.int32)

                trg_ids = trg_ids.write(step 1,next_id[0])
                return next_state,trg_ids,step 1

            state,trg_ids,step = tf.while_loop(
                continue_loop_condition,loop_body,init_loop_var)
            return trg_ids.stack()

def main():
    from stanfordcorenlp import StanfordCoreNLP
    nlp = StanfordCoreNLP("../../stanford-corenlp-full-2018-10-05",lang='en')
    with tf.variable_scope("nmt_model",reuse=None):
        model = NMTModel()
    vocab_file = "../train.tags.en-zh.en.deletehtml.vocab"
    sentence = "It doesn't belong to mine!"
    with open(vocab_file,'r') as f:
        data = f.readlines()
        words = [w.strip() for w in data]
    word_to_id = {k:v for (k,v) in zip(words,range(len(words)))}
    wordlist = nlp.word_tokenize(sentence.strip())   ["<eos>"]
    # print(wordlist)
    idlist = [str(word_to_id[w]) if w in word_to_id else str(word_to_id["<unk>"]) for w in wordlist]
    idlist = [int(i) for i in idlist]
    # print(idlist)

    output_op = model.inference(idlist)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7,allow_growth=True)
    session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    saver = tf.train.Saver()
    saver.restore(session,CHECKPOINT_PATH)

    output = session.run(output_op)

    vocab_file2 = "../train.tags.en-zh.zh.deletehtml.vocab"
    with open(vocab_file2,'r') as f2:
        data2 = f2.readlines()
        words = [w.strip() for w in data2]
    id_to_word = {k:v for (k,v) in zip(range(len(words)),words)}
    print([id_to_word[i] for i in output])
    session.close()

    nlp.close()

if __name__ == '__main__':
    main()

输出结果为:

['<sos>', '这', '不', '是', '我', '的', '<unk>', '<eos>']

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhiaiheh
系列文章
更多 icon
同类精品
更多 icon
继续加载