基于seq2seq的机器翻译
机器翻译是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程。

这里,我们将根据源语言输入,自动输出目标语言译文。这是一个典型的序列到序列(sequence2sequence, seq2seq)建模的场景,编码器-解码器(Encoder-Decoder)框架是解决seq2seq问题的经典方法,它能够将一个任意长度的源序列转换成另一个任意长度的目标序列:编码阶段将整个源序列编码成一个向量,解码阶段通过最大化预测序列概率,从中解码出整个目标序列。其模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,推荐参考飞桨官网机器翻译案例。



图1:encoder-decoder示意图

这里的Encoder采用LSTM,Decoder采用带有attention机制的LSTM。



图2:带有attention机制的encoder-decoder示意图

我们将源语言语句作为Encoder的输出,目标语言语句作为Decoder的输入,训练模型。

运行本目录下的范例模型需要安装PaddlePaddle 2.0-rc1及以上版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照安装文档中的说明更新 PaddlePaddle 安装版本。

In [1]
!pip install --upgrade paddlenlp==2.0.0b4 -i https://pypi.tuna.tsinghua.edu.cn/simple

import io
import numpy as np
from functools import partial

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.nn.initializer as I
from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper
from paddlenlp.metrics import Perplexity
from paddlenlp.datasets import IWSLT15
from paddlenlp.metrics import BLEU

from prepareinput import prepare_train_input, prepare_infer_input
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting paddlenlp==2.0.0b4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9c/de/9ca615db516f438bae269b457f320ac0cbfa6c90242e80e29da8e2f5491c/paddlenlp-2.0.0b4-py3-none-any.whl (164kB)
     |████████████████████████████████| 174kB 9.7MB/s eta 0:00:01
Collecting seqeval (from paddlenlp==2.0.0b4)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
     |████████████████████████████████| 51kB 33.4MB/s eta 0:00:01
Requirement already satisfied, skipping upgrade: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (0.4.4)
Requirement already satisfied, skipping upgrade: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (0.42.1)
Requirement already satisfied, skipping upgrade: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (4.1.0)
Requirement already satisfied, skipping upgrade: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (2.9.0)
Requirement already satisfied, skipping upgrade: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp==2.0.0b4) (2.1.1)
Requirement already satisfied, skipping upgrade: numpy>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp==2.0.0b4) (1.16.4)
Requirement already satisfied, skipping upgrade: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp==2.0.0b4) (0.22.1)
Requirement already satisfied, skipping upgrade: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from h5py->paddlenlp==2.0.0b4) (1.15.0)
Requirement already satisfied, skipping upgrade: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (1.0.0)
Requirement already satisfied, skipping upgrade: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (7.1.2)
Requirement already satisfied, skipping upgrade: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (0.8.53)
Requirement already satisfied, skipping upgrade: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (1.21.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (3.14.0)
Requirement already satisfied, skipping upgrade: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (2.22.0)
Requirement already satisfied, skipping upgrade: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (3.8.2)
Requirement already satisfied, skipping upgrade: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (1.1.1)
Requirement already satisfied, skipping upgrade: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddlenlp==2.0.0b4) (0.7.1.1)
Requirement already satisfied, skipping upgrade: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==2.0.0b4) (0.14.1)
Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==2.0.0b4) (1.3.0)
Requirement already satisfied, skipping upgrade: Jinja2>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (2.10.1)
Requirement already satisfied, skipping upgrade: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (2019.3)
Requirement already satisfied, skipping upgrade: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (2.8.0)
Requirement already satisfied, skipping upgrade: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp==2.0.0b4) (3.9.9)
Requirement already satisfied, skipping upgrade: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddlenlp==2.0.0b4) (0.18.0)
Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (5.1.2)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (0.23)
Requirement already satisfied, skipping upgrade: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (16.7.9)
Requirement already satisfied, skipping upgrade: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (1.3.4)
Requirement already satisfied, skipping upgrade: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (0.10.0)
Requirement already satisfied, skipping upgrade: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (1.3.0)
Requirement already satisfied, skipping upgrade: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (1.4.10)
Requirement already satisfied, skipping upgrade: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddlenlp==2.0.0b4) (2.0.1)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (3.0.4)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (1.25.6)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (2019.9.11)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->paddlenlp==2.0.0b4) (2.8)
Requirement already satisfied, skipping upgrade: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp==2.0.0b4) (0.6.1)
Requirement already satisfied, skipping upgrade: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp==2.0.0b4) (2.6.0)
Requirement already satisfied, skipping upgrade: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddlenlp==2.0.0b4) (2.2.0)
Requirement already satisfied, skipping upgrade: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp==2.0.0b4) (1.1.0)
Requirement already satisfied, skipping upgrade: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp==2.0.0b4) (0.16.0)
Requirement already satisfied, skipping upgrade: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddlenlp==2.0.0b4) (7.0)
Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.5->Flask-Babel>=1.0.0->visualdl->paddlenlp==2.0.0b4) (1.1.1)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version < "3.8"->pre-commit->visualdl->paddlenlp==2.0.0b4) (0.6.0)
Requirement already satisfied, skipping upgrade: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; python_version < "3.8"->pre-commit->visualdl->paddlenlp==2.0.0b4) (7.2.0)
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... done
  Created wheel for seqeval: filename=seqeval-1.2.2-cp37-none-any.whl size=16171 sha256=2d0a4be4752f7c4ec38fd16be4ce7341dd4140d9d896aec34116eadc8c2df6a8
  Stored in directory: /home/aistudio/.cache/pip/wheels/53/6a/9e/7e62c49241dcbafda4e9b1efaf3899193814e726387def0033
Successfully built seqeval
Installing collected packages: seqeval, paddlenlp
Successfully installed paddlenlp-2.0.0b4 seqeval-1.2.2
数据部分
数据集介绍
本教程使用IWSLT'15 English-Vietnamese data数据集中的英语到越南语的数据作为训练语料,IWSLT15为paddlenlp的内置数据集,以train.en为训练集源语言数据,train.vi为训练集目标语言数据,tst2012.en、tst2021.vi数据作为开发集,tst2013.en、tst2013.vi数据作为测试集,且已有定义好的vocab.en、vocab.vi作为源语言和目标语言词汇表(取词语频数前50000的词汇)。

构造dataloader
下面的create_data_loader函数用于创建训练、验证和预测时所需要的DataLoader对象,DataLoader对象用于产生一个个batch的数据。下面对函数中调用的paddlenlp内置函数作简单说明:

get_vocab:加载词汇表文件中的词语
get_default_transform_func():使用默认的将文本转化成数值id的方式
get_datasets:获得当前数据集的训练、验证和预测数据集
SamplerHelper:产生可迭代的sampler用于DataLoader对象直接调用
In [2]
def create_data_loader(mode):
    batch_size = 128
    max_len = 50
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    trans_func_tuple = IWSLT15.get_default_transform_func()
    dataset = IWSLT15.get_datasets(
        mode = [mode],
        transform_func=[trans_func_tuple])

    key = (lambda x, data_source: len(data_source[x][0]))
    cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])

    if mode in  ["train", "dev"]:
        dataset = dataset.filter(
        #对训练集源语言语句和目标语言语句的限制最大输入长度
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
        batch_sampler = SamplerHelper(dataset).shuffle().sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)
        data_loader = paddle.io.DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            #调用prepare_train_input函数,能够对产生的sampler进行pad操作,并返回实际长度等,详见prepareinput.py
            collate_fn=partial(
                prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
    else:
        batch_sampler = SamplerHelper(dataset).batch(batch_size=batch_size)
        data_loader = paddle.io.DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            collate_fn=partial(
                prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    return data_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
模型部分
下图是带有attention的Seq2Seq模型结构。下面我们分别定义网络的每个部分,最后构建Seq2Seq主网络。



图3:带有attention机制的encoder-decoder原理示意图

定义Encoder
Encoder部分可以直接利用PaddlePaddle2.0提供的rnn系列API的nn.LSTM,提供序列pad前的实际长度tensorsequence_length,得到encoder_output和encoder_state。

In [3]
class Seq2SeqEncoder(nn.Layer):
    def __init__(self,
                 vocab_size,
                 embed_dim,
                 hidden_size,
                 num_layers,
                 dropout_prob=0.,
                 init_scale=0.1):
        super(Seq2SeqEncoder, self).__init__()
        #Embedding层,用于将单词投影成词向量,vocab_size是词的数量,embed_dim是词向量的维度
        self.embedder = nn.Embedding(
            vocab_size,
            embed_dim,
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
                low=-init_scale, high=init_scale)))
        #定义一个lstm单元
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            direction="forward",
            dropout=dropout_prob if num_layers > 1 else 0.)

    def forward(self, sequence, sequence_length):
        inputs = self.embedder(sequence)
        encoder_output, encoder_state = self.lstm(
            inputs, sequence_length=sequence_length)

        return encoder_output, encoder_state
定义Decoder
定义Seq2SeqDecoderCell
由于Decoder部分是带有attention的LSTM,我们不能复用nn.LSTM,所以需要定义Seq2SeqDecoderCell

In [4]
class Seq2SeqDecoderCell(nn.RNNCellBase):
    def __init__(self, num_layers, input_size, hidden_size, dropout_prob=0.):
        super(Seq2SeqDecoderCell, self).__init__()
        if dropout_prob > 0.0:
            self.dropout = nn.Dropout(dropout_prob)
        else:
            self.dropout = None

        self.lstm_cells = nn.LayerList([
            nn.LSTMCell(
                input_size=input_size + hidden_size if i == 0 else hidden_size,
                hidden_size=hidden_size) for i in range(num_layers)
        ])

        self.attention_layer = AttentionLayer(hidden_size)

    def forward(self,
                step_input,
                states,
                encoder_output,
                encoder_padding_mask=None):
        lstm_states, input_feed = states
        new_lstm_states = []
        step_input = paddle.concat([step_input, input_feed], 1)
        for i, lstm_cell in enumerate(self.lstm_cells):
            out, new_lstm_state = lstm_cell(step_input, lstm_states[i])
            if self.dropout:
                step_input = self.dropout(out)
            else:
                step_input = out

            new_lstm_states.append(new_lstm_state)
        out = self.attention_layer(step_input, encoder_output,
                                   encoder_padding_mask)
        return out, [new_lstm_states, out]
定义AttentionLayer
其中AttentionLayer可以这样定义:

In [5]
class AttentionLayer(nn.Layer):
    def __init__(self, hidden_size, bias=False, init_scale=0.1):
        super(AttentionLayer, self).__init__()
        self.input_proj = nn.Linear(
            hidden_size,
            hidden_size,
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
                low=-init_scale, high=init_scale)),
            bias_attr=bias)
        self.output_proj = nn.Linear(
            hidden_size + hidden_size,
            hidden_size,
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
                low=-init_scale, high=init_scale)),
            bias_attr=bias)

    def forward(self, hidden, encoder_output, encoder_padding_mask):
        encoder_output = self.input_proj(encoder_output)
        attn_scores = paddle.matmul(
            paddle.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)

        if encoder_padding_mask is not None:
            attn_scores = paddle.add(attn_scores, encoder_padding_mask)

        attn_scores = F.softmax(attn_scores)
        attn_out = paddle.squeeze(
            paddle.matmul(attn_scores, encoder_output), [1])
        attn_out = paddle.concat([attn_out, hidden], 1)
        attn_out = self.output_proj(attn_out)
        return attn_out
定义Seq2SeqDecoder
有了Seq2SeqDecoderCell,就可以构建Seq2SeqDecoder了

In [6]
class Seq2SeqDecoder(nn.Layer):
    def __init__(self,
                 vocab_size,
                 embed_dim,
                 hidden_size,
                 num_layers,
                 dropout_prob=0.,
                 init_scale=0.1):
        super(Seq2SeqDecoder, self).__init__()
        self.embedder = nn.Embedding(
            vocab_size,
            embed_dim,
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
                low=-init_scale, high=init_scale)))
        self.lstm_attention = nn.RNN(Seq2SeqDecoderCell(
            num_layers, embed_dim, hidden_size, dropout_prob),
                                     is_reverse=False,
                                     time_major=False)
        #全连接层
        self.output_layer = nn.Linear(
            hidden_size,
            vocab_size,
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
                low=-init_scale, high=init_scale)),
            bias_attr=False)

    def forward(self, trg, decoder_initial_states, encoder_output,
                encoder_padding_mask):
        inputs = self.embedder(trg)
        decoder_output, _ = self.lstm_attention(
            inputs,
            initial_states=decoder_initial_states,
            encoder_output=encoder_output,
            encoder_padding_mask=encoder_padding_mask)
        predict = self.output_layer(decoder_output)

        return predict
构建主网络Seq2SeqAttnModel
Encoder和Decoder定义好之后,网络就可以构建起来了。

Encoder的input和output均为shape为[batch_size, src_length, hidden_size]的Tensor
将Encoder的最后一个时间步的状态输出作为Decoder的初始状态输入
Decoder的input和output均为shape为[batch_size, trg_length, hidden_size]的Tensor
In [7]
class Seq2SeqAttnModel(nn.Layer):
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 embed_dim,
                 hidden_size,
                 num_layers,
                 dropout_prob=0.,
                 eos_id=1,
                 init_scale=0.1):
        super(Seq2SeqAttnModel, self).__init__()
        self.hidden_size = hidden_size
        self.eos_id = eos_id
        self.num_layers = num_layers
        self.INF = 1e9
        self.encoder = Seq2SeqEncoder(src_vocab_size, embed_dim, hidden_size,
                                      num_layers, dropout_prob, init_scale)
        self.decoder = Seq2SeqDecoder(trg_vocab_size, embed_dim, hidden_size,
                                      num_layers, dropout_prob, init_scale)
    def forward(self, src, src_length, trg):
        encoder_output, encoder_final_state = self.encoder(src, src_length)
        # Transfer shape of encoder_final_states to [num_layers, 2, batch_size, hidden_size]
        encoder_final_states = [
            (encoder_final_state[0][i], encoder_final_state[1][i])
            for i in range(self.num_layers)
        ]
        # Construct decoder initial states: use input_feed and the shape is
        # [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states
        decoder_initial_states = [
            encoder_final_states,
            self.decoder.lstm_attention.cell.get_initial_states(
                batch_ref=encoder_output, shape=[self.hidden_size])
        ]
        # Build attention mask to avoid paying attention on padddings
        src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())
        encoder_padding_mask = (src_mask - 1.0) * self.INF
        encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])

        predict = self.decoder(trg, decoder_initial_states, encoder_output,
                               encoder_padding_mask)
        return predict
定义损失函数
这里使用的是交叉熵损失函数,由于前面对不足长度的位置进行了padding,生成句子计算loss时需要忽略那些原本是padding的位置的值,因此需要在损失函数中引入trg_mask参数,由于PaddlePaddle框架提供的paddle.nn.CrossEntropyLoss不能接受trg_mask参数,因此在这里需要重新定义:

In [8]
class CrossEntropyCriterion(nn.Layer):
    def __init__(self):
        super(CrossEntropyCriterion, self).__init__()

    def forward(self, predict, label, trg_mask):
        cost = F.softmax_with_cross_entropy(
            logits=predict, label=label, soft_label=False)
        cost = paddle.squeeze(cost, axis=[2])
        masked_cost = cost * trg_mask
        batch_mean_cost = paddle.mean(masked_cost, axis=[0])
        seq_cost = paddle.sum(batch_mean_cost)

        return seq_cost
执行过程
训练过程
使用高层API执行训练,需要调用prepare和fit函数。

在prepare函数中,配置优化器、损失函数,以及评价指标。其中评价指标使用的是PaddleNLP提供的困惑度计算API paddlenlp.metrics.Perplexity。

In [9]
batch_size = 128
num_layers = 2
dropout = 0.2
init_scale = 0.1
hidden_size =512
max_grad_norm = 5.0
learning_rate = 0.001
max_epoch = 12
max_len = 50
model_path = './attention_models'
infer_output_file = './output_files'
beam_size = 10
log_freq = 100
device = "gpu"

# Define dataloader
train_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id  = create_data_loader(mode="train")
eval_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id  = create_data_loader(mode="dev")

model = paddle.Model(
  Seq2SeqAttnModel(src_vocab_size, tgt_vocab_size, hidden_size, hidden_size, num_layers, dropout, eos_id))

grad_clip = nn.ClipGradByGlobalNorm(max_grad_norm)
optimizer = paddle.optimizer.Adam(
    learning_rate=learning_rate,
    parameters=model.parameters(),
    grad_clip=grad_clip)

ppl_metric = Perplexity()
model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)
100%|██████████| 9967/9967 [00:00<00:00, 40938.29it/s]
如果你安装了VisualDL,可以在fit中添加一个callbacks参数使用VisualDL观测你的训练过程,如下:

In [10]
model.fit(train_data=train_loader,
          eval_data=eval_loader,
          epochs=max_epoch,
          eval_freq=1,
          save_freq=1,
          save_dir=model_path,
          log_freq=log_freq,
          callbacks=[paddle.callbacks.VisualDL('./log')])
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/12
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
step  100/1041 - loss: 316.9290 - Perplexity: 569.7516 - 157ms/step
step  200/1041 - loss: 283.0392 - Perplexity: 377.6471 - 154ms/step
step  300/1041 - loss: 260.9322 - Perplexity: 267.6343 - 154ms/step
step  400/1041 - loss: 235.1697 - Perplexity: 201.8174 - 154ms/step
step  500/1041 - loss: 229.3707 - Perplexity: 159.4850 - 153ms/step
step  600/1041 - loss: 203.4672 - Perplexity: 129.9806 - 153ms/step
step  700/1041 - loss: 201.8655 - Perplexity: 108.9772 - 153ms/step
step  800/1041 - loss: 196.9280 - Perplexity: 93.9172 - 154ms/step
step  900/1041 - loss: 186.4891 - Perplexity: 82.5902 - 154ms/step
step 1000/1041 - loss: 177.8873 - Perplexity: 73.5232 - 154ms/step
step 1041/1041 - loss: 80.8191 - Perplexity: 70.4385 - 154ms/step
save checkpoint at /home/aistudio/attention_models/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 172.1245 - Perplexity: 24.3200 - 75ms/step
Eval samples: 1553
Epoch 2/12
step  100/1041 - loss: 170.8153 - Perplexity: 20.7067 - 154ms/step
step  200/1041 - loss: 169.8316 - Perplexity: 20.1637 - 155ms/step
step  300/1041 - loss: 164.2285 - Perplexity: 19.6250 - 154ms/step
step  400/1041 - loss: 166.2451 - Perplexity: 19.1300 - 154ms/step
step  500/1041 - loss: 154.8416 - Perplexity: 18.6932 - 155ms/step
step  600/1041 - loss: 161.9329 - Perplexity: 18.2686 - 154ms/step
step  700/1041 - loss: 158.8072 - Perplexity: 17.9527 - 154ms/step
step  800/1041 - loss: 151.1718 - Perplexity: 17.5832 - 154ms/step
step  900/1041 - loss: 150.8292 - Perplexity: 17.2788 - 154ms/step
step 1000/1041 - loss: 151.9545 - Perplexity: 16.9614 - 154ms/step
step 1041/1041 - loss: 64.2150 - Perplexity: 16.8593 - 154ms/step
save checkpoint at /home/aistudio/attention_models/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 163.4615 - Perplexity: 16.4078 - 75ms/step
Eval samples: 1553
Epoch 3/12
step  100/1041 - loss: 144.9384 - Perplexity: 11.9777 - 154ms/step
step  200/1041 - loss: 147.2327 - Perplexity: 11.9371 - 153ms/step
step  300/1041 - loss: 139.2942 - Perplexity: 11.9587 - 154ms/step
step  400/1041 - loss: 140.5380 - Perplexity: 11.9496 - 154ms/step
step  500/1041 - loss: 140.6450 - Perplexity: 11.9257 - 154ms/step
step  600/1041 - loss: 140.4239 - Perplexity: 11.8719 - 153ms/step
step  700/1041 - loss: 141.9416 - Perplexity: 11.8256 - 153ms/step
step  800/1041 - loss: 142.7165 - Perplexity: 11.7741 - 153ms/step
step  900/1041 - loss: 136.2898 - Perplexity: 11.7381 - 153ms/step
step 1000/1041 - loss: 139.1634 - Perplexity: 11.6733 - 153ms/step
step 1041/1041 - loss: 56.9410 - Perplexity: 11.6393 - 153ms/step
save checkpoint at /home/aistudio/attention_models/2
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 134.3902 - Perplexity: 13.4229 - 75ms/step
Eval samples: 1553
Epoch 4/12
step  100/1041 - loss: 133.1941 - Perplexity: 9.3546 - 154ms/step
step  200/1041 - loss: 127.7615 - Perplexity: 9.4610 - 154ms/step
step  300/1041 - loss: 128.3531 - Perplexity: 9.4909 - 154ms/step
step  400/1041 - loss: 131.7594 - Perplexity: 9.5061 - 154ms/step
step  500/1041 - loss: 135.5802 - Perplexity: 9.5315 - 153ms/step
step  600/1041 - loss: 132.5778 - Perplexity: 9.5537 - 153ms/step
step  700/1041 - loss: 127.7056 - Perplexity: 9.5584 - 153ms/step
step  800/1041 - loss: 129.8652 - Perplexity: 9.5510 - 153ms/step
step  900/1041 - loss: 126.0126 - Perplexity: 9.5457 - 153ms/step
step 1000/1041 - loss: 129.1563 - Perplexity: 9.5513 - 153ms/step
step 1041/1041 - loss: 43.0635 - Perplexity: 9.5494 - 153ms/step
save checkpoint at /home/aistudio/attention_models/3
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 138.2625 - Perplexity: 12.5110 - 75ms/step
Eval samples: 1553
Epoch 5/12
step  100/1041 - loss: 119.1514 - Perplexity: 7.9890 - 156ms/step
step  200/1041 - loss: 126.6494 - Perplexity: 8.0800 - 155ms/step
step  300/1041 - loss: 122.2061 - Perplexity: 8.1121 - 154ms/step
step  400/1041 - loss: 117.8465 - Perplexity: 8.1787 - 155ms/step
step  500/1041 - loss: 120.3643 - Perplexity: 8.2348 - 154ms/step
step  600/1041 - loss: 126.8760 - Perplexity: 8.2641 - 154ms/step
step  700/1041 - loss: 125.7731 - Perplexity: 8.3052 - 154ms/step
step  800/1041 - loss: 122.7028 - Perplexity: 8.3227 - 154ms/step
step  900/1041 - loss: 129.4286 - Perplexity: 8.3327 - 154ms/step
step 1000/1041 - loss: 125.5585 - Perplexity: 8.3499 - 153ms/step
step 1041/1041 - loss: 48.3358 - Perplexity: 8.3549 - 153ms/step
save checkpoint at /home/aistudio/attention_models/4
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 147.6611 - Perplexity: 12.0037 - 75ms/step
Eval samples: 1553
Epoch 6/12
step  100/1041 - loss: 120.7922 - Perplexity: 7.2646 - 155ms/step
step  200/1041 - loss: 118.1561 - Perplexity: 7.2689 - 154ms/step
step  300/1041 - loss: 120.9546 - Perplexity: 7.3105 - 153ms/step
step  400/1041 - loss: 118.6025 - Perplexity: 7.3526 - 153ms/step
step  500/1041 - loss: 117.7779 - Perplexity: 7.3928 - 153ms/step
step  600/1041 - loss: 122.9525 - Perplexity: 7.4319 - 153ms/step
step  700/1041 - loss: 117.7925 - Perplexity: 7.4670 - 153ms/step
step  800/1041 - loss: 126.4061 - Perplexity: 7.4887 - 153ms/step
step  900/1041 - loss: 118.9017 - Perplexity: 7.5278 - 153ms/step
step 1000/1041 - loss: 122.6585 - Perplexity: 7.5620 - 154ms/step
step 1041/1041 - loss: 57.3045 - Perplexity: 7.5724 - 154ms/step
save checkpoint at /home/aistudio/attention_models/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 141.4744 - Perplexity: 11.7521 - 75ms/step
Eval samples: 1553
Epoch 7/12
step  100/1041 - loss: 115.3885 - Perplexity: 6.5775 - 154ms/step
step  200/1041 - loss: 115.0385 - Perplexity: 6.6283 - 153ms/step
step  300/1041 - loss: 116.8085 - Perplexity: 6.6999 - 153ms/step
step  400/1041 - loss: 109.3981 - Perplexity: 6.7474 - 153ms/step
step  500/1041 - loss: 118.9770 - Perplexity: 6.7980 - 153ms/step
step  600/1041 - loss: 119.6290 - Perplexity: 6.8498 - 153ms/step
step  700/1041 - loss: 114.9620 - Perplexity: 6.8818 - 153ms/step
step  800/1041 - loss: 114.2539 - Perplexity: 6.9167 - 153ms/step
step  900/1041 - loss: 119.0262 - Perplexity: 6.9484 - 154ms/step
step 1000/1041 - loss: 114.2731 - Perplexity: 6.9769 - 154ms/step
step 1041/1041 - loss: 45.2423 - Perplexity: 6.9787 - 154ms/step
save checkpoint at /home/aistudio/attention_models/6
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 141.1222 - Perplexity: 11.5050 - 75ms/step
Eval samples: 1553
Epoch 8/12
step  100/1041 - loss: 106.5752 - Perplexity: 6.0088 - 155ms/step
step  200/1041 - loss: 112.7031 - Perplexity: 6.1096 - 154ms/step
step  300/1041 - loss: 112.7603 - Perplexity: 6.1900 - 154ms/step
step  400/1041 - loss: 112.4972 - Perplexity: 6.2464 - 154ms/step
step  500/1041 - loss: 110.8458 - Perplexity: 6.2897 - 154ms/step
step  600/1041 - loss: 118.5520 - Perplexity: 6.3431 - 154ms/step
step  700/1041 - loss: 116.1271 - Perplexity: 6.3832 - 154ms/step
step  800/1041 - loss: 115.2826 - Perplexity: 6.4184 - 154ms/step
step  900/1041 - loss: 115.0764 - Perplexity: 6.4395 - 154ms/step
step 1000/1041 - loss: 118.9429 - Perplexity: 6.4657 - 154ms/step
step 1041/1041 - loss: 48.1713 - Perplexity: 6.4823 - 154ms/step
save checkpoint at /home/aistudio/attention_models/7
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 132.2597 - Perplexity: 11.5465 - 76ms/step
Eval samples: 1553
Epoch 9/12
step  100/1041 - loss: 106.0669 - Perplexity: 5.6741 - 156ms/step
step  200/1041 - loss: 108.2129 - Perplexity: 5.7532 - 154ms/step
step  300/1041 - loss: 112.7348 - Perplexity: 5.8025 - 154ms/step
step  400/1041 - loss: 108.3239 - Perplexity: 5.8568 - 154ms/step
step  500/1041 - loss: 115.5850 - Perplexity: 5.9041 - 154ms/step
step  600/1041 - loss: 108.4180 - Perplexity: 5.9378 - 154ms/step
step  700/1041 - loss: 112.5859 - Perplexity: 5.9876 - 153ms/step
step  800/1041 - loss: 113.8619 - Perplexity: 6.0267 - 154ms/step
step  900/1041 - loss: 115.6805 - Perplexity: 6.0513 - 154ms/step
step 1000/1041 - loss: 118.7408 - Perplexity: 6.0714 - 154ms/step
step 1041/1041 - loss: 45.1587 - Perplexity: 6.0828 - 154ms/step
save checkpoint at /home/aistudio/attention_models/8
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 148.6588 - Perplexity: 11.5041 - 75ms/step
Eval samples: 1553
Epoch 10/12
step  100/1041 - loss: 107.6575 - Perplexity: 5.3731 - 153ms/step
step  200/1041 - loss: 105.6288 - Perplexity: 5.4270 - 153ms/step
step  300/1041 - loss: 109.6870 - Perplexity: 5.4661 - 154ms/step
step  400/1041 - loss: 106.5136 - Perplexity: 5.5160 - 153ms/step
step  500/1041 - loss: 113.4923 - Perplexity: 5.5568 - 153ms/step
step  600/1041 - loss: 108.5855 - Perplexity: 5.5904 - 153ms/step
step  700/1041 - loss: 107.1722 - Perplexity: 5.6375 - 153ms/step
step  800/1041 - loss: 115.3743 - Perplexity: 5.6674 - 153ms/step
step  900/1041 - loss: 114.0770 - Perplexity: 5.7021 - 153ms/step
step 1000/1041 - loss: 111.8619 - Perplexity: 5.7246 - 153ms/step
step 1041/1041 - loss: 47.1277 - Perplexity: 5.7353 - 154ms/step
save checkpoint at /home/aistudio/attention_models/9
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 142.5951 - Perplexity: 11.4914 - 76ms/step
Eval samples: 1553
Epoch 11/12
step  100/1041 - loss: 103.6631 - Perplexity: 5.0292 - 153ms/step
step  200/1041 - loss: 97.6034 - Perplexity: 5.0740 - 154ms/step
step  300/1041 - loss: 109.3245 - Perplexity: 5.1391 - 153ms/step
step  400/1041 - loss: 106.0880 - Perplexity: 5.1917 - 154ms/step
step  500/1041 - loss: 105.9739 - Perplexity: 5.2453 - 154ms/step
step  600/1041 - loss: 108.3997 - Perplexity: 5.2823 - 154ms/step
step  700/1041 - loss: 108.6982 - Perplexity: 5.3275 - 154ms/step
step  800/1041 - loss: 106.0240 - Perplexity: 5.3552 - 154ms/step
step  900/1041 - loss: 102.1773 - Perplexity: 5.3887 - 154ms/step
step 1000/1041 - loss: 111.5466 - Perplexity: 5.4194 - 154ms/step
step 1041/1041 - loss: 36.5989 - Perplexity: 5.4327 - 154ms/step
save checkpoint at /home/aistudio/attention_models/10
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 147.4474 - Perplexity: 11.9804 - 75ms/step
Eval samples: 1553
Epoch 12/12
step  100/1041 - loss: 98.2783 - Perplexity: 4.7699 - 154ms/step
step  200/1041 - loss: 102.6894 - Perplexity: 4.8731 - 153ms/step
step  300/1041 - loss: 105.4697 - Perplexity: 4.9088 - 153ms/step
step  400/1041 - loss: 97.5007 - Perplexity: 4.9465 - 153ms/step
step  500/1041 - loss: 108.4352 - Perplexity: 4.9925 - 154ms/step
step  600/1041 - loss: 102.4900 - Perplexity: 5.0358 - 153ms/step
step  700/1041 - loss: 102.6786 - Perplexity: 5.0682 - 153ms/step
step  800/1041 - loss: 104.1088 - Perplexity: 5.0986 - 153ms/step
step  900/1041 - loss: 107.4128 - Perplexity: 5.1334 - 153ms/step
step 1000/1041 - loss: 105.3015 - Perplexity: 5.1618 - 154ms/step
step 1041/1041 - loss: 37.2941 - Perplexity: 5.1706 - 153ms/step
save checkpoint at /home/aistudio/attention_models/11
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 13/13 - loss: 134.3483 - Perplexity: 11.9075 - 76ms/step
Eval samples: 1553
save checkpoint at /home/aistudio/attention_models/final
我们可以调用state_dict函数,取出网络参数

In [11]
state_dict = model.network.state_dict()
模型预测
定义预测网络Seq2SeqAttnInferModel
预测网络继承上面的主网络Seq2SeqAttnModel,定义子类Seq2SeqAttnInferModel

In [12]
class Seq2SeqAttnInferModel(Seq2SeqAttnModel):
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 embed_dim,
                 hidden_size,
                 num_layers,
                 dropout_prob=0.,
                 bos_id=0,
                 eos_id=1,
                 beam_size=4,
                 max_out_len=256):
        self.bos_id = bos_id
        self.beam_size = beam_size
        self.max_out_len = max_out_len
        self.num_layers = num_layers
        super(Seq2SeqAttnInferModel, self).__init__(src_vocab_size, trg_vocab_size, embed_dim, hidden_size, num_layers)
        # Dynamic decoder for inference
        self.beam_search_decoder = nn.BeamSearchDecoder(
            self.decoder.lstm_attention.cell,
            start_token=bos_id,
            end_token=eos_id,
            beam_size=beam_size,
            embedding_fn=self.decoder.embedder,
            output_fn=self.decoder.output_layer)
    def forward(self, src, src_length):
        encoder_output, encoder_final_state = self.encoder(src, src_length)

        encoder_final_state = [
            (encoder_final_state[0][i], encoder_final_state[1][i])
            for i in range(self.num_layers)
        ]

        # Initial decoder initial states
        decoder_initial_states = [
            encoder_final_state,
            self.decoder.lstm_attention.cell.get_initial_states(
                batch_ref=encoder_output, shape=[self.hidden_size])
        ]
        # Build attention mask to avoid paying attention on paddings
        src_mask = (src != self.eos_id).astype(paddle.get_default_dtype())

        encoder_padding_mask = (src_mask - 1.0) * self.INF
        encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])

        # Tile the batch dimension with beam_size
        encoder_output = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_output, self.beam_size)
        encoder_padding_mask = nn.BeamSearchDecoder.tile_beam_merge_with_batch(
            encoder_padding_mask, self.beam_size)
        
        # Dynamic decoding with beam search
        seq_output, _ = nn.dynamic_decode(
            decoder=self.beam_search_decoder,
            inits=decoder_initial_states,
            max_step_num=self.max_out_len,
            encoder_output=encoder_output,
            encoder_padding_mask=encoder_padding_mask)
        return seq_output
解码部分
接下来对我们的任务选择beam search解码方式,本项目指定了beam_size为10

In [13]
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
    """
    Post-process the decoded sequence.
    """
    eos_pos = len(seq) - 1
    for i, idx in enumerate(seq):
        if idx == eos_idx:
            eos_pos = i
            break
    seq = [
        idx for idx in seq[:eos_pos + 1]
        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
    ]
    return seq
In [14]
test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id  = create_data_loader(mode="test")
_, vocab = IWSLT15.get_vocab()
trg_idx2word = vocab.idx_to_token

model = paddle.Model(
    Seq2SeqAttnInferModel(
        src_vocab_size,
        tgt_vocab_size,
        hidden_size,
        hidden_size,
        num_layers,
        dropout,
        bos_id=bos_id,
        eos_id=eos_id,
        beam_size=beam_size,
        max_out_len=256))

model.prepare()
我们可以将刚才保存的模型参数state_dict,使用set_state_dict函数,load进新的预测网络

In [15]
model.network.set_state_dict(state_dict)
In [16]
cand_list = []
with io.open(infer_output_file, 'w', encoding='utf-8') as f:
    for data in test_loader():
        with paddle.no_grad():
            finished_seq = model.predict_batch(inputs=data)[0]
        finished_seq = finished_seq[:, :, np.newaxis] if len(
            finished_seq.shape) == 2 else finished_seq
        finished_seq = np.transpose(finished_seq, [0, 2, 1])
        for ins in finished_seq:
            for beam_idx, beam in enumerate(ins):
                id_list = post_process_seq(beam, bos_id, eos_id)
                word_list = [trg_idx2word[id] for id in id_list]
                sequence = " ".join(word_list) + "\n"
                f.write(sequence)
                cand_list.append(word_list)
                break

test_ds = IWSLT15.get_datasets(["test"])
生成翻译结果之后,调用paddlenlp.metrics.BLEU计算翻译结果的BLEU指标。

In [17]
bleu = BLEU()
for i, data in enumerate(test_ds):
    ref = data[1].split()
    bleu.add_inst(cand_list[i], [ref])
print("BLEU score is %s." % bleu.score())
BLEU score is 0.24077765548678734.
PaddleNLP更多教程
使用seq2vec模块进行句子情感分析
使用预训练模型ERNIE优化情感分析
自定义数据集完成情感分析
使用预训练词向量优化情感分析
使用BiGRU-CRF模型完成快递单信息抽取
使用预训练模型ERNIE优化快递单信息抽取
使用预训练模型ERNIE-GEN实现智能写诗
使用TCN网络完成新冠疫情病例数预测
使用预训练模型完成阅读理解
使用PaddleNLP搭建seq2seq实现对联生成
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐