논문 읽기/NLP

[논문 읽기] PyTorch 구현 코드로 살펴보는 Seq2Seq(2014), Sequence to Sequence Learning with Neural Networks

AI 꿈나무 2021. 6. 15. 13:40
반응형

 안녕하세요, 오늘 읽은 논문은 Seq2Seq, Sequence to Sequence Learning with Neural Networks 입니다. ㅎㅎ 자연어 처리 분야 논문은 처음 읽어보네요!! Seq2Seq부터 transformer까지 차근차근 읽어나갈 생각입니다.

 

 Seq2Seq를 구현하고, 불러온 데이터셋으로 학습까지 해보는 과정은 아래 포스팅에서 확인하실 수 있습니다.

 

[논문 구현] PyTorch로 Seq2Seq(2014) 구현하고 학습하기

 안녕하세요, 이번 포스팅에서는 PyTorch로 Seq2Seq를 구현하고 학습해보도록 하겠습니다. 작업 환경은 Google Colab에서 진행했습니다.  PyTorch 코드는 아래 깃허브에서 참고했습니다. bentrevett/pytorch-se

deep-learning-study.tistory.com

 

 Seq2Seq는 두 개의 LSTM을 사용합니다. 하나의 LSTM은 input sequence를 고정된 길이의 벡터로 encode 합니다. 다른 LSTM은 encode된 고정된 길이의 vector를 decode하여 target sequence를 생성합니다. 즉, Seq2Seq는 시계열 데이터에 특화되어 있습니다. 영어 문장을 입력받아 독일어로 번역을 할 수 있으며, 이외에도 여러가지 시계열 task에 활용할 수 있습니다. 또한 해당 논문에서는 reversing the order of the words 방법과 teacher forcing 방법을 제안합니다.

 

 Seq2Seq

 Seq2Seq는 왜 LSTM을 사용할까요? DNN도 sequence를 encode 할 수 있습니다. 하지만 decode 과정에서 문장의 길이가 알려지지 않아, 한계점이 존재합니다. decode과정을 입력을 읽고 한번에 하나의 time step을 출력하는 LSTM을 사용하여 한계점을 극복합니다. 또한 LSTM은 long range temporal dependency를 학습하는 능력이 있으므로, LSTM을 선택하는 것은 자연스러운 선택이었다고 말합니다.

 

 

 위 그림은 Seq2Seq입니다. 그림을 보면, 첫 번째 LSTM이 ABC를 순차적으로 입력받아 encode 한 후에, 두 번째 LSTM이 encode된 vector를 입력받아, decode하여 WXYZ을 출력합니다.

 

Seq2Seq PyTorch 코드

 Seq2Seq는 PyTorch로 다음과 같이 구현할 수 있습니다. 코드는 https://github.com/bentrevett/pytorch-seq2seq 에서 참고했습니다.

 

(1) Encoder

# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()

        self.hid_dim = hid_dim
        self.n_layers = n_layers

        # embedding: 입력값을 emd_dim 벡터로 변경
        self.embedding = nn.Embedding(input_dim, emb_dim)

        # embedding을 입력받아 hid_dim 크기의 hidden state, cell 출력
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # sre: [src_len, batch_size]

        embedded = self.dropout(self.embedding(src))

        # initial hidden state는 zero tensor
        outputs, (hidden, cell) = self.rnn(embedded)

        # output: [src_len, batch_size, hid dim * n directions]
        # hidden: [n layers * n directions, batch_size, hid dim]
        # cell: [n layers * n directions, batch_size, hid dim]

        return hidden, cell

 

(2) decoder

# decoder
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()

        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers

        # content vector를 입력받아 emb_dim 출력
        self.embedding = nn.Embedding(output_dim, emb_dim)

        # embedding을 입력받아 hid_dim 크기의 hidden state, cell 출력
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)

        self.fc_out = nn.Linear(hid_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # input: [batch_size]
        # hidden: [n layers * n directions, batch_size, hid dim]
        # cell: [n layers * n directions, batch_size, hid dim]

        input = input.unsqueeze(0) # input: [1, batch_size], 첫번째 input은 <SOS>

        embedded = self.dropout(self.embedding(input)) # [1, batch_size, emd dim]

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # output: [seq len, batch_size, hid dim * n directions]
        # hidden: [n layers * n directions, batch size, hid dim]
        # cell: [n layers * n directions, batch size, hid dim]

        prediction = self.fc_out(output.squeeze(0)) # [batch size, output dim]
        
        return prediction, hidden, cell

 

(3) Seq2Seq

# Seq2Seq
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device

        # encoder와 decoder의 hid_dim이 일치하지 않는 경우 에러메세지
        assert encoder.hid_dim == decoder.hid_dim, \
            'Hidden dimensions of encoder decoder must be equal'
        # encoder와 decoder의 hid_dim이 일치하지 않는 경우 에러메세지
        assert encoder.n_layers == decoder.n_layers, \
            'Encoder and decoder must have equal number of layers'

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src: [src len, batch size]
        # trg: [trg len, batch size]
        
        batch_size = trg.shape[1]
        trg_len = trg.shape[0] # 타겟 토큰 길이 얻기
        trg_vocab_size = self.decoder.output_dim # context vector의 차원

        # decoder의 output을 저장하기 위한 tensor
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        # initial hidden state
        hidden, cell = self.encoder(src)

        # 첫 번째 입력값 <sos> 토큰
        input = trg[0,:]

        for t in range(1,trg_len): # <eos> 제외하고 trg_len-1 만큼 반복
            output, hidden, cell = self.decoder(input, hidden, cell)

            # prediction 저장
            outputs[t] = output

            # teacher forcing을 사용할지, 말지 결정
            teacher_force = random.random() < teacher_forcing_ratio

            # 가장 높은 확률을 갖은 값 얻기
            top1 = output.argmax(1)

            # teacher forcing의 경우에 다음 lstm에 target token 입력
            input = trg[t] if teacher_force else top1

        return outputs

 

Reversing the Source Sentences

 논문에서 입력 문장의 단어의 순서를 뒤집어 성능을 향상시킵니다. 예를 들어, A,B,C,<EOS>를 입력하는 것이 아니라 <EOS>,C,B,A 순서로 입력하는 것입니다. 이 방법으로 test perplexity를 5.8에서 4.7로 감소시킵니다. 왜 입력 문장 순서를 뒤집는 것이 잘 작동할까요?? 논문에서는 처음 몇개의 target word가 그에 해당하는 source word와 가깝게 있으므로 기울기 전달이 효과적으로 되기 때문이라고 설명합니다. 나름 일리가 있는 말인 것 같네요 ㅎㅎ

 

 pytorch 구현 방법은 token을 생성할때, [::-1]로 뒤집으면 됩니다 ㅎㅎ

# tokenizer function 생성
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)][::-1]

Teacher forcing

 teacher forcing은 논문에서는 다루지 않지만, 구현 코드에서는 다루고 있어 내용을 포함했습니다. teacher forcing은 decode에서 다음 입력값을 일정 확률로 prediction 단어가 아닌 target 단어를 입력하는 것입니다. teacher forcing을 사용하면 학습 초기에 안정적입니다. 학습 초기에는 틀린 단어를 예측할 확률이 높은데, 틀린 단어를 다음 lstm cell에 입력하기 보다 이를 교정해서 target 단어를 다음 lstm cell에 입력해준다면, 학습 초기에 안정적인 학습을 기대할 수 있습니다.

 


참고자료

[1] https://github.com/bentrevett/pytorch-seq2seq

[2] https://arxiv.org/abs/1409.3215

반응형