지난 lecture까지 거치면서, 우리는 가장 핵심적인 NLP의 전략인 RNN / word encoding -> seq2seq -> attention 등을 이어 보았다.
이번 lecture는 이러한 구조를 어떻게 최적화하였는지, 세부적인 테크닉들에 집중하며 최근까지 가장 많이 쓰인 transformer model까지 들여다본다.
Recurrent model Issues
지난 lecture까지 보았던 recurrent model의 근본적 issue들을 따져보면 다음과 같다.
- linear interaction distance
LSTM에서 어느정도 극복을 해보려했으나, 어쨌든 근본적으로 선형적 구조이기에
거리가 멀어질 수록 정보의 전달이 어렵다.
- lack of parallelizability
이러한 선형 구조는 역전파/순전파시에도 문제가 되어, 동시에 여러 time 처리가 불가능해 complexity 상승을 유발한다.
그렇다면, 저번 lecture에서 배운 attention만을 순수하게 이용하면 어떨까?
attention을 통해 여러 단어들을 한번씩 모두 attention score를 구하는 것이다.
Self-attention
단순하게 보면 attention을 자기 자신들에 여러번 반복하여 새로운 embedding을 구하는 것이다.
이렇게 구하는 embedding들은 다른 모든 단어 embedding의 weighted sum이기에, 충분히 문장 embedding으로써의 기능을 할 수도 있을 것이다.
이러한 방식을 self-attention mechanism이라 한다.
위 그림과 같이, 'learned'라는 query에 대해 각 key들의 attention score를 구하여, attention distribution을 구한다.
이렇게 구한 각 weight와 각 value를 곱하여 query에 대한 output을 도출하는 것이 각 단어의 attention 과정이다.
잠깐 착각하는 것으로는, hidden state같은 것이 없기에 attention만으로는 parameter가 어딨냐 할 수 있지만, 초반 lecture와 같이 각 단어의 임베딩을 parameter로 학습한다고 생각하면 된다.
대신 attention score만 계속 반복하는 것은 단지 linear operation들에 불과하기에, feed-forward network와 함께 사용하여야 한다.
물론 앞에서 RNN이 선형 구조라 문제가 있다고는 했지만, 지금까지의 self-attention은 아예 선형 구조없이 계산하기에 문장 내 단어의 위치를 학습할 수가 없다.
알다시피 문장 내 단어의 위치 역시 문장 해석에 정말 중요한 요소이기에, 이를 반영하기 위해 input으로
word embedding 뿐만이아닌, positional embedding을 더하여 학습한다. (해당 임베딩은 learnable하게 만들수도, 단지 index 형식으로 1,2,3,,,,번째로 fixed하게 만들수도 있다. 각 장단이 있다.)
또한 학습때 길이가 다른 데이터를 학습할 때, 혹은 테스팅 때 예측해야할 단어 뒤를 무시해야 할 때 같은 경우
masking을 사용하여, 관계없는 위치의 임베딩들을 attention weight=0으로 만들면,
우리의 self-attention network가 학습가능해진 것이다.
아래에 마지막으로 정리한 그림을 첨부한다.
Transformer
이제 transformer 모듈을 알아본다.
transformer는 기본적으로 masked self-attention을 사용하나, 하나 더 발전하여 multi-head로 attention을 계산한다.
단순히 말해 여러 번 attention을 계산하여 사용한다. 이는 실제 문장에서도 '같은 맥락'이라 비슷한 단어가 있고, '비슷한 syntax(알파벳?)'이라 비슷한 단어가 있기에 다양한 관점에서 집중할 부분들을 찾아 사용하기 위함이다.
head를 나누는 것은 행렬을 쪼개어 계산하는 것과 같아서, 나누지않고 한번에 계산하는 것보다 계산이 efficient하기도 하다.
또한 attention score가 너무 클 경우 softmax한 output값에 비해 gap이 커져 gradient가 작아질 수 있어, softmax이전에 scaling도 진행한다.
또 transformer에서 추가되는 테크닉은 두 가지가 더 있다. 하나씩 간단히 살펴보면
- residual connection
ResNet에서 자주보았을 텐데, attention output과 input을 더하여 최종 output을 구하는 방식이다.
이는 input에 순수하게 1의 gradient을 제공하기에, 학습 효과를 많이 높일 수 있다.
- layer normalization
각 층을 지날때마다 각 단어 임베딩 value값을 정규화하는 과정이다. 이를 통해 쓸데 없이 수치가 크거나 작아서 생기는 불필요한 정보를 줄이는 효과를 얻을 수 있다.
최종적인 transformer는 seq2seq처럼 encoder, decoder로 나뉜다.
decoder는 아래와 같고, encoder는 저기서 'masked'만 빼면 된다. (보통 encoder의 경우 LM 이 아닌 문장순서를 고려안하는 NMT같이 전체 문장 encoding을 필요로 하는데 쓰이기 때문.
대신 decoder의 경우 LM처럼 문장 순서가 중요하여 뒷부분 masking이 필요할때 사용.)
encoder-decoder모형은 seq2seq의 진화판이라 생각하면 될 듯 하다.
한 부분만 보면, decoder중간에 encoder output을 사용하는 부분의 경우 decoder input을 query로 하여 encoder output의 key, value와 attention하여 encoder의 value값을 통해 attention score를 구한다고 생각하면 된다.
이러한 transformer들은, 기존 RNN의 non-parallelazibility를 극복했기에 빠른 학습이 가능해져, 주로 큰 corpus에 pre-training까지 거칠 수 있게 되어 매우 큰 폭의 성능 향상을 가능케 했다. (어떤 task든지),. (linear distance도 극복했으니, 그냥 학습만 해도 성능이 높았다.)
그래서 현재까지도 해당 모델은 널리 쓰이고 있다. 고칠 점은 무엇이 있을까?
일단 하나는 계산량 감소이다. self attention의 경우 multi-head로 줄이기는 했으나, big model 학습에는 아직도 계산량이 많다. 한번에 작업하는 글의 길이가 2배가 되면 self-attention은 4배로 계산량이 늘어나기 때문이다.
뭐 sequence length dimension을 value, key에 적게 할당하여 사용하는 등의 시도가 좋았으나(2020)
실제 환경에선 오히려 더 계산량의 줄이는 시도는 쓰이지 않는다. 애초에 너무 커지면 잘 작동도 안하고.
게다가 놀랍게도 여러 modification이 발표 이후 굉장히 많았지만, 눈에 띄는 성능 향상은 없었다고 한다..!
정말 transformer가 완벽한 모델일지... 신기하면서도 어려운 것 같다.
이후 관련 논문도 참조해보려 한다.
* Do Transformer Modifications Transfer Across Implementations and Applications?
https://arxiv.org/abs/2102.11972
해당 논문에선 성능이 거진 비스무리한 것에 대해, cherry-picking을 피하면서 hyperparameter도 동결할 필요가 있고, 발표 전 최대한 많은 task(NLP를 너머서도)에 실험을 해봐야 진정으로 모델간 정확한 비교를 하여 보다 robust한 future architecture improvement가 가능할 것이라 본다.
정말 교과서적으로 동의는 하나, 이렇게까지해도 비슷하다는 것은 논문을 내는 입장에서도, (혹은 회사 입장일수도.) 많은 실험을 못해보거나 그정도까지 성능 향상은 어떻게 해도 안나왔다는 뭐랄까 조금은 비관적인? 결과로 받아들여졌던 것 같다.
'모아 읽은 보따리 > cs224n' 카테고리의 다른 글
L10: Pretraining (0) | 2023.02.07 |
---|---|
L7: Attention (0) | 2023.02.04 |
L6: LSTMs, NMT (0) | 2023.02.03 |
L5: Language Modeling, RNN (0) | 2023.02.02 |
L4: Dependency Parsing (0) | 2023.02.01 |