반응형
Transformer는 사실, NLP 분야뿐만 아니라, 다양한 분야에서 많이 사용되기 때문에, 그만큼 구현 소스를 쉽게 찾을 수 있다. 나도, Transformer를 자주 사용하지만, 라이브러리에서 읽어오는 형태로 사용하기 때문에, 그 상세 구조에 대해서는 대략적으로만 알고 있다. 이번 기회에 Transformer를 pytorch로 직접 짜보면서 그 구조를 정확히 이해하고자 한다.
Full source : https://github.com/daehwichoi/transformer-pytorch/blob/main/model/transformer.py
구현 방향
- 사실, pytorch로 Transformer를 구현한 사례는 google 검색만 해도 굉장히 많이 나온다. 하지만, original transformer를 직접 구현해보고 싶어서, 논문을 그대로 구현하는데 초점을 맞췄다.
- 모델 학습을 위한 layer(Dropout 등)나, dataloader는 task마다 다르고, 구현 목적이 Transformer 모델을 구현하는 것이기 때문에, 모델 구현만 진행했다.
참고 자료
- Transformer 논문 내, 구조 설명 부분
2023.05.08 - [NLP 논문] - Transformer (Attention Is All You Need) - (1) 리뷰
구조 설명
- Transformer는 크게 Encoder 부분과 Decoder 부분, input&output embedding, postional encoding으로 나뉜다.
- Encoder 부분은 N개의 Encoder가 연결된 구조로 구성되어 있고, Decoder도 N개의 Decoder가 연결된 구조로 구성되어 있다.
- Ecoder는 크게, Multi-Head Attention(self-attention)과 redidual 부분(residual add & layer norm), Feed Forward로 구성되어 있다.
- Decoder는 크게, Masked-Multi-Head Attention(self-attention)과 resiedidual 부분(residual add & layer norm), Multi-Head Attention(encoder-decoder attention), Feed Forward로 구성되어 있다.
구현 내용 설명
(순서는 내가 구현한 순서이다.)
1. Multi-Head Attention
[Sacled Dot-Product Attention]
- Multi-Head Attention의 핵심은 scaled_dot_product_attention이다.
- scaled_dot_product는 Query, Key, Value가 있을 때, Query와 Key의 Transpose의 Matmul(Dot Product)를 통해, similarity를 계산하고, similarity 기반으로 Value 값을 참조한다.
- Scaled Dot-product는 Network 여러 부분에서 사용되지만, Decoder 부분에서는 masking 처리를 해야 하는 부분이 있기 때문에, mask부분을 포함해서 함께 구현했다.
def scaled_dot_product_attention(self, q, k, v, mask=None):
d_k = k.size()[-1]
k_transpose = torch.transpose(k, 3, 2)
output = torch.matmul(q, k_transpose)
output = output / math.sqrt(d_k)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(1).unsqueeze(-1), 0)
output = F.softmax(output, -1)
output = torch.matmul(output, v)
return output
[Multi-Head Attention]
- Multi-Head Attention은 scaled Dot-Product Attention을 query에 해당하는 value 값들을 참조하기 위해 사용하는데, query, key, value를 그대로 사용하는 것이 아니라, 여러 개의 head로 나누고, query, key, value를 linear projection 한 후, 사용한다.
- Scaled Dot-Product Attention 이후, 각 head의 value 값을 concat하고, linear layer을 거쳐 output을 낸다.
- 주의할 점은, sequence의 순서가 중요하기 때문에, contiguous를 사용해서, 순서를 유지한다는 점이다.
class MultiHeadAttention(nn.Module):
def __init__(self, dim_num=512, head_num=8):
super().__init__()
self.head_num = head_num
self.dim_num = dim_num
self.query_embed = nn.Linear(dim_num, dim_num)
self.key_embed = nn.Linear(dim_num, dim_num)
self.value_embed = nn.Linear(dim_num, dim_num)
self.output_embed = nn.Linear(dim_num, dim_num)
def scaled_dot_product_attention(self, q, k, v, mask=None):
...
def forward(self, q, k, v, mask=None):
batch_size = q.size()[0]
# 순서 유지 때문에 view 후 transpose 사용
q = self.query_embed(q).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
k = self.key_embed(k).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
v = self.value_embed(v).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
output = self.scaled_dot_product_attention(q, k, v, mask)
batch_num, head_num, seq_num, hidden_num = output.size()
output = torch.transpose(output, 1, 2).contiguous().view((batch_size, -1, hidden_num * self.head_num))
return output
2. Residual Add & Layer Norm
[Layer Norm]
- Layer Norm은 dimension layer 방향으로 평균을 빼고, 표준 편차로 나누는 Normalization 기법이다.
- 이 부분은 nn.LayerNorm을 통해, 구현할 수 있다.
def layer_norm(self, input):
mean = torch.mean(input, dim=-1, keepdim=True)
std = torch.std(input, dim=-1, keepdim=True)
output = (input - mean) / std
return output
[Add & Layer Norm]
- 이전 층의 output을 layer norm을 통해, normalization 한 후, residual 값을 더해준다.
class AddLayerNorm(nn.Module):
def __init__(self):
super().__init__()
def layer_norm(self, input):
...
def forward(self, input, residual):
return residual + self.layer_norm(input)
3. Feed Forward
- Feed Forward는 Fully Connected Layer → Relu → Fully Connected Layer로 구성되어 있다.
class FeedForward(nn.Module):
def __init__(self, dim_num=512):
super().__init__()
self.layer1 = nn.Linear(dim_num, dim_num * 4)
self.layer2 = nn.Linear(dim_num * 4, dim_num)
def forward(self, input):
output = self.layer1(input)
output = self.layer2(F.relu(output))
return output
4. Encoder
- Encoder는 Multi-Head Attention → Residual Add & Layer Norm → Feed Forward → Residual Add & Layer Norm 순으로 구성되어 있다.
- Encoder는 단순히, 앞서 선언했던, sub layer들을 연결하는 방식으로 구현했다.
class Encoder(nn.Module):
def __init__(self, dim_num=512):
super().__init__()
self.multihead = MultiHeadAttention(dim_num=dim_num)
self.residual_layer1 = AddLayerNorm()
self.feed_forward = FeedForward(dim_num=dim_num)
self.residual_layer2 = AddLayerNorm()
def forward(self, q, k, v):
multihead_output = self.multihead(q, k, v)
residual1_output = self.residual_layer1(multihead_output, q)
feedforward_output = self.feed_forward(residual1_output)
output = self.residual_layer2(feedforward_output, residual1_output)
return output
5. Decoder
- Decoder는 Masked Multi-Head Attention → Residual Add & Layer Norm → Multi-Head Attention → Residual Add & Layer Norm → Feed Forward → Residual Add & Layer Norm 순으로 구성되어 있다.
- Encoder와 마찬가지로, 앞서 구현해놓은 sub-layer를 연결하면 되지만, 중간 Multi-Head Attention은 Query와 Key를 Encoder의 Output을 사용하기 때문에, 이 점을 명시해야 한다.
- Decoder는 Ecoder와 다르게, masking을 이용하여, mask를 인자로 받는 것도 주의해야 한다.
class Decoder(nn.Module):
def __init__(self, dim_num=512):
super().__init__()
self.masked_multihead = MultiHeadAttention(dim_num=dim_num)
self.residual_layer1 = AddLayerNorm()
self.multihead = MultiHeadAttention(dim_num=dim_num)
self.residual_layer2 = AddLayerNorm()
self.feed_forward = FeedForward(dim_num=dim_num)
self.residual_layer3 = AddLayerNorm()
def forward(self, o_q, o_k, o_v, encoder_output, mask):
masked_multihead_output = self.masked_multihead(o_q, o_k, o_v, mask)
residual1_output = self.residual_layer1(masked_multihead_output, o_q)
multihead_output = self.multihead(encoder_output, encoder_output, residual1_output, mask)
residual2_output = self.residual_layer2(multihead_output, residual1_output)
feedforward_output = self.feed_forward(residual2_output)
output = self.residual_layer3(feedforward_output, residual2_output)
return output
6. Transformer
- 전체 Transformer는 Input Embedding, Positional Encoding, Output Embedding, N개의 encoder와 N개의 decoder로 구성되어 있다.
[positional_encoding]
- positinal encoding은 짝수번째 token과 홀수번째 token이 각기 다른 식을 따른다. 아래 식에서 i는 hidden dimension 방향의 index이고, pos는 positional 방향(몇 번째 seq인지)을 의미한다.
- positional encoding은 크게, 두 부분에서 사용되는데, Input과 Output의 sequence length 길이가 다를 수 있기 때문에, 이것을 인자로 받는 형태로 구현했다.
- 마지막에 self.register_buffer는 추후, model parameter 학습 시, psotional encoding이 학습되지 않도록 막아주기 위한 용도이다.
def position_encoding(self, position_max_length=100):
position = torch.arange(0, position_max_length, dtype=torch.float).unsqueeze(1)
pe = torch.zeros(position_max_length, self.hidden_dim)
div_term = torch.pow(torch.ones(self.hidden_dim // 2).fill_(10000),
torch.arange(0, self.hidden_dim, 2) / torch.tensor(self.hidden_dim, dtype=torch.float32))
pe[:, 0::2] = torch.sin(position / div_term)
pe[:, 1::2] = torch.cos(position / div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
return pe
[input & output Embedding]
- Embedding은 nn.Embedding을 통해, 쉽게 구현할 수 있다.
- Embedding의 첫 번째 인자는 input 데이터의 total word 개수, 두 번째 인자는 hidden dimension의 수이다.
- total_word_num은 sequence dictionary에 존재하는 unique value의 개수를 의미한다. (전체 단어가 아님을 주의)
- 사실 편의를 위해, 공통 total_word_num을 사용했는데, 번역과 같은 경우, input의 단어 개수와 output의 단어 개수가 다를 수 있어, task에 따라서는 다른 인자를 받는 게 맞다.
self.input_data_embed = nn.Embedding(total_word_num, self.hidden_dim)
self.output_data_embed = nn.Embedding(total_word_num, self.hidden_dim)
[Transformer]
- Transformer의 Encoder 부분은 앞서 구현했던, Encoder를 N번 반복하는 구조로 구현되어 있다.
- Encoder 부분에 들어가는 query, key, value는 문장의 embedding 한 값으로 모두 같고, (참조를 위한 query와 key가 비효율적이다.) 전번째 encoder의 결과가 다음 encoder의 query, key, value가 된다.
- Decoder 부분도 비슷하지만, Encoder의 output이 사용된다는 점, Decoder 단에서는 다음 sequence를 볼 수 없기 때문에, 그 부분을 처리하기 위한 mask가 존재한다는 점이 다르다.
- Decoder에 masking으로 0 값을 넣어주었지만, 실제 학습해서는 매우 작은 값을 넣어주는 것이 학습 측면에서 유리하다고 한다.
- Encoder 부분과 Decoder 부분을 모두 거치면, 목적에 맞는 fully connected layer를 연결하여, output을 낸다.
class Transformer(nn.Module):
def __init__(self, encoder_num=6, decoder_num=6, hidden_dim=512, max_encoder_seq_length=100,
max_decoder_seq_length=100):
super().__init__()
self.encoder_num = encoder_num
self.hidden_dim = hidden_dim
self.max_encoder_seq_length = max_encoder_seq_length
self.max_decoder_seq_length = max_decoder_seq_length
self.input_data_embed = nn.Embedding(max_seq_length, self.hidden_dim)
self.Encoders = [Encoder(dim_num=hidden_dim) for _ in range(encoder_num)]
self.output_data_embed = nn.Embedding(max_seq_length, self.hidden_dim)
self.Decoders = [Decoder(dim_num=hidden_dim) for _ in range(decoder_num)]
self.last_linear_layer = nn.Linear(self.hidden_dim, max_seq_length)
def position_encoding(self, position_max_length=100):
...
def forward(self, input, output, mask):
input_embed = self.input_data_embed(input)
input_embed += self.position_encoding(self.max_encoder_seq_length)
q, k, v = input_embed, input_embed, input_embed
for encoder in self.Encoders:
encoder_output = encoder(q, k, v)
q = encoder_output
k = encoder_output
v = encoder_output
output_embed = self.output_data_embed(output)
output += self.position_encoding(self.max_decoder_seq_length)
output_embed = output_embed.masked_fill(mask.unsqueeze(-1), 0)
d_q, d_k, d_v = output_embed, output_embed, output_embed
for decoder in self.Decoders:
decoder_output = decoder(d_q, d_k, d_v, encoder_output, mask)
d_q = decoder_output
d_k = decoder_output
d_v = decoder_output
output = F.softmax(self.last_linear_layer(decoder_output), dim=-1)
return output
총평
- 실제 NLP 단어 예측 등, 데이터셋을 넣어보기 위해, dataloader와 학습 등을 연결해 봐야겠다.
- 특정 task를 풀기 위해, 데이터셋을 처리하기 위한 model을 짜는 것도 좋지만, 가끔은 논문을 그대로 구현해 보는 것도 좋을 것 같다. 특히, 그림과 글만 보고 구현을 하려고 하니, 내가 정확하게 알지 못했던 부분, 특히 머리로 이해하고 넘어간 부분을 완전히 알게 된 것 같아 좋다.
'Python' 카테고리의 다른 글
JAX: Just Another XLA 설명 (2) | 2023.08.07 |
---|---|
Pandas 성능 향상을 위한 방법들 (2) | 2023.07.21 |
Pytorch Profiler Tensorboard로 시각화 (1) | 2023.07.10 |
Pytorch Resource & 모델 구조 Profiler 도구 (torch profiler) (1) | 2023.07.09 |
Python 프로파일링을 위한 도구들 (Process, Memory, Execution Time) (5) | 2023.07.06 |