『NLP经典项目集』07:基于seq2seq的机器翻译
基于seq2seq的机器翻译机器翻译是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程。这里,我们将根据源语言输入,自动输出目标语言译文。这是一个典型的序列到序列(sequence2sequence, seq2seq)建模的场景,编码器-解码器(Encoder-Decoder)框架是解决seq2seq问题的经典方法,它能够将一个任意长度的源序列转换成另一个任意长度的目标序列
·
基于seq2seq的机器翻译
机器翻译是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程。
这里,我们将根据源语言输入,自动输出目标语言译文。这是一个典型的序列到序列(sequence2sequence, seq2seq)建模的场景,编码器-解码器(Encoder-Decoder)框架是解决seq2seq问题的经典方法,它能够将一个任意长度的源序列转换成另一个任意长度的目标序列:编码阶段将整个源序列编码成一个向量,解码阶段通过最大化预测序列概率,从中解码出整个目标序列。其模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,推荐参考飞桨官网机器翻译案例。
图1:encoder-decoder示意图
这里的Encoder采用LSTM,Decoder采用带有attention机制的LSTM。
图2:带有attention机制的encoder-decoder示意图
我们将源语言语句作为Encoder的输出,目标语言语句作为Decoder的输入,训练模型。
运行本目录下的范例模型需要安装PaddlePaddle 2.0-rc1及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照安装文档中的说明更新 PaddlePaddle 安装版本。
In [1]
!pip install --upgrade paddlenlp==2.0.0b4 -i https://pypi.tuna.tsinghua.edu.cn/simple
import io
import numpy as np
from functools import partial
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.nn.initializer as I
from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper
from paddlenlp.metrics import Perplexity
from paddlenlp.datasets import IWSLT15
from paddlenlp.metrics import BLEU
from prepareinput import prepare_train_input, prepare_infer_input
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting paddlenlp==2.0.0b4
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9c/de/9ca615db516f438bae269b457f320ac0cbfa6c90242e80e29da8e2f5491c/paddlenlp-2.0.0b4-py3-none-any.whl (164kB)
|████████████████████████████████| 174kB 9.7MB/s eta 0:00:01
Collecting seqeval (from paddlenlp==2.0.0b4)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
|████████████████████████████████| 51kB 33.4MB/s eta 0:00:01
Requirement already satisfied, skipping upgrade: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (0.4.4)
Requirement already satisfied, skipping upgrade: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (0.42.1)
Requirement already satisfied, skipping upgrade: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (4.1.0)
Requirement already satisfied, skipping upgrade: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (2.9.0)
Requirement already satisfied, skipping upgrade: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (2.1.1)
Requirement already satisfied, skipping upgrade: numpy>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp==2.0.0b4) (1.16.4)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp==2.0.0b4) (0.22.1)
Requirement already satisfied, skipping upgrade: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from h5py->paddlenlp==2.0.0b4) (1.15.0)
Requirement already satisfied, skipping upgrade: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (1.0.0)
Requirement already satisfied, skipping upgrade: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (7.1.2)
Requirement already satisfied, skipping upgrade: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (0.8.53)
Requirement already satisfied, skipping upgrade: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (1.21.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (3.14.0)
Requirement already satisfied, skipping upgrade: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (2.22.0)
Requirement already satisfied, skipping upgrade: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (3.8.2)
Requirement already satisfied, skipping upgrade: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (1.1.1)
Requirement already satisfied, skipping upgrade: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (0.7.1.1)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==2.0.0b4) (0.14.1)
Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==2.0.0b4) (1.3.0)
Requirement already satisfied, skipping upgrade: Jinja2>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (2.10.1)
Requirement already satisfied, skipping upgrade: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (2019.3)
Requirement already satisfied, skipping upgrade: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (2.8.0)
Requirement already satisfied, skipping upgrade: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp==2.0.0b4) (3.9.9)
Requirement already satisfied, skipping upgrade: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp==2.0.0b4) (0.18.0)
Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (5.1.2)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (0.23)
Requirement already satisfied, skipping upgrade: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (16.7.9)
Requirement already satisfied, skipping upgrade: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (1.3.4)
Requirement already satisfied, skipping upgrade: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (0.10.0)
Requirement already satisfied, skipping upgrade: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (1.3.0)
Requirement already satisfied, skipping upgrade: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (1.4.10)
Requirement already satisfied, skipping upgrade: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (2.0.1)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (3.0.4)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (1.25.6)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (2019.9.11)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (2.8)
Requirement already satisfied, skipping upgrade: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp==2.0.0b4) (0.6.1)
Requirement already satisfied, skipping upgrade: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp==2.0.0b4) (2.6.0)
Requirement already satisfied, skipping upgrade: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp==2.0.0b4) (2.2.0)
Requirement already satisfied, skipping upgrade: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp==2.0.0b4) (1.1.0)
Requirement already satisfied, skipping upgrade: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp==2.0.0b4) (0.16.0)
Requirement already satisfied, skipping upgrade: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp==2.0.0b4) (7.0)
Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.5->Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (1.1.1)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->pre-commit->visualdl->paddlenlp==2.0.0b4) (0.6.0)
Requirement already satisfied, skipping upgrade: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->pre-commit->visualdl->paddlenlp==2.0.0b4) (7.2.0)
Building wheels for collected packages: seqeval
Building wheel for seqeval (setup.py) ... done
Created wheel for seqeval: filename=seqeval-1.2.2-cp37-none-any.whl size=16171 sha256=2d0a4be4752f7c4ec38fd16be4ce7341dd4140d9d896aec34116eadc8c2df6a8
Stored in directory: /home/aistudio/.cache/pip/wheels/53/6a/9e/7e62c49241dcbafda4e9b1efaf3899193814e726387def0033
Successfully built seqeval
Installing collected packages: seqeval, paddlenlp
Successfully installed paddlenlp-2.0.0b4 seqeval-1.2.2
数据部分
数据集介绍
本教程使用IWSLT'15 English-Vietnamese data数据集中的英语到越南语的数据作为训练语料,IWSLT15为paddlenlp的内置数据集,以train.en为训练集源语言数据,train.vi为训练集目标语言数据,tst2012.en、tst2021.vi数据作为开发集,tst2013.en、tst2013.vi数据作为测试集,且已有定义好的vocab.en、vocab.vi作为源语言和目标语言词汇表(取词语频数前50000的词汇)。
构造dataloader
下面的create_data_loader函数用于创建训练、验证和预测时所需要的DataLoader对象,DataLoader对象用于产生一个个batch的数据。下面对函数中调用的paddlenlp内置函数作简单说明:
get_vocab:加载词汇表文件中的词语
get_default_transform_func():使用默认的将文本转化成数值id的方式
get_datasets:获得当前数据集的训练、验证和预测数据集
SamplerHelper:产生可迭代的sampler用于DataLoader对象直接调用
In [2]
def create_data_loader(mode):
batch_size = 128
max_len = 50
src_vocab, tgt_vocab = IWSLT15.get_vocab()
bos_id = src_vocab[src_vocab.bos_token]
eos_id = src_vocab[src_vocab.eos_token]
pad_id = eos_id
trans_func_tuple = IWSLT15.get_default_transform_func()
dataset = IWSLT15.get_datasets(
mode = [mode],
transform_func=[trans_func_tuple])
key = (lambda x, data_source: len(data_source[x][0]))
cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])
if mode in ["train", "dev"]:
dataset = dataset.filter(
#对训练集源语言语句和目标语言语句的限制最大输入长度
lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
batch_sampler = SamplerHelper(dataset).shuffle().sort(
key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)
data_loader = paddle.io.DataLoader(
dataset,
batch_sampler=batch_sampler,
#调用prepare_train_input函数,能够对产生的sampler进行pad操作,并返回实际长度等,详见prepareinput.py
collate_fn=partial(
prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
else:
batch_sampler = SamplerHelper(dataset).batch(batch_size=batch_size)
data_loader = paddle.io.DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=partial(
prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
return data_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
模型部分
下图是带有attention的Seq2Seq模型结构。下面我们分别定义网络的每个部分,最后构建Seq2Seq主网络。
图3:带有attention机制的encoder-decoder原理示意图
定义Encoder
Encoder部分可以直接利用PaddlePaddle2.0提供的rnn系列API的nn.LSTM,提供序列pad前的实际长度tensorsequence_length,得到encoder_output和encoder_state。
In [3]
class Seq2SeqEncoder(nn.Layer):
def __init__(self,
vocab_size,
embed_dim,
hidden_size,
num_layers,
dropout_prob=0.,
init_scale=0.1):
super(Seq2SeqEncoder, self).__init__()
#Embedding层,用于将单词投影成词向量,vocab_size是词的数量,embed_dim是词向量的维度
self.embedder = nn.Embedding(
vocab_size,
embed_dim,
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)))
#定义一个lstm单元
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
direction="forward",
dropout=dropout_prob if num_layers > 1 else 0.)
def forward(self, sequence, sequence_length):
inputs = self.embedder(sequence)
encoder_output, encoder_state = self.lstm(
inputs, sequence_length=sequence_length)
return encoder_output, encoder_state
定义Decoder
定义Seq2SeqDecoderCell
由于Decoder部分是带有attention的LSTM,我们不能复用nn.LSTM,所以需要定义Seq2SeqDecoderCell
In [4]
class Seq2SeqDecoderCell(nn.RNNCellBase):
def __init__(self, num_layers, input_size, hidden_size, dropout_prob=0.):
super(Seq2SeqDecoderCell, self).__init__()
if dropout_prob > 0.0:
self.dropout = nn.Dropout(dropout_prob)
else:
self.dropout = None
self.lstm_cells = nn.LayerList([
nn.LSTMCell(
input_size=input_size + hidden_size if i == 0 else hidden_size,
hidden_size=hidden_size) for i in range(num_layers)
])
self.attention_layer = AttentionLayer(hidden_size)
def forward(self,
step_input,
states,
encoder_output,
encoder_padding_mask=None):
lstm_states, input_feed = states
new_lstm_states = []
step_input = paddle.concat([step_input, input_feed], 1)
for i, lstm_cell in enumerate(self.lstm_cells):
out, new_lstm_state = lstm_cell(step_input, lstm_states[i])
if self.dropout:
step_input = self.dropout(out)
else:
step_input = out
new_lstm_states.append(new_lstm_state)
out = self.attention_layer(step_input, encoder_output,
encoder_padding_mask)
return out, [new_lstm_states, out]
定义AttentionLayer
其中AttentionLayer可以这样定义:
In [5]
class AttentionLayer(nn.Layer):
def __init__(self, hidden_size, bias=False, init_scale=0.1):
super(AttentionLayer, self).__init__()
self.input_proj = nn.Linear(
hidden_size,
hidden_size,
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)),
bias_attr=bias)
self.output_proj = nn.Linear(
hidden_size + hidden_size,
hidden_size,
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)),
bias_attr=bias)
def forward(self, hidden, encoder_output, encoder_padding_mask):
encoder_output = self.input_proj(encoder_output)
attn_scores = paddle.matmul(
paddle.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)
if encoder_padding_mask is not None:
attn_scores = paddle.add(attn_scores, encoder_padding_mask)
attn_scores = F.softmax(attn_scores)
attn_out = paddle.squeeze(
paddle.matmul(attn_scores, encoder_output), [1])
attn_out = paddle.concat([attn_out, hidden], 1)
attn_out = self.output_proj(attn_out)
return attn_out
定义Seq2SeqDecoder
有了Seq2SeqDecoderCell,就可以构建Seq2SeqDecoder了
In [6]
class Seq2SeqDecoder(nn.Layer):
def __init__(self,
vocab_size,
embed_dim,
hidden_size,
num_layers,
dropout_prob=0.,
init_scale=0.1):
super(Seq2SeqDecoder, self).__init__()
self.embedder = nn.Embedding(
vocab_size,
embed_dim,
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)))
self.lstm_attention = nn.RNN(Seq2SeqDecoderCell(
num_layers, embed_dim, hidden_size, dropout_prob),
is_reverse=False,
time_major=False)
#全连接层
self.output_layer = nn.Linear(
hidden_size,
vocab_size,
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)),
bias_attr=False)
def forward(self, trg, decoder_initial_states, encoder_output,
encoder_padding_mask):
inputs = self.embedder(trg)
decoder_output, _ = self.lstm_attention(
inputs,
initial_states=decoder_initial_states,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
predict = self.output_layer(decoder_output)
return predict
构建主网络Seq2SeqAttnModel
Encoder和Decoder定义好之后,网络就可以构建起来了。
Encoder的input和output均为shape为[batch_size, src_length, hidden_size]的Tensor
将Encoder的最后一个时间步的状态输出作为Decoder的初始状态输入
Decoder的input和output均为shape为[batch_size, trg_length, hidden_size]的Tensor
In [7]
class Seq2SeqAttnModel(nn.Layer):
def __init__(self,
src_vocab_size,
trg_vocab_size,
embed_dim,
hidden_size,
num_layers,
dropout_prob=0.,
eos_id=1,
init_scale=0.1):
super(Seq2SeqAttnModel, self).__init__()
self.hidden_size = hidden_size
self.eos_id = eos_id
self.num_layers = num_layers
self.INF = 1e9
self.encoder = Seq2SeqEncoder(src_vocab_size, embed_dim, hidden_size,
num_layers, dropout_prob, init_scale)
self.decoder = Seq2SeqDecoder(trg_vocab_size, embed_dim, hidden_size,
num_layers, dropout_prob, init_scale)
def forward(self, src, src_length, trg):
encoder_output, encoder_final_state = self.encoder(src, src_length)
# Transfer shape of encoder_final_states to [num_layers, 2, batch_size, hidden_size]
encoder_final_states = [
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
# Construct decoder initial states: use input_feed and the shape is
# [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states
decoder_initial_states = [
encoder_final_states,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# Build attention mask to avoid paying attention on padddings
src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
encoder_padding_mask = (src_mask - 1.0) * self.INF
encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask)
return predict
定义损失函数
这里使用的是交叉熵损失函数,由于前面对不足长度的位置进行了padding,生成句子计算loss时需要忽略那些原本是padding的位置的值,因此需要在损失函数中引入trg_mask参数,由于PaddlePaddle框架提供的paddle.nn.CrossEntropyLoss不能接受trg_mask参数,因此在这里需要重新定义:
In [8]
class CrossEntropyCriterion(nn.Layer):
def __init__(self):
super(CrossEntropyCriterion, self).__init__()
def forward(self, predict, label, trg_mask):
cost = F.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False)
cost = paddle.squeeze(cost, axis=[2])
masked_cost = cost * trg_mask
batch_mean_cost = paddle.mean(masked_cost, axis=[0])
seq_cost = paddle.sum(batch_mean_cost)
return seq_cost
执行过程
训练过程
使用高层API执行训练,需要调用prepare和fit函数。
在prepare函数中,配置优化器、损失函数,以及评价指标。其中评价指标使用的是PaddleNLP提供的困惑度计算API paddlenlp.metrics.Perplexity。
In [9]
batch_size = 128
num_layers = 2
dropout = 0.2
init_scale = 0.1
hidden_size =512
max_grad_norm = 5.0
learning_rate = 0.001
max_epoch = 12
max_len = 50
model_path = './attention_models'
infer_output_file = './output_files'
beam_size = 10
log_freq = 100
device = "gpu"
# Define dataloader
train_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_data_loader(mode="train")
eval_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_data_loader(mode="dev")
model = paddle.Model(
Seq2SeqAttnModel(src_vocab_size, tgt_vocab_size, hidden_size, hidden_size, num_layers, dropout, eos_id))
grad_clip = nn.ClipGradByGlobalNorm(max_grad_norm)
optimizer = paddle.optimizer.Adam(
learning_rate=learning_rate,
parameters=model.parameters(),
grad_clip=grad_clip)
ppl_metric = Perplexity()
model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)
100%|██████████| 9967/9967 [00:00<00:00, 40938.29it/s]
如果你安装了VisualDL,可以在fit中添加一个callbacks参数使用VisualDL观测你的训练过程,如下:
In [10]
model.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=max_epoch,
eval_freq=1,
save_freq=1,
save_dir=model_path,
log_freq=log_freq,
callbacks=[paddle.callbacks.VisualDL('./log')])
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/12
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return (isinstance(seq, collections.Sequence) and
step 100/1041 - loss: 316.9290 - Perplexity: 569.7516 - 157ms/step
step 200/1041 - loss: 283.0392 - Perplexity: 377.6471 - 154ms/step
step 300/1041 - loss: 260.9322 - Perplexity: 267.6343 - 154ms/step
step 400/1041 - loss: 235.1697 - Perplexity: 201.8174 - 154ms/step
step 500/1041 - loss: 229.3707 - Perplexity: 159.4850 - 153ms/step
step 600/1041 - loss: 203.4672 - Perplexity: 129.9806 - 153ms/step
step 700/1041 - loss: 201.8655 - Perplexity: 108.9772 - 153ms/step
step 800/1041 - loss: 196.9280 - Perplexity: 93.9172 - 154ms/step
step 900/1041 - loss: 186.4891 - Perplexity: 82.5902 - 154ms/step
step 1000/1041 - loss: 177.8873 - Perplexity: 73.5232 - 154ms/step
step 1041/1041 - loss: 80.8191 - Perplexity: 70.4385 - 154ms/step
save checkpoint at /home/aistudio/attention_models/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 172.1245 - Perplexity: 24.3200 - 75ms/step
Eval samples: 1553
Epoch 2/12
step 100/1041 - loss: 170.8153 - Perplexity: 20.7067 - 154ms/step
step 200/1041 - loss: 169.8316 - Perplexity: 20.1637 - 155ms/step
step 300/1041 - loss: 164.2285 - Perplexity: 19.6250 - 154ms/step
step 400/1041 - loss: 166.2451 - Perplexity: 19.1300 - 154ms/step
step 500/1041 - loss: 154.8416 - Perplexity: 18.6932 - 155ms/step
step 600/1041 - loss: 161.9329 - Perplexity: 18.2686 - 154ms/step
step 700/1041 - loss: 158.8072 - Perplexity: 17.9527 - 154ms/step
step 800/1041 - loss: 151.1718 - Perplexity: 17.5832 - 154ms/step
step 900/1041 - loss: 150.8292 - Perplexity: 17.2788 - 154ms/step
step 1000/1041 - loss: 151.9545 - Perplexity: 16.9614 - 154ms/step
step 1041/1041 - loss: 64.2150 - Perplexity: 16.8593 - 154ms/step
save checkpoint at /home/aistudio/attention_models/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 163.4615 - Perplexity: 16.4078 - 75ms/step
Eval samples: 1553
Epoch 3/12
step 100/1041 - loss: 144.9384 - Perplexity: 11.9777 - 154ms/step
step 200/1041 - loss: 147.2327 - Perplexity: 11.9371 - 153ms/step
step 300/1041 - loss: 139.2942 - Perplexity: 11.9587 - 154ms/step
step 400/1041 - loss: 140.5380 - Perplexity: 11.9496 - 154ms/step
step 500/1041 - loss: 140.6450 - Perplexity: 11.9257 - 154ms/step
step 600/1041 - loss: 140.4239 - Perplexity: 11.8719 - 153ms/step
step 700/1041 - loss: 141.9416 - Perplexity: 11.8256 - 153ms/step
step 800/1041 - loss: 142.7165 - Perplexity: 11.7741 - 153ms/step
step 900/1041 - loss: 136.2898 - Perplexity: 11.7381 - 153ms/step
step 1000/1041 - loss: 139.1634 - Perplexity: 11.6733 - 153ms/step
step 1041/1041 - loss: 56.9410 - Perplexity: 11.6393 - 153ms/step
save checkpoint at /home/aistudio/attention_models/2
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 134.3902 - Perplexity: 13.4229 - 75ms/step
Eval samples: 1553
Epoch 4/12
step 100/1041 - loss: 133.1941 - Perplexity: 9.3546 - 154ms/step
step 200/1041 - loss: 127.7615 - Perplexity: 9.4610 - 154ms/step
step 300/1041 - loss: 128.3531 - Perplexity: 9.4909 - 154ms/step
step 400/1041 - loss: 131.7594 - Perplexity: 9.5061 - 154ms/step
step 500/1041 - loss: 135.5802 - Perplexity: 9.5315 - 153ms/step
step 600/1041 - loss: 132.5778 - Perplexity: 9.5537 - 153ms/step
step 700/1041 - loss: 127.7056 - Perplexity: 9.5584 - 153ms/step
step 800/1041 - loss: 129.8652 - Perplexity: 9.5510 - 153ms/step
step 900/1041 - loss: 126.0126 - Perplexity: 9.5457 - 153ms/step
step 1000/1041 - loss: 129.1563 - Perplexity: 9.5513 - 153ms/step
step 1041/1041 - loss: 43.0635 - Perplexity: 9.5494 - 153ms/step
save checkpoint at /home/aistudio/attention_models/3
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 138.2625 - Perplexity: 12.5110 - 75ms/step
Eval samples: 1553
Epoch 5/12
step 100/1041 - loss: 119.1514 - Perplexity: 7.9890 - 156ms/step
step 200/1041 - loss: 126.6494 - Perplexity: 8.0800 - 155ms/step
step 300/1041 - loss: 122.2061 - Perplexity: 8.1121 - 154ms/step
step 400/1041 - loss: 117.8465 - Perplexity: 8.1787 - 155ms/step
step 500/1041 - loss: 120.3643 - Perplexity: 8.2348 - 154ms/step
step 600/1041 - loss: 126.8760 - Perplexity: 8.2641 - 154ms/step
step 700/1041 - loss: 125.7731 - Perplexity: 8.3052 - 154ms/step
step 800/1041 - loss: 122.7028 - Perplexity: 8.3227 - 154ms/step
step 900/1041 - loss: 129.4286 - Perplexity: 8.3327 - 154ms/step
step 1000/1041 - loss: 125.5585 - Perplexity: 8.3499 - 153ms/step
step 1041/1041 - loss: 48.3358 - Perplexity: 8.3549 - 153ms/step
save checkpoint at /home/aistudio/attention_models/4
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 147.6611 - Perplexity: 12.0037 - 75ms/step
Eval samples: 1553
Epoch 6/12
step 100/1041 - loss: 120.7922 - Perplexity: 7.2646 - 155ms/step
step 200/1041 - loss: 118.1561 - Perplexity: 7.2689 - 154ms/step
step 300/1041 - loss: 120.9546 - Perplexity: 7.3105 - 153ms/step
step 400/1041 - loss: 118.6025 - Perplexity: 7.3526 - 153ms/step
step 500/1041 - loss: 117.7779 - Perplexity: 7.3928 - 153ms/step
step 600/1041 - loss: 122.9525 - Perplexity: 7.4319 - 153ms/step
step 700/1041 - loss: 117.7925 - Perplexity: 7.4670 - 153ms/step
step 800/1041 - loss: 126.4061 - Perplexity: 7.4887 - 153ms/step
step 900/1041 - loss: 118.9017 - Perplexity: 7.5278 - 153ms/step
step 1000/1041 - loss: 122.6585 - Perplexity: 7.5620 - 154ms/step
step 1041/1041 - loss: 57.3045 - Perplexity: 7.5724 - 154ms/step
save checkpoint at /home/aistudio/attention_models/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 141.4744 - Perplexity: 11.7521 - 75ms/step
Eval samples: 1553
Epoch 7/12
step 100/1041 - loss: 115.3885 - Perplexity: 6.5775 - 154ms/step
step 200/1041 - loss: 115.0385 - Perplexity: 6.6283 - 153ms/step
step 300/1041 - loss: 116.8085 - Perplexity: 6.6999 - 153ms/step
step 400/1041 - loss: 109.3981 - Perplexity: 6.7474 - 153ms/step
step 500/1041 - loss: 118.9770 - Perplexity: 6.7980 - 153ms/step
step 600/1041 - loss: 119.6290 - Perplexity: 6.8498 - 153ms/step
step 700/1041 - loss: 114.9620 - Perplexity: 6.8818 - 153ms/step
step 800/1041 - loss: 114.2539 - Perplexity: 6.9167 - 153ms/step
step 900/1041 - loss: 119.0262 - Perplexity: 6.9484 - 154ms/step
step 1000/1041 - loss: 114.2731 - Perplexity: 6.9769 - 154ms/step
step 1041/1041 - loss: 45.2423 - Perplexity: 6.9787 - 154ms/step
save checkpoint at /home/aistudio/attention_models/6
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 141.1222 - Perplexity: 11.5050 - 75ms/step
Eval samples: 1553
Epoch 8/12
step 100/1041 - loss: 106.5752 - Perplexity: 6.0088 - 155ms/step
step 200/1041 - loss: 112.7031 - Perplexity: 6.1096 - 154ms/step
step 300/1041 - loss: 112.7603 - Perplexity: 6.1900 - 154ms/step
step 400/1041 - loss: 112.4972 - Perplexity: 6.2464 - 154ms/step
step 500/1041 - loss: 110.8458 - Perplexity: 6.2897 - 154ms/step
step 600/1041 - loss: 118.5520 - Perplexity: 6.3431 - 154ms/step
step 700/1041 - loss: 116.1271 - Perplexity: 6.3832 - 154ms/step
step 800/1041 - loss: 115.2826 - Perplexity: 6.4184 - 154ms/step
step 900/1041 - loss: 115.0764 - Perplexity: 6.4395 - 154ms/step
step 1000/1041 - loss: 118.9429 - Perplexity: 6.4657 - 154ms/step
step 1041/1041 - loss: 48.1713 - Perplexity: 6.4823 - 154ms/step
save checkpoint at /home/aistudio/attention_models/7
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 132.2597 - Perplexity: 11.5465 - 76ms/step
Eval samples: 1553
Epoch 9/12
step 100/1041 - loss: 106.0669 - Perplexity: 5.6741 - 156ms/step
step 200/1041 - loss: 108.2129 - Perplexity: 5.7532 - 154ms/step
step 300/1041 - loss: 112.7348 - Perplexity: 5.8025 - 154ms/step
step 400/1041 - loss: 108.3239 - Perplexity: 5.8568 - 154ms/step
step 500/1041 - loss: 115.5850 - Perplexity: 5.9041 - 154ms/step
step 600/1041 - loss: 108.4180 - Perplexity: 5.9378 - 154ms/step
step 700/1041 - loss: 112.5859 - Perplexity: 5.9876 - 153ms/step
step 800/1041 - loss: 113.8619 - Perplexity: 6.0267 - 154ms/step
step 900/1041 - loss: 115.6805 - Perplexity: 6.0513 - 154ms/step
step 1000/1041 - loss: 118.7408 - Perplexity: 6.0714 - 154ms/step
step 1041/1041 - loss: 45.1587 - Perplexity: 6.0828 - 154ms/step
save checkpoint at /home/aistudio/attention_models/8
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 148.6588 - Perplexity: 11.5041 - 75ms/step
Eval samples: 1553
Epoch 10/12
step 100/1041 - loss: 107.6575 - Perplexity: 5.3731 - 153ms/step
step 200/1041 - loss: 105.6288 - Perplexity: 5.4270 - 153ms/step
step 300/1041 - loss: 109.6870 - Perplexity: 5.4661 - 154ms/step
step 400/1041 - loss: 106.5136 - Perplexity: 5.5160 - 153ms/step
step 500/1041 - loss: 113.4923 - Perplexity: 5.5568 - 153ms/step
step 600/1041 - loss: 108.5855 - Perplexity: 5.5904 - 153ms/step
step 700/1041 - loss: 107.1722 - Perplexity: 5.6375 - 153ms/step
step 800/1041 - loss: 115.3743 - Perplexity: 5.6674 - 153ms/step
step 900/1041 - loss: 114.0770 - Perplexity: 5.7021 - 153ms/step
step 1000/1041 - loss: 111.8619 - Perplexity: 5.7246 - 153ms/step
step 1041/1041 - loss: 47.1277 - Perplexity: 5.7353 - 154ms/step
save checkpoint at /home/aistudio/attention_models/9
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 142.5951 - Perplexity: 11.4914 - 76ms/step
Eval samples: 1553
Epoch 11/12
step 100/1041 - loss: 103.6631 - Perplexity: 5.0292 - 153ms/step
step 200/1041 - loss: 97.6034 - Perplexity: 5.0740 - 154ms/step
step 300/1041 - loss: 109.3245 - Perplexity: 5.1391 - 153ms/step
step 400/1041 - loss: 106.0880 - Perplexity: 5.1917 - 154ms/step
step 500/1041 - loss: 105.9739 - Perplexity: 5.2453 - 154ms/step
step 600/1041 - loss: 108.3997 - Perplexity: 5.2823 - 154ms/step
step 700/1041 - loss: 108.6982 - Perplexity: 5.3275 - 154ms/step
step 800/1041 - loss: 106.0240 - Perplexity: 5.3552 - 154ms/step
step 900/1041 - loss: 102.1773 - Perplexity: 5.3887 - 154ms/step
step 1000/1041 - loss: 111.5466 - Perplexity: 5.4194 - 154ms/step
step 1041/1041 - loss: 36.5989 - Perplexity: 5.4327 - 154ms/step
save checkpoint at /home/aistudio/attention_models/10
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 147.4474 - Perplexity: 11.9804 - 75ms/step
Eval samples: 1553
Epoch 12/12
step 100/1041 - loss: 98.2783 - Perplexity: 4.7699 - 154ms/step
step 200/1041 - loss: 102.6894 - Perplexity: 4.8731 - 153ms/step
step 300/1041 - loss: 105.4697 - Perplexity: 4.9088 - 153ms/step
step 400/1041 - loss: 97.5007 - Perplexity: 4.9465 - 153ms/step
step 500/1041 - loss: 108.4352 - Perplexity: 4.9925 - 154ms/step
step 600/1041 - loss: 102.4900 - Perplexity: 5.0358 - 153ms/step
step 700/1041 - loss: 102.6786 - Perplexity: 5.0682 - 153ms/step
step 800/1041 - loss: 104.1088 - Perplexity: 5.0986 - 153ms/step
step 900/1041 - loss: 107.4128 - Perplexity: 5.1334 - 153ms/step
step 1000/1041 - loss: 105.3015 - Perplexity: 5.1618 - 154ms/step
step 1041/1041 - loss: 37.2941 - Perplexity: 5.1706 - 153ms/step
save checkpoint at /home/aistudio/attention_models/11
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 134.3483 - Perplexity: 11.9075 - 76ms/step
Eval samples: 1553
save checkpoint at /home/aistudio/attention_models/final
我们可以调用state_dict函数,取出网络参数
In [11]
state_dict = model.network.state_dict()
模型预测
定义预测网络Seq2SeqAttnInferModel
预测网络继承上面的主网络Seq2SeqAttnModel,定义子类Seq2SeqAttnInferModel
In [12]
class Seq2SeqAttnInferModel(Seq2SeqAttnModel):
def __init__(self,
src_vocab_size,
trg_vocab_size,
embed_dim,
hidden_size,
num_layers,
dropout_prob=0.,
bos_id=0,
eos_id=1,
beam_size=4,
max_out_len=256):
self.bos_id = bos_id
self.beam_size = beam_size
self.max_out_len = max_out_len
self.num_layers = num_layers
super(Seq2SeqAttnInferModel, self).__init__(src_vocab_size, trg_vocab_size, embed_dim, hidden_size, num_layers)
# Dynamic decoder for inference
self.beam_search_decoder = nn.BeamSearchDecoder(
self.decoder.lstm_attention.cell,
start_token=bos_id,
end_token=eos_id,
beam_size=beam_size,
embedding_fn=self.decoder.embedder,
output_fn=self.decoder.output_layer)
def forward(self, src, src_length):
encoder_output, encoder_final_state = self.encoder(src, src_length)
encoder_final_state = [
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
# Initial decoder initial states
decoder_initial_states = [
encoder_final_state,
self.decoder.lstm_attention.cell.get_initial_states(
batch_ref=encoder_output, shape=[self.hidden_size])
]
# Build attention mask to avoid paying attention on paddings
src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
encoder_padding_mask = (src_mask - 1.0) * self.INF
encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
# Tile the batch dimension with beam_size
encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output, self.beam_size)
encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask, self.beam_size)
# Dynamic decoding with beam search
seq_output, _ = nn.dynamic_decode(
decoder=self.beam_search_decoder,
inits=decoder_initial_states,
max_step_num=self.max_out_len,
encoder_output=encoder_output,
encoder_padding_mask=encoder_padding_mask)
return seq_output
解码部分
接下来对我们的任务选择beam search解码方式,本项目指定了beam_size为10
In [13]
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
In [14]
test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_data_loader(mode="test")
_, vocab = IWSLT15.get_vocab()
trg_idx2word = vocab.idx_to_token
model = paddle.Model(
Seq2SeqAttnInferModel(
src_vocab_size,
tgt_vocab_size,
hidden_size,
hidden_size,
num_layers,
dropout,
bos_id=bos_id,
eos_id=eos_id,
beam_size=beam_size,
max_out_len=256))
model.prepare()
我们可以将刚才保存的模型参数state_dict,使用set_state_dict函数,load进新的预测网络
In [15]
model.network.set_state_dict(state_dict)
In [16]
cand_list = []
with io.open(infer_output_file, 'w', encoding='utf-8') as f:
for data in test_loader():
with paddle.no_grad():
finished_seq = model.predict_batch(inputs=data)[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
id_list = post_process_seq(beam, bos_id, eos_id)
word_list = [trg_idx2word[id] for id in id_list]
sequence = " ".join(word_list) + "\n"
f.write(sequence)
cand_list.append(word_list)
break
test_ds = IWSLT15.get_datasets(["test"])
生成翻译结果之后,调用paddlenlp.metrics.BLEU计算翻译结果的BLEU指标。
In [17]
bleu = BLEU()
for i, data in enumerate(test_ds):
ref = data[1].split()
bleu.add_inst(cand_list[i], [ref])
print("BLEU score is %s." % bleu.score())
BLEU score is 0.24077765548678734.
PaddleNLP更多教程
使用seq2vec模块进行句子情感分析
使用预训练模型ERNIE优化情感分析
自定义数据集完成情感分析
使用预训练词向量优化情感分析
使用BiGRU-CRF模型完成快递单信息抽取
使用预训练模型ERNIE优化快递单信息抽取
使用预训练模型ERNIE-GEN实现智能写诗
使用TCN网络完成新冠疫情病例数预测
使用预训练模型完成阅读理解
使用PaddleNLP搭建seq2seq实现对联生成
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)