논문 읽기/NLP

[논문 읽기] PyTorch 코드로 살펴보는 Transformer(2017)

AI 꿈나무 2021. 6. 28. 15:25
반응형

 안녕하세요, 오늘 읽은 논문은 transformer을 제안하는 Attention is All You Need 입니다.

 

 논문에서는 self-attention 으로 구성된 encoder와 decoder을 사용합니다. self-attention 작동 방법이 이해가 되질 않아서 2~3일 동안 공부했네요..ㅎㅎ 저는 이해가 잘 안될때는 코드와 함께 보는 편입니다. 코드를 보면서 어떤 과정으로 데이터가 처리되는지 확인하면 이해가 더 잘되는 것 같아요. 자연어처리 모델은 익숙하지가 않아서 아래 깃허브를 참고하면서 공부했어요. 구현 코드와 논문 설명이 함께 되어 있어서, 큰 도움이 되었습니다 ㅎㅎ

 

 

bentrevett/pytorch-seq2seq

Tutorials on implementing a few sequence-to-sequence (seq2seq) models with PyTorch and TorchText. - bentrevett/pytorch-seq2seq

github.com

 

 transformer가 어떻게 이루어져 있고 어떤 방식으로 데이터를 처리하는지 살펴보도록 하겠습니다.

 

출처: https://arxiv.org/pdf/2101.01169.pdf

 

transformer 전체 구조

 

 위 구조는 transformer의 전체 구조입니다. 보시면 ConvS2S 또는 LSTM 모듈을 사용하지 않고 attention mechanisms을 최대한 활용합니다.

 

transformer 장점

 transformer 장점은 input sequence elements간에 long depecdency를 포착할 수 있고, parallel preocessing이 가능합니다. 또한 similar preocessing block을 사용하므로 very large capacity network 또는 big dataset으로 쉽게 확장할 수 있습니다. transformer에서 사용하는 self-attention은 sequence 요소들 사이에 관계(long-term information or dependency)를 학습할 수 있습니다. 

 

Encoder

 

 encoder은 input/source sentenc를 context vector로 encoding 합니다.

 

 transformer에서 encoder은 전체 source sentence, X=(x1, ... , xn)을 입력 받아 context vector Z = (z1, ... , zn)을 출력합니다. 만약 입력 sentence가 5개의 토큰으로 이루어져 있다면, context vector도 5개의 벡터를 출력합니다. 그리고 각 Context vector은 모든 입력 sentence 정보를 갖고 있습니다. hidden state를 사용하는 이전 모델은 입력 sentence를 순서대로 입력받아 한번에 하나씩 hiiden state를 계산했지만, transformer의 encoder은 한번에 입력 sentence를 입력받아 한번에 context vector를 계산합니다.

 

(1) Encoder 작동 순서

 

 데이터가 어떻게 처리되는지 살펴보겠습니다.

 

(1) token들이 embedding layer에 전달되어 emb_dim으로 차원이 변경됩니다. 그리고 sqrt(hid_dim)을 곱하여 scale 합니다. scale을 하는 이유는 embedding내에 variance를 감소시킬 수 있고, 모델이 안정적으로 학습하도록 합니다.

 

(2) 각 token의 position 정보가 있는 seqeunce를 positional embedding layer에 전달합니다.

 

(3) 두 embedding을 element-wise sum으로 결합합니다. 따라서 결합된 벡터는 모든 토큰 정보와 토큰의 순서 정보를 포함합니다. 이제 결합된 embedding을 multi-head attention 모듈에 전달할 것입니다.

 

 

(4) 결합된 embedding을 query, key, value로써 사용합니다. 즉, 동일한 embedding을 3개의 fc layer에 전달하여 query, key, value를 생성합니다.  

 

(5) query, key, value의 차원 hid_dim을 n_head로 나누어 head_dim 차원을 갖는 n_head 개의 query, key, value로 분할합니다.

 

 

(6) 쿼리와 키를 행렬곱 한후에 sqrt(head_dim)으로 나눕니다. 나누는 이유는 행렬 곱이 너무 커져서 sigmoid를 통과하여 gradient가 작아지는 현상을 예방합니다. 행렬곱 결과를 energy라고 합니다. 이 energy를 mask로 filter 합니다. mask는 source sentence에서 pad인 부분은 0, pad가 아닌 부분은 1로 채워져 있습니다. mask의 역할은 pad 정보를 무시하기 위함입니다.

 

(7) 쿼리와 키를 행렬곱 하고 sqrt(head_dim)으로 나눈 값을 softmax에 전달합니다. 그리고 value와 행렬 곱 연산을 적용합니다.

 

 

(8) n_head로 분할된 것을 하나로 concat 합니다. 그리고 fc layer로 전달합니다.

 

(9) fc layer 출력값에 dropout을 적용하고, 입력 embedding과 residual connection으로 연결한 후에 layer normalization을 적용합니다. 결과값을 PositionwiseFeedforwardLayer로 전달합니다. 

 

(10) Positionwise Feed Forward layer는 2개의 fc layer로 구성되어 있으며, hid_dim -> pf dim -> hid_dim 순서로 차원이 변경됩니다. pf dim은 hid dim보다 큰 값을 사용합니다.

 

(11) positionwise feed forward layer의 결과값에 dropout을 적용하고, 입력값과 residual connection으로 연결합니다. 그리고 layer norm을 적용합니다. 이 출력값을 decoder에서 활용합니다. 

 

(2) Encoder PyTorch Code

 Encoder 파이토치 코드를 살펴보겠습니다. 코드는 https://github.com/bentrevett/pytorch-seq2seq 를 참고했습니다.

 

 encoder은 embedding, encoder layer로 이루어져있습니다.

class Encoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()

        self.device = device
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, src len]
        
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src

 

encoder layer는 MultiHeadAttention과 Position wise Feed forward layer로 이루어져 있습니다.

class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len] 
                
        #self attention
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        
        #dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        return src

 

Multi head attention layer 입니다.

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        #Q = [batch size, n heads, query len, head dim]
        #K = [batch size, n heads, key len, head dim]
        #V = [batch size, n heads, value len, head dim]
                
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]
        
        return x, attention

 

 position-wise feedforward layer 입니다.

class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x

 

Decoder

 decoder 의 목적은 context vector을 입력 받아 decode 하여 target sentence를 출력하는 것입니다.

 

 

 decoder은 encoder와 비슷하지만 두 개의 multi-head attention을 갖는 것이 차이점 입니다.

 

 첫 번째 multi head attention은 decoder representation(target sequence, target mask)을 사용하여 self-attention을 계산합니다. 작동 방식은 encoder와 동일하며 target mask는 target token을 순서대로 사용하도록하며, pad인 부분을 무시하도록 구현되어 있습니다.

 

 

 pad mask와 결합되기 전의 target mask입니다. 첫 번째 행은 첫 번째 token만 입력으로 전달됩니다. 이처럼 순서대로 token 정보를 활용하도록 구현되어있습니다. decoder가 target sequence의 이전 정보만을 활용해서 다음 예측값을 출력하도록 하는 역할을 합니다.

 

 

 pad가 적용된 target mask 입니다.

 

 두 번째 multi-head attention layer는 decoder representation을 쿼리로 사용하고, encoder representation을 키와 벨류로 사용합니다. 또한 target mask와 source mask도 함께 사용합니다. 작동 방법은 encoder와 동일하게 작동되며 그 과정에서 drop out, residual connection, layer norm이 사용됩니다.

 

 두 개의 multi-head attention layer를 거친 후에, position wise feed forward layer에 전달하고, 출력 값을 fc layer로 전달하여 예측값을 계산합니다.

 

(1) decoder pytorch code

 pytorch 코드와 함께 살펴보겠습니다.

class Decoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
                
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
            
        return output, attention

 

class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        #self attention
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        
        #dropout, residual connection and layer norm
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
            
        #trg = [batch size, trg len, hid dim]
            
        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        
        #dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
                    
        #trg = [batch size, trg len, hid dim]
        
        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        
        #dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return trg, attention

 

Seq2Seq

 이제 encoder와 decoder을 seq2seq로 연결해주면 transformer가 완성됩니다.

class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention

 

참고자료

[1] https://arxiv.org/abs/1706.03762

[2] https://github.com/bentrevett/pytorch-seq2seq

[3] https://arxiv.org/pdf/2101.01169.pdf

반응형