手撕实现self-attention和multihead-attention(pytorch版本)
【代码】手撕实现self-attention和multihead-attention(pytorchb版本)
·
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, input_dim, dim_q, dim_v ):
super().__init__()
self.linear_q = nn.Linear(input_dim, dim_q)
self.linear_k = nn.Linear(input_dim, dim_q)
self.linear_v = nn.Linear(input_dim, dim_v)
self.norm_factor = 1 / (dim_q) ** (0.5)
def forward(self, x):
q = self.linear_k(x)
k = self.linear_k(x)
v = self.linear_v(x)
q_k_mat = torch.matmul(q, k.transpose(-2,-1)) * self.norm_factor
att_relitive = torch.softmax(q_k_mat, dim=-1)
att_out = torch.matmul(att_relitive, v)
return att_out
class MultiHeadAttention(nn.Module):
def __init__(self, model_dim=512, head=8):
super().__init__()
self.head_dim = model_dim // head
self.head_num = head
self.model_dim = model_dim
self.linear_q = nn.Linear(model_dim, model_dim)
self.linear_k = nn.Linear(model_dim, model_dim)
self.linear_v = nn.Linear(model_dim, model_dim)
self.linear_o = nn.Linear(model_dim, model_dim)
def forward(self, q, k, v, mask=None):
batch_size, seq_len = q.size(0), q.size(1)
q_w = self.linear_q(q).view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1,2)
k_w = self.linear_q(k).view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1,2)
v_w = self.linear_q(v).view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1,2)
att_score = torch.matmul(q_w, k_w.transpose(-2,-1)) * (self.head_dim**(-0.5))
score = torch.softmax(att_score, dim=-1)
att_out = torch.matmul(score, v_w)
out = att_out.transpose(1,2).contiguous().view(batch_size, -1, self.model_dim)
out = self.linear_o(out)
return out
def main():
mla = MultiHeadAttention(model_dim=256)
x = torch.randn(16, 10, 256)
print(x.shape)
out = mla(x, x, x)
print(out.shape)
if __name__ == "__main__":
main()
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)