반응형

Transformer는 사실, NLP 분야뿐만 아니라, 다양한 분야에서 많이 사용되기 때문에, 그만큼 구현 소스를 쉽게 찾을 수 있다. 나도, Transformer를 자주 사용하지만, 라이브러리에서 읽어오는 형태로 사용하기 때문에, 그 상세 구조에 대해서는 대략적으로만 알고 있다. 이번 기회에 Transformer를 pytorch로 직접 짜보면서 그 구조를 정확히 이해하고자 한다.

 

Full source : https://github.com/daehwichoi/transformer-pytorch/blob/main/model/transformer.py

 

구현 방향

  • 사실, pytorch로 Transformer를 구현한 사례는 google 검색만 해도 굉장히 많이 나온다. 하지만, original transformer를 직접 구현해보고 싶어서, 논문을 그대로 구현하는데 초점을 맞췄다. 
  • 모델 학습을 위한 layer(Dropout 등)나, dataloader는 task마다 다르고, 구현 목적이 Transformer 모델을 구현하는 것이기 때문에, 모델 구현만 진행했다.

 

참고 자료

  • Transformer 논문 내, 구조 설명 부분

2023.05.08 - [NLP 논문] - Transformer (Attention Is All You Need) - (1) 리뷰

 

Transformer (Attention Is All You Need) - (1) 리뷰

Transformer 배경 설명 Transformer는 Google Brain이 2017년 "Attention is All You Need"라는 논문에서 제안된 딥러닝 모델이다. Transformer는 기존 자연어 처리 분야에서 주로 사용되던 RNN, LSTM 같은 순환 신경망 모

devhwi.tistory.com

 

 

구조 설명

  • Transformer는 크게 Encoder 부분과 Decoder 부분, input&output embedding, postional encoding으로 나뉜다.
  • Encoder 부분은 N개의 Encoder가 연결된 구조로 구성되어 있고, Decoder도 N개의 Decoder가 연결된 구조로 구성되어 있다.
  • Ecoder는 크게, Multi-Head Attention(self-attention)과 redidual 부분(residual add  & layer norm), Feed Forward로 구성되어 있다. 
  • Decoder는 크게, Masked-Multi-Head Attention(self-attention)과 resiedidual 부분(residual add  & layer norm), Multi-Head Attention(encoder-decoder attention), Feed Forward로 구성되어 있다.

 

구현 내용 설명

(순서는 내가 구현한 순서이다.)

1. Multi-Head Attention

[Sacled Dot-Product Attention]

  • Multi-Head Attention의 핵심은 scaled_dot_product_attention이다.
  • scaled_dot_product는 Query, Key, Value가 있을 때, Query와 Key의 Transpose의 Matmul(Dot Product)를 통해, similarity를 계산하고, similarity 기반으로 Value 값을 참조한다.
  • Scaled Dot-product는 Network 여러 부분에서 사용되지만, Decoder 부분에서는 masking 처리를 해야 하는 부분이 있기 때문에, mask부분을 포함해서 함께 구현했다.

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        d_k = k.size()[-1]
        k_transpose = torch.transpose(k, 3, 2)

        output = torch.matmul(q, k_transpose)
        output = output / math.sqrt(d_k)
        if mask is not None:
            output = output.masked_fill(mask.unsqueeze(1).unsqueeze(-1), 0)

        output = F.softmax(output, -1)
        output = torch.matmul(output, v)

        return output

[Multi-Head Attention]

  • Multi-Head Attention은 scaled Dot-Product Attention을 query에 해당하는 value 값들을 참조하기 위해 사용하는데, query, key, value를 그대로 사용하는 것이 아니라, 여러 개의 head로 나누고, query, key, value를 linear projection 한 후, 사용한다. 
  • Scaled Dot-Product Attention 이후, 각 head의 value 값을 concat하고, linear layer을 거쳐 output을 낸다.
  • 주의할 점은, sequence의 순서가 중요하기 때문에, contiguous를 사용해서, 순서를 유지한다는 점이다.

class MultiHeadAttention(nn.Module):
    def __init__(self, dim_num=512, head_num=8):
        super().__init__()
        self.head_num = head_num
        self.dim_num = dim_num

        self.query_embed = nn.Linear(dim_num, dim_num)
        self.key_embed = nn.Linear(dim_num, dim_num)
        self.value_embed = nn.Linear(dim_num, dim_num)
        self.output_embed = nn.Linear(dim_num, dim_num)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
    ...
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size()[0]

        # 순서 유지 때문에 view 후 transpose 사용
        q = self.query_embed(q).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
        k = self.key_embed(k).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
        v = self.value_embed(v).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)

        output = self.scaled_dot_product_attention(q, k, v, mask)
        batch_num, head_num, seq_num, hidden_num = output.size()
        output = torch.transpose(output, 1, 2).contiguous().view((batch_size, -1, hidden_num * self.head_num))

        return output

 

2. Residual Add & Layer Norm

[Layer Norm]

  • Layer Norm은 dimension layer 방향으로 평균을 빼고, 표준 편차로 나누는 Normalization 기법이다.
  • 이 부분은 nn.LayerNorm을 통해, 구현할 수 있다. 
   def layer_norm(self, input):
        mean = torch.mean(input, dim=-1, keepdim=True)
        std = torch.std(input, dim=-1, keepdim=True)
        output = (input - mean) / std
        return output

[Add & Layer Norm]

  • 이전 층의 output을 layer norm을 통해, normalization 한 후, residual 값을 더해준다.
class AddLayerNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def layer_norm(self, input):
    ...

    def forward(self, input, residual):
        return residual + self.layer_norm(input)

 

3. Feed Forward

  • Feed Forward는 Fully Connected Layer → Relu →  Fully Connected Layer로 구성되어 있다.

class FeedForward(nn.Module):
    def __init__(self, dim_num=512):
        super().__init__()
        self.layer1 = nn.Linear(dim_num, dim_num * 4)
        self.layer2 = nn.Linear(dim_num * 4, dim_num)

    def forward(self, input):
        output = self.layer1(input)
        output = self.layer2(F.relu(output))

        return output

 

4. Encoder

  • Encoder는 Multi-Head Attention → Residual Add & Layer Norm → Feed Forward → Residual Add & Layer Norm 순으로 구성되어 있다. 
  • Encoder는 단순히, 앞서 선언했던, sub layer들을 연결하는 방식으로 구현했다.

class Encoder(nn.Module):
    def __init__(self, dim_num=512):
        super().__init__()
        self.multihead = MultiHeadAttention(dim_num=dim_num)
        self.residual_layer1 = AddLayerNorm()
        self.feed_forward = FeedForward(dim_num=dim_num)
        self.residual_layer2 = AddLayerNorm()

    def forward(self, q, k, v):
        multihead_output = self.multihead(q, k, v)
        residual1_output = self.residual_layer1(multihead_output, q)
        feedforward_output = self.feed_forward(residual1_output)
        output = self.residual_layer2(feedforward_output, residual1_output)

        return output

 

5. Decoder

  • Decoder는 Masked Multi-Head Attention → Residual Add & Layer Norm → Multi-Head Attention → Residual Add & Layer Norm Feed Forward → Residual Add & Layer Norm 순으로 구성되어 있다.
  • Encoder와 마찬가지로, 앞서 구현해놓은 sub-layer를 연결하면 되지만, 중간 Multi-Head Attention은 Query와 Key를 Encoder의 Output을 사용하기 때문에, 이 점을 명시해야 한다.
  • Decoder는 Ecoder와 다르게, masking을 이용하여, mask를 인자로 받는 것도 주의해야 한다.

class Decoder(nn.Module):
    def __init__(self, dim_num=512):
        super().__init__()

        self.masked_multihead = MultiHeadAttention(dim_num=dim_num)
        self.residual_layer1 = AddLayerNorm()
        self.multihead = MultiHeadAttention(dim_num=dim_num)
        self.residual_layer2 = AddLayerNorm()
        self.feed_forward = FeedForward(dim_num=dim_num)
        self.residual_layer3 = AddLayerNorm()

    def forward(self, o_q, o_k, o_v, encoder_output, mask):
        masked_multihead_output = self.masked_multihead(o_q, o_k, o_v, mask)
        residual1_output = self.residual_layer1(masked_multihead_output, o_q)
        multihead_output = self.multihead(encoder_output, encoder_output, residual1_output, mask)
        residual2_output = self.residual_layer2(multihead_output, residual1_output)
        feedforward_output = self.feed_forward(residual2_output)
        output = self.residual_layer3(feedforward_output, residual2_output)

        return output

 

6. Transformer

  • 전체 Transformer는 Input Embedding, Positional Encoding, Output Embedding, N개의 encoder와 N개의 decoder로 구성되어 있다. 

[positional_encoding]

  • positinal encoding은 짝수번째 token과 홀수번째 token이 각기 다른 식을 따른다. 아래 식에서 i는 hidden dimension 방향의 index이고, pos는 positional 방향(몇 번째 seq인지)을 의미한다.
  • positional encoding은 크게, 두 부분에서 사용되는데, Input과 Output의 sequence length 길이가 다를 수 있기 때문에, 이것을 인자로 받는 형태로 구현했다.
  • 마지막에 self.register_buffer는 추후, model parameter 학습 시, psotional encoding이 학습되지 않도록 막아주기 위한 용도이다. 

    def position_encoding(self, position_max_length=100):
        position = torch.arange(0, position_max_length, dtype=torch.float).unsqueeze(1)
        pe = torch.zeros(position_max_length, self.hidden_dim)
        div_term = torch.pow(torch.ones(self.hidden_dim // 2).fill_(10000),
                             torch.arange(0, self.hidden_dim, 2) / torch.tensor(self.hidden_dim, dtype=torch.float32))
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

        return pe

 

[input & output Embedding]

  • Embedding은 nn.Embedding을 통해, 쉽게 구현할 수 있다.
  • Embedding의 첫 번째 인자는 input 데이터의 total word 개수, 두 번째 인자는 hidden dimension의 수이다.
  • total_word_num은 sequence dictionary에 존재하는 unique value의 개수를 의미한다. (전체 단어가 아님을 주의)
  • 사실 편의를 위해, 공통 total_word_num을 사용했는데, 번역과 같은 경우, input의 단어 개수와 output의 단어 개수가 다를 수 있어, task에 따라서는 다른 인자를 받는 게 맞다.
 self.input_data_embed = nn.Embedding(total_word_num, self.hidden_dim)
 self.output_data_embed = nn.Embedding(total_word_num, self.hidden_dim)

 

[Transformer]

  • Transformer의 Encoder 부분은 앞서 구현했던, Encoder를 N번 반복하는 구조로 구현되어 있다. 
  • Encoder 부분에 들어가는 query, key, value는 문장의 embedding 한 값으로 모두 같고, (참조를 위한 query와 key가 비효율적이다.) 전번째 encoder의 결과가 다음 encoder의 query, key, value가 된다.
  • Decoder 부분도 비슷하지만, Encoder의 output이 사용된다는 점, Decoder 단에서는 다음 sequence를 볼 수 없기 때문에, 그 부분을 처리하기 위한 mask가 존재한다는 점이 다르다. 
  • Decoder에 masking으로 0 값을 넣어주었지만, 실제 학습해서는 매우 작은 값을 넣어주는 것이 학습 측면에서 유리하다고 한다.
  • Encoder 부분과 Decoder 부분을 모두 거치면, 목적에 맞는 fully connected layer를 연결하여, output을 낸다. 
class Transformer(nn.Module):
    def __init__(self, encoder_num=6, decoder_num=6, hidden_dim=512, max_encoder_seq_length=100,
                 max_decoder_seq_length=100):
        super().__init__()

        self.encoder_num = encoder_num
        self.hidden_dim = hidden_dim
        self.max_encoder_seq_length = max_encoder_seq_length
        self.max_decoder_seq_length = max_decoder_seq_length

        self.input_data_embed = nn.Embedding(max_seq_length, self.hidden_dim)
        self.Encoders = [Encoder(dim_num=hidden_dim) for _ in range(encoder_num)]

        self.output_data_embed = nn.Embedding(max_seq_length, self.hidden_dim)
        self.Decoders = [Decoder(dim_num=hidden_dim) for _ in range(decoder_num)]

        self.last_linear_layer = nn.Linear(self.hidden_dim, max_seq_length)

    def position_encoding(self, position_max_length=100):
    ...

    def forward(self, input, output, mask):

        input_embed = self.input_data_embed(input)
        input_embed += self.position_encoding(self.max_encoder_seq_length)
        q, k, v = input_embed, input_embed, input_embed

        for encoder in self.Encoders:
            encoder_output = encoder(q, k, v)
            q = encoder_output
            k = encoder_output
            v = encoder_output

        output_embed = self.output_data_embed(output)
        output += self.position_encoding(self.max_decoder_seq_length)
        output_embed = output_embed.masked_fill(mask.unsqueeze(-1), 0)
        d_q, d_k, d_v = output_embed, output_embed, output_embed

        for decoder in self.Decoders:
            decoder_output = decoder(d_q, d_k, d_v, encoder_output, mask)
            d_q = decoder_output
            d_k = decoder_output
            d_v = decoder_output

        output = F.softmax(self.last_linear_layer(decoder_output), dim=-1)
        return output

 

총평

  • 실제 NLP 단어 예측 등, 데이터셋을 넣어보기 위해, dataloader와 학습 등을 연결해 봐야겠다.
  • 특정 task를 풀기 위해, 데이터셋을 처리하기 위한 model을 짜는 것도 좋지만, 가끔은 논문을 그대로 구현해 보는 것도 좋을 것 같다. 특히, 그림과 글만 보고 구현을 하려고 하니, 내가 정확하게 알지 못했던 부분, 특히 머리로 이해하고 넘어간 부분을 완전히 알게 된 것 같아 좋다. 

+ Recent posts