논문 읽기/NLP

[논문 읽기] PyTorch 구현 코드로 살펴보는 Attention(2015), Neural Machine Translation by jointly Learning to Align and Translate

AI 꿈나무 2021. 6. 19. 14:48
반응형

 안녕하세요, 오늘 읽은 논문은 Neural Machine Translation by jointly Learning to Align and Translate 입니다. 

 

 해당 논문은 Seq2Seq 구조에서 Attention 매커니즘과 양방향 RNN(bidirectional RNN)을 제안합니다.

 

 

 Seq2Seq 구조는 Encoder와 decoder로 구성됩니다. encoder의 역할은 source sentence를 입력 받아, 고정된 벡터 크기로 반환합니다. 저자는 이 고정된 길이의 벡터가 긴 문장을 번역하는데 문제점으로 작용한다고 합니다. 따라서 decode가 어떤 source sentence에 집중해야 하는지 결정하도록 합니다. decoder를 attention 매커니즘으로 작동하게 함으로써, encoder는 source sentence의 모든 정보를 고정된 길이의 벡터로 encode해야하는 부담감을 덜어줍니다. 또한 다음 target 단어의 생성과 관련있는 정보에만 집중하도록 합니다.

 

 양방향 RNN은 두 개의 RNN을 사용합니다 하나의 RNN은 input sequence를 순서대로 읽고, forward hidden states의 순서를 계산합니다. 역방향 RNN은 역방향으로 sequence를 읽고, backward hidden states를 계산합니다. 이 둘의 출력 hidden state를 concat으로 연결합니다. 따라서 concat된 hidden state는 단어 x 주변의 단어에 집중할 수 있게 됩니다.

 

Encoder

 Encoder은 bidirectional RNN을 사용합니다. 

 

출처: https://github.com/bentrevett/pytorch-seq2seq

 

 2개의 RNN은 정방향, 역방향으로 source sequence를 읽어, 정방향, 역방향 hidden states를 계산합니다.

 

 

 계산된 hidden states는 concat되고, fc layer를 거친 후에, tanh 활성화 함수에 전달되어 context vector를 생성합니다.

 

 

 그리고 이 context vector은 decoder의 initial hidden states로 입력됩니다.

 

 pytorch 구현 코드를 살펴보겠습니다. 코드 출처는 https://github.com/bentrevett/pytorch-seq2seq 입니다.

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        # bidirectional=True로 설정하여 bi-rnn을 구현합니다.
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        
        # 양방향 rnn의 출력값을 concat 한 후에 fc layer에 전달합니다.
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        
        #src = [src len, batch size]
        
        # 입력 x를 임베딩
        embedded = self.dropout(self.embedding(src))
        
        #embedded = [src len, batch size, emb dim]
        
        # rnn의 출력값
        outputs, hidden = self.rnn(embedded)
                
        #outputs = [src len, batch size, hid dim * num directions]
        #hidden = [n layers * num directions, batch size, hid dim]
        
        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        #outputs are always from the last layer
        
        #hidden [-2, :, : ] is the last of the forwards RNN 
        #hidden [-1, :, : ] is the last of the backwards RNN
        
        #initial decoder hidden is final hidden state of the forwards and backwards 
        #  encoder RNNs fed through a linear layer
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        
        #outputs = [src len, batch size, enc hid dim * 2]
        #hidden = [batch size, dec hid dim]
        
        return outputs, hidden

 

Attention

 

 이제 attention layer를 구현합니다. attention layer는 decoder의 이전 hidden state인 s_(t-1)와 encoder의 hiddenstate H를 입력받아 attention vector at를 출력합니다. at의 길이는 source sentence와 동일하여 각 요소는 0과 1이고, 전체 합은 1입니다. 즉, attention vector at를 생산하기 위해 decoder의 hidden state s_(t-1)와 encode의 hidden state H를 입력 받는 것입니다.

 

 우선 energy를 계산합니다. energy는 linear layer와 tanh 활성화 함수로 계산합니다.

 

 

 이를 source sentence 길이로 바꾸기 위해서 energy를 v 텐서로 곱합니다.

 

 

이 v 텐서의 파라미터는 무작위로 초기화되어있고, 역전파로 학습이 됩니다. 이 v 텐서는 모든 encoder hidden state의 가중 합에 대한 가중치로 생각해볼 수 있습니다. v는 time에 의존적이지 않습니다. decoding에서 매 time step에서 동일한 v를 사용합니다. 이 v는 편향이 없는 linear layer로 구현합니다.

 

 그리고 softmax 함수를 거쳐서 0~1 값을 갖게 합니다.

 

 

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        #hidden = [batch size, src len, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        #energy = [batch size, src len, dec hid dim]

        attention = self.v(energy).squeeze(2)
        
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)

 

Decoder

 

 decoder은 위에서 정의한 attention class를 포함합니다. attention layer는 이전 hidden state $s_{t-1}$와 encoder의 모든 hidden state H를 입력 받아 attention vector $a_t$를 반환합니다.

 

 이 attention vector를 weighted vector $w_t$를 생성하기 위해 사용합니다. attention vector $a_t$와 encoder의 hidden state H를 사용하여 계산합니다.

 

 

 이 weighted vector와 이전 예측값에 임베딩을 적용한 d($y_t$), 이전 hidden state $s_{t-1}$를 GRU에 전달하여 다음 hidden state $s_t$를 계산합니다. 여기서, embedding d($y_t$)와 $w_t$가 concat된 후 GRU에 전달됩니다.

 

 

 예측값은 d($y_t$), $w_t$, $s_t$를 fc layer에 전달하여 계산합니다.

 

 

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        # embedding과 weighted vector가 concat 된 후, 이전 hidden staet와 함께 입력
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        
        # 입력값 d(y_t), w_t, s_t
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
             
        #input = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        input = input.unsqueeze(0)
        
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input))
        
        #embedded = [1, batch size, emb dim]
        
        a = self.attention(hidden, encoder_outputs)
                
        #a = [batch size, src len]
        
        a = a.unsqueeze(1)
        
        #a = [batch size, 1, src len]
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        weighted = torch.bmm(a, encoder_outputs)
        
        #weighted = [batch size, 1, enc hid dim * 2]
        
        weighted = weighted.permute(1, 0, 2)
        
        #weighted = [1, batch size, enc hid dim * 2]
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        
        #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]
            
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        
        #output = [seq len, batch size, dec hid dim * n directions]
        #hidden = [n layers * n directions, batch size, dec hid dim]
        
        #seq len, n layers and n directions will always be 1 in this decoder, therefore:
        #output = [1, batch size, dec hid dim]
        #hidden = [1, batch size, dec hid dim]
        #this also means that output == hidden
        assert (output == hidden).all()
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        
        #prediction = [batch size, output dim]
        
        return prediction, hidden.squeeze(0)

 

Seq2Seq

 위에서 정의한 encoder, decoder class를 사용하여 Seq2Seq를 정의합니다.

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        
        #src = [src len, batch size]
        #trg = [trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)
                
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, trg_len):
            
            #insert input token embedding, previous hidden state and all encoder hidden states
            #receive output tensor (predictions) and new hidden state
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            
            #place predictions in a tensor holding predictions for each token
            outputs[t] = output
            
            #decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 
            
            #if teacher forcing, use actual next token as next input
            #if not, use predicted token
            input = trg[t] if teacher_force else top1

        return outputs

참고자료

[1] https://github.com/bentrevett/pytorch-seq2seq

[2] https://arxiv.org/pdf/1409.0473.pdf

반응형