import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(self, input_dim, dim_q, dim_v ):
        super().__init__()


        self.linear_q = nn.Linear(input_dim, dim_q) 
        self.linear_k = nn.Linear(input_dim, dim_q) 
        self.linear_v = nn.Linear(input_dim, dim_v) 

        self.norm_factor = 1 / (dim_q) ** (0.5)

    def forward(self, x):


        q = self.linear_k(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        q_k_mat = torch.matmul(q, k.transpose(-2,-1)) * self.norm_factor

        att_relitive = torch.softmax(q_k_mat, dim=-1)

        att_out = torch.matmul(att_relitive, v)

        return att_out


class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim=512, head=8):
        super().__init__()


        self.head_dim = model_dim // head
        self.head_num = head
        self.model_dim = model_dim

        self.linear_q = nn.Linear(model_dim, model_dim)
        self.linear_k = nn.Linear(model_dim, model_dim)
        self.linear_v = nn.Linear(model_dim, model_dim)
        self.linear_o = nn.Linear(model_dim, model_dim)
        

    def forward(self, q, k, v, mask=None):
        
        batch_size, seq_len = q.size(0), q.size(1)

        q_w = self.linear_q(q).view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1,2)
        k_w = self.linear_q(k).view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1,2)
        v_w = self.linear_q(v).view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1,2)

        att_score = torch.matmul(q_w, k_w.transpose(-2,-1)) * (self.head_dim**(-0.5))
        score = torch.softmax(att_score, dim=-1)

        att_out = torch.matmul(score, v_w)
        out = att_out.transpose(1,2).contiguous().view(batch_size, -1, self.model_dim)

        out = self.linear_o(out)

        return out


def main():

    mla = MultiHeadAttention(model_dim=256)


    x = torch.randn(16, 10, 256)
   
    print(x.shape)

    out = mla(x, x, x)
    print(out.shape)


if __name__ == "__main__":


    main()

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐