1 2 3 4 5 6 import torch.nn as nnimport torchimport numpy as npfrom torch.autograd import Variableimport mathimport torch.nn.functional as F
注意力计算公式
tensor([[[[-1.6685, -1.7979, 0.0747]],
[[ 1.1604, 1.1415, 0.4631]],
[[ 1.6218, -1.3112, -0.6065]]],
[[[ 0.2836, -0.8159, -0.4028]],
[[-0.0721, -0.3244, 0.2214]],
[[-0.9558, 0.5414, -0.4869]]]])
tensor([[[[ 0.5522, -3.4192, 0.7006]],
[[-1.5651, -1.0705, 1.7866]],
[[-2.1893, -0.2521, 0.2480]]],
[[[ 0.4621, 1.0492, 0.5085]],
[[-0.3847, -1.9930, 1.6604]],
[[-1.0364, -0.3537, 1.5496]]]])
1 torch.matmul(x,y.transpose(-1 ,-2 ))
tensor([[[[ 5.2783]],
[[-2.2107]],
[[-3.3705]]],
[[[-0.9298]],
[[ 1.0419]],
[[ 0.0446]]]])
torch.Size([2, 3, 3, 1])
tensor([[ 0.0000, -1.8719, 0.5233, -0.0000],
[-0.0000, 0.6212, 0.2304, -0.1491],
[-1.5584, -2.3030, 0.0000, -0.8582]])
1 s(torch.FloatTensor([1 ,-np.inf,3 ]))
<ipython-input-33-50788e72e9da>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
s(torch.FloatTensor([1,-np.inf,3]))
tensor([0.1192, 0.0000, 0.8808])
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 class ScaledDotProductAttention (nn.Module) : def __init__ (self, attention_dropout=0.0 ) : super(ScaledDotProductAttention,self).__init__() self.dropout = nn.Dropout(attention_dropout) self.softmax = nn.Softmax(dim = -1 ) def forward (self,q,k,v,scale=None,attn_mask = None) : attention = torch.matmul(q,k.transpose(-2 ,-1 )) if scale: attention = attention * scale if attn_mask is not None : attention = attention.masked_fill_(attn_mask,-np.inf) attention = self.softmax(attention) attention = self.dropout(attention) context = torch.matmul(attention,v) return context
多头注意力机制
Linear(in_features=10, out_features=10, bias=True)
tensor([[[ 0.4640, 0.5466, -0.6880, 0.1568],
[ 0.8788, 0.9843, -0.4244, -1.5735],
[ 0.1039, 1.2114, 0.7816, -0.8735]],
[[ 1.1619, -2.5654, 0.5679, -1.1354],
[-0.9004, 0.5074, 1.4977, -0.5807],
[-1.3787, 0.7510, -1.1061, 1.2569]]])
1 x.unsqueeze(2 ).repeat(1 ,1 ,10 ,1 ).size()
torch.Size([2, 3, 10, 4])
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 class MultiHeadAttention (nn.Module) : def __init__ (self, d_modl=512 , num_heads=8 , dropout=0.0 ) : super(MultiHeadAttention,self).__init__() self.dim_per_head = d_modl // num_heads self.num_heads = num_heads self.linear_k = nn.Linear(d_modl, d_modl) self.linear_v = nn.Linear(d_modl, d_modl) self.linear_q = nn.Linear(d_modl, d_modl) self.dot_product_attention = ScaledDotProductAttention(dropout) self.linear_final = nn.Linear(d_modl,d_modl) self.norm = nn.LayerNorm(d_modl) def forward (self, keys, values, queries, attn_mask=None) : residual = queries batch_size = keys.size(0 ) keys = self.linear_k(keys) values = self.linear_v(values) queries = self.linear_q(queries) keys = keys.view(batch_size , -1 , self.num_heads, self.dim_per_head).transpose(1 ,2 ) values = values.view(batch_size, -1 , self.num_heads, self.dim_per_head).transpose(1 ,2 ) queries = queries.view(batch_size, -1 , self.num_heads, self.dim_per_head).transpose(1 ,2 ) if attn_mask is not None : attn_mask = attn_mask.unsqueeze(1 ).repeat(1 ,self.num_heads,1 ,1 ) scale = (keys.size(-1 )) ** -0.5 context = self.dot_product_attention(queries,keys,values,scale,attn_mask) context = context.transpose(1 ,2 ).contiguous() \ .view(batch_size,-1 ,self.num_heads * self.dim_per_head) return self.norm(residual+self.linear_final(context))
位置编码 1 torch.arange(0 ,10 ).unsqueeze(1 ).size()
torch.Size([10, 1])
1 div_term = torch.exp(torch.arange(0. ,512 ,2 )*-(math.log(10000.0 )/512 ))
torch.Size([256])
tensor([[ 9.0318e-01, 0.0000e+00, 1.2352e-01, 0.0000e+00, -3.0140e-01,
0.0000e+00, -3.6465e-01, 0.0000e+00, -5.0365e-01, 0.0000e+00],
[-4.5975e-01, 0.0000e+00, 8.5064e-01, 0.0000e+00, -2.6547e+00,
0.0000e+00, 7.4937e-01, 0.0000e+00, -4.1507e-01, 0.0000e+00],
[-1.4702e+00, 0.0000e+00, 4.7715e-01, 0.0000e+00, 8.0542e-01,
0.0000e+00, -4.0687e-01, 0.0000e+00, -7.3654e-01, 0.0000e+00],
[ 1.2496e+00, 0.0000e+00, 1.0493e+00, 0.0000e+00, 1.4115e+00,
0.0000e+00, -4.0402e-01, 0.0000e+00, 1.9959e-01, 0.0000e+00],
[ 4.1005e-01, 0.0000e+00, -1.3749e+00, 0.0000e+00, -9.4356e-02,
0.0000e+00, -2.5279e-01, 0.0000e+00, 1.3641e+00, 0.0000e+00],
[ 3.0355e-01, 0.0000e+00, -7.0061e-01, 0.0000e+00, -6.3308e-01,
0.0000e+00, 7.0820e-02, 0.0000e+00, -6.3141e-02, 0.0000e+00],
[-1.7276e+00, 0.0000e+00, 7.1022e-01, 0.0000e+00, -3.7692e-01,
0.0000e+00, 5.7131e-01, 0.0000e+00, -1.0790e+00, 0.0000e+00],
[-1.9643e+00, 0.0000e+00, -8.7474e-01, 0.0000e+00, -1.2753e+00,
0.0000e+00, 2.8921e-01, 0.0000e+00, -1.4253e+00, 0.0000e+00],
[ 8.4792e-01, 0.0000e+00, 2.9655e-02, 0.0000e+00, -9.0477e-02,
0.0000e+00, 3.1047e-01, 0.0000e+00, 1.8603e+00, 0.0000e+00],
[-5.7733e-01, 0.0000e+00, -2.1318e-01, 0.0000e+00, -2.9424e-01,
0.0000e+00, 5.5969e-01, 0.0000e+00, 5.9077e-01, 0.0000e+00],
[-9.6322e-01, 0.0000e+00, 8.8474e-01, 0.0000e+00, 2.2378e-01,
0.0000e+00, -6.0010e-01, 0.0000e+00, -3.6576e-01, 0.0000e+00],
[ 8.8694e-01, 0.0000e+00, 2.8291e-02, 0.0000e+00, -6.5218e-01,
0.0000e+00, -3.9719e-01, 0.0000e+00, -8.0203e-01, 0.0000e+00],
[ 4.1978e-01, 0.0000e+00, -2.4290e-01, 0.0000e+00, 7.7798e-02,
0.0000e+00, -9.2004e-01, 0.0000e+00, 5.3866e-01, 0.0000e+00],
[-1.0515e+00, 0.0000e+00, -1.0967e+00, 0.0000e+00, -1.0951e+00,
0.0000e+00, 2.9280e-01, 0.0000e+00, -9.3913e-01, 0.0000e+00],
[ 8.6279e-01, 0.0000e+00, 4.4137e-01, 0.0000e+00, 2.5958e-01,
0.0000e+00, 7.3830e-01, 0.0000e+00, 7.2514e-01, 0.0000e+00],
[ 1.5696e+00, 0.0000e+00, -6.6977e-01, 0.0000e+00, -1.4154e+00,
0.0000e+00, 1.1696e+00, 0.0000e+00, 2.2280e-01, 0.0000e+00],
[-1.2376e+00, 0.0000e+00, -1.3173e-01, 0.0000e+00, 1.9464e-01,
0.0000e+00, 2.0106e-01, 0.0000e+00, -1.9465e-01, 0.0000e+00],
[-8.8660e-01, 0.0000e+00, -1.7934e-01, 0.0000e+00, 1.1574e+00,
0.0000e+00, 4.0144e-01, 0.0000e+00, -1.7495e-03, 0.0000e+00],
[ 6.2252e-01, 0.0000e+00, -3.1496e-01, 0.0000e+00, 6.6546e-01,
0.0000e+00, -1.8034e-01, 0.0000e+00, -7.8079e-01, 0.0000e+00],
[ 7.3892e-01, 0.0000e+00, 1.0642e+00, 0.0000e+00, -1.8440e-01,
0.0000e+00, -1.8549e+00, 0.0000e+00, -1.6177e+00, 0.0000e+00]])
[1, 0, 3]
torch.Size([20, 2])
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 class PositionalEncoding (nn.Module) : def __init__ (self, d_model, max_seq_len, dropout=0.0 ) : super(PositionalEncoding,self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_seq_len,d_model) position = torch.arange(0. ,max_seq_len).unsqueeze(1 ) div_term = torch.exp(torch.arange(0. ,d_model,2 )*-(math.log(10000.0 )/d_model)) pe[:,0 ::2 ] = torch.sin(position * div_term) pe[:,1 ::2 ] = torch.cos(position * div_term) pe = pe.unsqueeze(0 ) self.register_buffer("pe" ,pe) def forward (self,x) : x = x + Variable(self.pe[:,:x.size(1 )],requires_grad=False ) return self.dropout(x)
前向+层归一 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 class PositionalWiseFeedForward (nn.Module) : def __init__ (self, d_model=512 , ffn_dim=2048 , dropout=0.0 ) : super(PositionalWiseFeedForward,self).__init__() self.w1 = nn.Linear(d_model,ffn_dim) self.w2 = nn.Linear(ffn_dim,d_model) self.dropout = nn.Dropout(dropout) self.norm = nn.LayerNorm(d_model) def forward (self,x) : output = self.w2(F.relu(self.w1(x))) return self.norm(x+self.dropout(output))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class EncoderLayer (nn.Module) : def __init__ (self, d_model = 512 , num_heads = 8 , ffn_dim = 2018 , dropout = 0.0 ) : super(EncoderLayer,self).__init__() self.attention = MultiHeadAttention(d_model, num_heads, dropout) self.feed_forward = PositionalWiseFeedForward(d_model, ffn_dim, dropout) def forward (self, x, attn_mask = None) : context = self.attention(x,x,x,attn_mask) output = self.feed_forward(context) return output
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class Encoder (nn.Module) : def __init__ (self, vocab_size, max_seq_len, num_layers = 6 , d_model = 512 , num_heads = 8 , ffn_dim = 2048 , dropout = 0.0 ) : super(Encoder,self).__init__() self.encoder_layers = nn.ModuleList( [EncoderLayer(d_model,num_heads,ffn_dim,dropout) for _ in range(num_layers)]) self.pos_embedding = PositionalEncoding(d_model, max_seq_len,dropout) self.norm = nn.LayerNorm(d_model) def forward (self, x, seq_embedding) : embedding = seq_embedding(x) output = self.pos_embedding(embedding) self_attention_mask = padding_mask(x,x) for encoder in self.encoder_layers: output = encoder(output,self_attention_mask) return self.norm(output)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 class DecoderLayer (nn.Module) : def __init__ (self, d_model, num_heads = 8 , ffn_dim = 2048 , dropout = 0.0 ) : super(DecoderLayer,self).__init__() self.attention = MultiHeadAttention(d_model, num_heads, dropout) self.feed_forward = PositionalWiseFeedForward(d_model, ffn_dim, dropout) def forward (self, dec_inputs, enc_outputs, self_attn_mask = None,context_attn_mask = None) : dec_ouput = self.attention(dec_inputs, dec_inputs, dec_inputs ,self_attn_mask) dec_ouput = self.attention(enc_outputs, enc_outputs,dec_ouput, context_attn_mask) dec_ouput = self.feed_forward(dec_ouput) return dec_ouput class Decoder (nn.Module) : def __init__ (self, vocab_size, max_seq_len, device, num_layers = 6 , d_model = 512 , num_heads = 8 , ffn_dim = 2048 , dropout = 0.0 , ) : super(Decoder,self).__init__() self.device = device self.num_layers = num_layers self.decoder_layers = nn.ModuleList( [DecoderLayer(d_model,num_heads,ffn_dim,dropout) for _ in range(num_layers)]) self.seq_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0 ) self.pos_embedding = PositionalEncoding(d_model, max_seq_len) self.linear = nn.Linear(d_model, vocab_size, bias=False ) def forward (self, inputs, enc_output, seq_embedding, context_attn_mask = None) : embedding = seq_embedding(inputs) output = embedding + self.pos_embedding(embedding) self_attention_padding_mask = padding_mask(inputs, inputs) seq_mask = sequence_mask(inputs).to(self.device) self_attn_mask = torch.gt((self_attention_padding_mask+seq_mask), 0 ) for decoder in self.decoder_layers: output = decoder(output, enc_output,self_attn_mask,context_attn_mask) output = self.linear(output) return output
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 class Transformer (nn.Module) : def __init__ (self, vocab_size, max_len, device, num_layers = 6 , stack_layers= 6 , d_model = 512 , num_heads = 8 , ffn_dim = 2048 , dropout = 0.2 ) : super(Transformer, self).__init__() self.device = device self.encoder = Encoder(vocab_size, max_len,num_layers,d_model,num_heads,ffn_dim,dropout) self.decoder = Decoder(vocab_size, max_len,device, num_layers,d_model,num_heads, ffn_dim, dropout) self.embedding = nn.Embedding(vocab_size,d_model) self.linear = nn.Linear(d_model, vocab_size, bias = False ) def forward (self, src_seq, dec_tgt,dec_in) : context_attn_mask_dec = padding_mask(dec_tgt, src_seq) en_output = self.encoder(src_seq,self.embedding) dec_output = self.decoder(dec_tgt, en_output,self.embedding,context_attn_mask_dec) return dec_output
1 2 3 4 5 6 7 8 def padding_mask (seq_k, seq_q) : len_q = seq_q.size(1 ) pad_mask = seq_k.eq(0 ) pad_mask = pad_mask.unsqueeze(1 ).expand(-1 ,len_q,-1 ) return pad_mask
1 2 3 4 inputs = torch.tensor([[1 ,2 ,3 ,0 ,0 ,0 ,0 ,0 ], [3 ,4 ,0 ,0 ,0 ,0 ,0 ,0 ], [3 ,0 ,0 ,0 ,0 ,0 ,0 ,0 ], [4 ,5 ,6 ,7 ,0 ,0 ,0 ,0 ]])
1 padding_mask(inputs,inputs)
tensor([[[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True]],
[[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True],
[False, False, True, True, True, True, True, True]],
[[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, True]],
[[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True]]])
1 2 3 4 5 6 7 def sequence_mask (seq) : batch_size , seq_len = seq.size() mask = torch.triu(torch.ones((seq_len, seq_len),dtype = torch.uint8), diagonal = 1 ) mask = mask.unsqueeze(0 ).expand(batch_size, -1 ,-1 ) return mask
tensor([[[0, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)