Attention-LSTM序列分类模型实现(pytorch)
加入Attention机制会使得LSTM的分类效果大幅度提高。
·
-
前言
加入Attention机制会使得LSTM的分类效果大幅度提高。
-
手动实现Self-Attention
import torch.nn as nn
import torch
class Attention(nn.Module):
def __init__(self, hidden_dim):
super(Attention, self).__init__()
self.hidden_dim = hidden_dim
self.qk_weights = nn.Linear(hidden_dim, hidden_dim)
self.dot_product = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, x):
qk_weights = self.qk_weights(x)
qk_dot = self.dot_product(qk_weights).squeeze(2)
similarity = torch.softmax(qk_dot, dim=1)
output = torch.sum(similarity.unsqueeze(2) * x, dim=1)
return output, similarity
class AttentionLSTM(nn.Module):
def __init__(self, output_dim=2):
super(AttentionLSTM, self).__init__()
self.embedding = nn.Embedding(embedding_dim=8, num_embeddings=257)
self.sequence = nn.LSTM(input_size=8, hidden_size=16, num_layers=2, batch_first=True, bidirectional=True)
self.attention = Attention(32)
self.mlp = nn.Sequential(
nn.Linear(32, 16),
nn.ReLU(),
nn.Linear(16, output_dim),
)
def forward(self, x):
x = self.embedding(x)
out, hidden = self.sequence(x)
out, weights = self.attention(out)
x = self.mlp(out)
return x
if __name__ == '__main__':
model = AttentionLSTM(output_dim=2)
x = torch.tensor([
[1, 212, 145],
[53, 151, 45]
])
print(x.shape)
out = model(x)
print(out.shape)
-
使用pytorch库实现Self-Attention
import torch.nn as nn
import torch
class AttentionLSTM(nn.Module):
def __init__(self, output_dim=2):
super(AttentionLSTM, self).__init__()
self.embedding = nn.Embedding(embedding_dim=8, num_embeddings=257)
self.sequence = nn.LSTM(input_size=8, hidden_size=16, num_layers=2, batch_first=True, bidirectional=True)
self.q = nn.Linear(32, 32, bias=False)
self.k = nn.Linear(32, 32, bias=False)
self.v = nn.Linear(32, 32, bias=False)
self.attention = nn.MultiheadAttention(32, 1)
self.mlp = nn.Sequential(
nn.Linear(32, 16),
nn.ReLU(),
nn.Linear(16, output_dim),
)
def forward(self, x):
x = self.embedding(x)
out, hidden = self.sequence(x)
q = self.q(out)
k = self.k(out)
v = self.v(out)
output, attn_output_weights = self.attention(q, k, v)
output = torch.sum(output, dim=1)
x = self.mlp(output)
return x
if __name__ == '__main__':
model = AttentionLSTM(output_dim=2)
x = torch.tensor([
[1, 212, 145],
[53, 151, 45]
])
print(x.shape)
out = model(x)
print(out.shape)
实际测试下来,nn.MultiheadAttention的收敛速度远低于手动实现的Attention。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)