• 前言

        加入Attention机制会使得LSTM的分类效果大幅度提高。

  • 手动实现Self-Attention

import torch.nn as nn
import torch


class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.qk_weights = nn.Linear(hidden_dim, hidden_dim)
        self.dot_product = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, x):
        qk_weights = self.qk_weights(x)
        qk_dot = self.dot_product(qk_weights).squeeze(2)
        similarity = torch.softmax(qk_dot, dim=1)
        output = torch.sum(similarity.unsqueeze(2) * x, dim=1)
        return output, similarity


class AttentionLSTM(nn.Module):
    def __init__(self, output_dim=2):
        super(AttentionLSTM, self).__init__()
        self.embedding = nn.Embedding(embedding_dim=8, num_embeddings=257)
        self.sequence = nn.LSTM(input_size=8, hidden_size=16, num_layers=2, batch_first=True, bidirectional=True)
        self.attention = Attention(32)
        self.mlp = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, output_dim),
        )

    def forward(self, x):
        x = self.embedding(x)
        out, hidden = self.sequence(x)
        out, weights = self.attention(out)
        x = self.mlp(out)
        return x


if __name__ == '__main__':
    model = AttentionLSTM(output_dim=2)
    x = torch.tensor([
        [1, 212, 145],
        [53, 151, 45]
    ])
    print(x.shape)
    out = model(x)
    print(out.shape)
  • 使用pytorch库实现Self-Attention

import torch.nn as nn
import torch


class AttentionLSTM(nn.Module):
    def __init__(self, output_dim=2):
        super(AttentionLSTM, self).__init__()
        self.embedding = nn.Embedding(embedding_dim=8, num_embeddings=257)
        self.sequence = nn.LSTM(input_size=8, hidden_size=16, num_layers=2, batch_first=True, bidirectional=True)

        self.q = nn.Linear(32, 32, bias=False)
        self.k = nn.Linear(32, 32, bias=False)
        self.v = nn.Linear(32, 32, bias=False)
        self.attention = nn.MultiheadAttention(32, 1)

        self.mlp = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, output_dim),
        )

    def forward(self, x):
        x = self.embedding(x)
        out, hidden = self.sequence(x)
        q = self.q(out)
        k = self.k(out)
        v = self.v(out)
        output, attn_output_weights = self.attention(q, k, v)
        output = torch.sum(output, dim=1)
        x = self.mlp(output)
        return x


if __name__ == '__main__':
    model = AttentionLSTM(output_dim=2)
    x = torch.tensor([
        [1, 212, 145],
        [53, 151, 45]
    ])
    print(x.shape)
    out = model(x)
    print(out.shape)

实际测试下来,nn.MultiheadAttention的收敛速度远低于手动实现的Attention。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐