논문 읽기/NLP

[논문 읽기] PyTorch 코드로 살펴보는 Convolutional Sequence to Sequence Learning(2017)

AI 꿈나무 2021. 6. 24. 23:53
반응형

 안녕하세요, 오늘 읽은 논문은 Convolutional Sequence to Sequence Learning 입니다.

 

 해당 논문은 rnn이 아닌 convolution 연산으로 번역을 수행합니다. seq2seq 구조인 encoder, decoder 구조를 갖고 있으며, 두 구조 모두 conv layer로 이루어져 있습니다. 이미지 task에서 사용하는 conv 연산을 어떻게 문자 task에서 활용할 수 있을 까요??

 

 1d convolution 연산은 filter 크기 만큼 1차원 벡터의 정보를 취합합니다. kernel_size=3인 1d conv 연산을 수행하면 3개의 입력값을 받아 filter 가중치를 거쳐서 1개의 값을 출력합니다. 입력값에 순서대로 conv filter를 적용하면, receptive field 영역만큼의 정보를 계산하여 출력값을 생성합니다. 여기에 하나의 conv layer를 추가한다면 receptive field는 확장되어 더 많은 sequence 정보를 활용할 수 있습니다. 또한 conv 연산에 사용하는 filter의 수 만큼 input sequence로부터 서로 다른 특징을 추출할 수 있습니다.

 

 rnn대신에 conv layer를 사용할때의 이점은 무엇일까요?

 

 conv layer를 쌓을수록 receptive field는 확장됩니다. 이는 입력벡터에서 정보를 사용하는 length를 자유롭게 조절할 수 있습니다. 또한 conv는 이전 time step 정보를 활용하지 않습니다. 이미 receptive field에 이전 정보도 포함되어 있기 때문입니다. 이덕분에 parallel computation을 사용할 수 있습니다.

 

 multi-layer cnn은 계층 representation을 계산합니다. 인접 입력 요소는 낮은 계층에서 상호작용하고, 멀리 떨어진 입력 요소는 높은 계층에서 상호작용합니다. 이 덕분에 long-range 정보도 포착할 수 있습니다.

 

 해당 논문에서는 conv layer를 문자 번역 task에 효과적으로 적용하기 이해서 attention mechanism, GLU, residual connection을 사용합니다. encoder와 decoder을 pytorch 구현 코드와 함께 살펴보면서 이해해보도록 하겠습니다.

 

Encoder

해당 논문에서의 encoder은 source를 입력받아 combined, conved output을 출력하여 decoder로 전달합니다. conved output은 embedding vector가 conv 연산을 거친 값입니다. combiend output은 conved output와 embedding vector 사이에 element-wise sum 연산을 수행한 것입니다.

 

출처: https://github.com/bentrevett/pytorch-seq2seq

 

(1) embedding

 token과 position vector가 각각 embedding layer를 거쳐서 embedding vector를 생성합니다. token은 기존의 NLP에서 사용하는 token이며, position vector는 token의 absolute position 정보를 담은 vector입니다. 첫 번째 token의 position은 0, 두 번째 token의 position은 1과 같이 position이 할당됩니다.

 

 token embedding과 position embedding을 elementwise sum하여 결합합니다. 결합된 embedding은 linear layer에 전달하여 차원 수를 변경한 뒤에 conv layer로 전달합니다. 또한 residual connection을 거쳐서 conved output과 element-wise sum 연산을 수행합니다.

 

(2) Convolutional Block Structure

출처: https://github.com/bentrevett/pytorch-seq2seq

 

 linear layer를 거쳐서 [batch, hid_dim,src_len] 채널을 갖은 입력값을 1d conv layer에 전달합니다.

 

 

1d conv layer 출력값은 [batch, 2*hid_dim, src_len]차원을 갖으며 이를 GLU activation function에 입력합니다. GLU는 일종의 attention 역할을 수행합니다. GLU를 거치면 [batch, hid_dim, scr_len] 차원을 갖게 됩니다.

 

 glu를 거친후에 채널수가 1/2되었습니다. glu는 2*hid_dim 채널을 가진 입력값에서 반은 X1*W1 + b 연산을 수행하고 반은 sigmoid(X2*W2 + c) 연산을 수행합니다. 이 둘을 element-wise product를 수행하여 [batch, hid_dim, src_len] 차원을 갖는 출력값을 생성합니다.

 

 그리고 residual connection에의해 layer의 입력값과 더해집니다. 이 과정을 conv layer 수 만큼 반복하여 출력값을 계산합니다.

 

 conv block 이후에도 residual connection이 존재합니다. residual connection을 거치기 전의 벡터를 conved output이라고 부르며, residual connection을 거친 후의 벡터를 combined output이라고 합니다. 이 둘을 decoder로 전달합니다.

 

(3) encoder PyTorch code

 파이토치 코드를 한번 살펴보겠습니다. 코드는 https://github.com/bentrevett/pytorch-seq2seq 를 참고했습니다. 해당 깃허브 코드는 각 단계별 차원수를 명시해줘서 이해하는데 큰 도움이 됩니다 ㅎㅎ

class Encoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 emb_dim, 
                 hid_dim, 
                 n_layers, 
                 kernel_size, 
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()
        
        assert kernel_size % 2 == 1, "Kernel size must be odd!"
        
        self.device = device
        
        self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
        
        self.tok_embedding = nn.Embedding(input_dim, emb_dim)
        self.pos_embedding = nn.Embedding(max_length, emb_dim)
        
        self.emb2hid = nn.Linear(emb_dim, hid_dim)
        self.hid2emb = nn.Linear(hid_dim, emb_dim)
        
        self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim, 
                                              out_channels = 2 * hid_dim, 
                                              kernel_size = kernel_size, 
                                              padding = (kernel_size - 1) // 2)
                                    for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        
        #src = [batch size, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        #create position tensor
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [0, 1, 2, 3, ..., src len - 1]
        
        #pos = [batch size, src len]
        
        #embed tokens and positions
        tok_embedded = self.tok_embedding(src)
        pos_embedded = self.pos_embedding(pos)
        
        #tok_embedded = pos_embedded = [batch size, src len, emb dim]
        
        #combine embeddings by elementwise summing
        embedded = self.dropout(tok_embedded + pos_embedded)
        
        #embedded = [batch size, src len, emb dim]
        
        #pass embedded through linear layer to convert from emb dim to hid dim
        conv_input = self.embdid(embedded)
        
        #conv_input = [batch size, src len, hid dim]
        
        #permute for convolutional layer
        conv_input = conv_input.permute(0, 2, 1) 
        
        #conv_input = [batch size, hid dim, src len]
        
        #begin convolutional blocks...
        
        for i, conv in enumerate(self.convs):
        
            #pass through convolutional layer
            conved = conv(self.dropout(conv_input))

            #conved = [batch size, 2 * hid dim, src len]

            #pass through GLU activation function
            conved = F.glu(conved, dim = 1)

            #conved = [batch size, hid dim, src len]
            
            #apply residual connection
            conved = (conved + conv_input) * self.scale

            #conved = [batch size, hid dim, src len]
            
            #set conv_input to conved for next loop iteration
            conv_input = conved
        
        #...end convolutional blocks
        
        #permute and convert back to emb dim
        conved = self.hid2emb(conved.permute(0, 2, 1))
        
        
        #conved = [batch size, src len, emb dim]
        
        #elementwise sum output (conved) and input (embedded) to be used for attention
        combined = (conved + embedded) * self.scale
        
        #combined = [batch size, src len, emb dim]
        
        return conved, combined

 

Decoder

https://github.com/bentrevett/pytorch-seq2seq

 

 decoder와 encoder의 가장 큰 차이점은 conv block 이후에 residual connection 유무입니다. decoder은 element-wise sum된 embedding이 residual connection으로 전달되지 않지만, conv block내의 residual connection에서 사용됩니다. 구조가 복잡하군요 ㅎㅎ

 

(1) conv block

 decoder 내의 conv block을 살펴보겠습니다.

 

 

 위 그림을 보면 GLU의 출력값에서 encoder 출력값인 conved, combined와 target embedding이 입력됩니다. 이 세개의 vector를 사용하여 attention을 계산합니다.

 

 attention 계산 과정

 (1) decoder conv layer 출려값인 decode_conved를 fc layer에 전달하여 차원을 hid_dim에서 emb_dim으로 변경합니다.

 (2) residual connection으로 target embedding과 decode_conved 를 element-wise sum 합니다. 이 벡터를 combined라고 하겠습니다.

 (3) combiend와 encoder_conved를 matmul 연산을 하여 energy를 계산합니다.

 (4) energy에 softmax연산을 수행해 0~1 값을 갖도록 합니다. 이 벡터를 attention이라고 부르겠습니다.

 (5) attention과 encoder_combined를 matmul해 attended_encoding 벡터를 생성합니다.

 (6) attended_encoding을 fc layer를 거쳐 emb_dim을 hid_dim으로 변경합니다.

 (7) 그리고 residual connection에 의해 decode_conved 벡터와 attended_encoding이 element-wise sum 연산을 수행합니다.

 

(2) decoder pytorch 코드

 decoder pytorch 코드를 살펴보겠습니다.

class Decoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 emb_dim, 
                 hid_dim, 
                 n_layers, 
                 kernel_size, 
                 dropout, 
                 trg_pad_idx, 
                 device,
                 max_length = 100):
        super().__init__()
        
        self.kernel_size = kernel_size
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
        self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)
        
        self.tok_embedding = nn.Embedding(output_dim, emb_dim)
        self.pos_embedding = nn.Embedding(max_length, emb_dim)
        
        self.emb2hid = nn.Linear(emb_dim, hid_dim)
        self.hid2emb = nn.Linear(hid_dim, emb_dim)
        
        self.attn_hid2emb = nn.Linear(hid_dim, emb_dim)
        self.attn_emb2hid = nn.Linear(emb_dim, hid_dim)
        
        self.fc_out = nn.Linear(emb_dim, output_dim)
        
        self.convs = nn.ModuleList([nn.Conv1d(in_channels = hid_dim, 
                                              out_channels = 2 * hid_dim, 
                                              kernel_size = kernel_size)
                                    for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
      
    def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined):
        
        #embedded = [batch size, trg len, emb dim]
        #conved = [batch size, hid dim, trg len]
        #encoder_conved = encoder_combined = [batch size, src len, emb dim]
        
        #permute and convert back to emb dim
        conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1))
        
        #conved_emb = [batch size, trg len, emb dim]
        
        combined = (conved_emb + embedded) * self.scale
        
        #combined = [batch size, trg len, emb dim]
                
        energy = torch.matmul(combined, encoder_conved.permute(0, 2, 1))
        
        #energy = [batch size, trg len, src len]
        
        attention = F.softmax(energy, dim=2)
        
        #attention = [batch size, trg len, src len]
            
        attended_encoding = torch.matmul(attention, encoder_combined)
        
        #attended_encoding = [batch size, trg len, emd dim]
        
        #convert from emb dim -> hid dim
        attended_encoding = self.attn_emb2hid(attended_encoding)
        
        #attended_encoding = [batch size, trg len, hid dim]
        
        #apply residual connection
        attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale
        
        #attended_combined = [batch size, hid dim, trg len]
        
        return attention, attended_combined
        
    def forward(self, trg, encoder_conved, encoder_combined):
        
        #trg = [batch size, trg len]
        #encoder_conved = encoder_combined = [batch size, src len, emb dim]
                
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
            
        #create position tensor
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, trg len]
        
        #embed tokens and positions
        tok_embedded = self.tok_embedding(trg)
        pos_embedded = self.pos_embedding(pos)
        
        #tok_embedded = [batch size, trg len, emb dim]
        #pos_embedded = [batch size, trg len, emb dim]
        
        #combine embeddings by elementwise summing
        embedded = self.dropout(tok_embedded + pos_embedded)
        
        #embedded = [batch size, trg len, emb dim]
        
        #pass embedded through linear layer to go through emb dim -> hid dim
        conv_input = self.emb2hid(embedded)
        
        #conv_input = [batch size, trg len, hid dim]
        
        #permute for convolutional layer
        conv_input = conv_input.permute(0, 2, 1) 
        
        #conv_input = [batch size, hid dim, trg len]
        
        batch_size = conv_input.shape[0]
        hid_dim = conv_input.shape[1]
        
        for i, conv in enumerate(self.convs):
        
            #apply dropout
            conv_input = self.dropout(conv_input)
        
            #need to pad so decoder can't "cheat"
            padding = torch.zeros(batch_size, 
                                  hid_dim, 
                                  self.kernel_size - 1).fill_(self.trg_pad_idx).to(self.device)
                
            padded_conv_input = torch.cat((padding, conv_input), dim = 2)
        
            #padded_conv_input = [batch size, hid dim, trg len + kernel size - 1]
        
            #pass through convolutional layer
            conved = conv(padded_conv_input)

            #conved = [batch size, 2 * hid dim, trg len]
            
            #pass through GLU activation function
            conved = F.glu(conved, dim = 1)

            #conved = [batch size, hid dim, trg len]
            
            #calculate attention
            attention, conved = self.calculate_attention(embedded, 
                                                         conved, 
                                                         encoder_conved, 
                                                         encoder_combined)
            
            #attention = [batch size, trg len, src len]
            
            #apply residual connection
            conved = (conved + conv_input) * self.scale
            
            #conved = [batch size, hid dim, trg len]
            
            #set conv_input to conved for next loop iteration
            conv_input = conved
            
        conved = self.hid2emb(conved.permute(0, 2, 1))
         
        #conved = [batch size, trg len, emb dim]
            
        output = self.fc_out(self.dropout(conved))
        
        #output = [batch size, trg len, output dim]
            
        return output, attention

 

Seq2Seq

 위에서 정의한 encoder, decoder class를 사용하여 seq2seq class를 정의합니다.

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len - 1] (<eos> token sliced off the end)
           
        #calculate z^u (encoder_conved) and (z^u + e) (encoder_combined)
        #encoder_conved is output from final encoder conv. block
        #encoder_combined is encoder_conved plus (elementwise) src embedding plus 
        #  positional embeddings 
        encoder_conved, encoder_combined = self.encoder(src)
            
        #encoder_conved = [batch size, src len, emb dim]
        #encoder_combined = [batch size, src len, emb dim]
        
        #calculate predictions of next words
        #output is a batch of predictions for each word in the trg sentence
        #attention a batch of attention scores across the src sentence for 
        #  each word in the trg sentence
        output, attention = self.decoder(trg, encoder_conved, encoder_combined)
        
        #output = [batch size, trg len - 1, output dim]
        #attention = [batch size, trg len - 1, src len]
        
        return output, attention

참고 자료

[1] https://leimao.github.io/blog/Gated-Linear-Units/

[2] https://github.com/bentrevett/pytorch-seq2seq

[3] https://arxiv.org/abs/1705.03122

반응형