본문 바로가기
내 맘대로 읽는 논문 리뷰/NLP

Deeper Transformer with ADMIN (2)

by 동석2 2021. 3. 10.

(1) Very Deep Transformers for Neural Machine Translation (ADMIN 사용, deeper 모델 학습 결과 위주)

https://arxiv.org/pdf/2008.07772v2.pdf

(2) Understanding the Difficulty of Training Transformers (ADMIN 이론적 배경, 구체적 이론)

https://arxiv.org/pdf/2004.08249.pdf

 

1번 논문의 ADMIN의 이론적 배경. 어쩌다 인터뷰한 기업에서 하는 주제라 공부하려 읽었었는데, transformer의 문제를 파헤친 것이 흥미로워 져서.. 이론 리뷰까지 왔다.

우선 모든 내 리뷰가 그렇겠지만, 내가 이해한대로 기록하는 것인지라 잘못된 정보가 많을 것이다.... 그렇지만 이런 어려운 논문일수록 좀 심해질까봐, 일단 본 논문은 밑밥을 확실히 깔겠다. 잘못된 이해가 있으면 comment 남겨주시길..

+수학적 증명은 논문을 참고. 아직 Appendix는 설명할 만큼 이해가 안됨..

 

  • 서론, 배경

이야기는 What Complicates Transformer Training? 에서 시작한다. Transformer는 학습 시에 SGD를 보통 사용하지 않는다. 다른 RNN / CNN에선 잘 되었던 것이, transformer/attention 기반 모델에는 정확도가 나쁘게 수렴되었기 때문이다. 또한 trasnformer는 warmup을 굉장히 필요로 한다. 없으면 심각히는 발산하기도 한다.

여태까지는 이를 불균형한 gradient 때문이라고 생각을 했었지만, 저자는 실제로 중요한 영향을 미치는 것은 amplification affect, 즉 증폭 효과때문이라고 한다. 무엇이 증폭되는 것일까?

우선 알아야 할 것이 있다. Transformer는 residual network를 사용하는데, 여기서 residual network란 간단히 말해 input과 계산값을 더해 값을 내는 네트워크를 말한다. ResNet을 공부했다면 많이 봤을 구조다.

어쨌든 transformer에서도 F = self attention 혹은 feed forward network로써 residual network를 사용한다. 정확히는 그런 구조로만 구성되어있다. 앞으로 논문과 같이 h(x)를 shortcut branch, F(x)를 residual branch라고 하겠다.

바로 이 residual branch에 대한 강한 의존이 transformer training을 unstable하게 만들고, 이것이 small parameter perturbation(한글로 섭동인데, 대강 작은 애들이 모임으로써 큰 하나처럼 보이는? 현상을 말한다)을 증폭시켜 output에 significant한 방해를 준다는 것이 논문에서 말하는 증폭 효과다.

그럼 이 증폭 효과의 원인은 무엇일까. 차근차근 연구진들의 발자국을 따라가 보자.

 

  • Preliminaries

우선 Transformer의 residual network를 요약하면 다음과 같을 것이다.

이러한 방식을 Layer Norm을 계산 이후 한다는 점에서 Post-LN이라 하자.

LN을 계산 전에하는 방식도 있는데, 이는 Pre-LN이라 하며 다음과 같다.

기존에는 Post-LN보다 Pre-LN을 사용했는데, Post-LN의 경우 네트워크 깊이를 키울 경우 발산하는 문제가 발생하여, 덜 robust한 Pre-LN을 사용했던 것이다. (발산 이유는 나중에)

그런데 논문에선 Post-LN이 성능 포텐셜이 더욱 높다고 한다. 즉 수렴만 하면 성능은 Post-LN이 언제나 더 좋았다는 것이다. 이유는 뒤에서 나온다. 일단 패스하고..

 

  • Gradient가 튀는 것이 이유가 아닐까?

저자들은 Post-LN의 발산 이유만 잡으면, 네트워크 깊이가 깊은 Post-LN 방식 transformer가 SOTA를 찍지 않겠냐는 생각으로 이유를 찾았다. 그래서 먼저 생각한 것이 unstable한 gradient의 원인을 찾는 것이었다.

다음은 각 층마다의 gradient norm을 그래프로 나타낸 것이다. 위 그래프를 보면 PostLN Decoder의 Encoder Attention 부분에서만 gradient가 점차 커지는 것을 알 수 있다. 즉 gradient vanishing은 Post LN decoder Encoder Attention 부분 역전파과정에서만 일어난 다는 것이고, 이 것이 Post-LN만 학습이 발산하는 원인으로 연구진들은 우선 생각했다.

 

잠깐 이론적인 분석으로 나머지 네트워크 구조를 보면, Post-LN도 encoder에선 grdient vanishing문제가 없음이 논문 내 Theorem 1으로 알 수 있었다. (LN을 적용함에 따라 Gradient Variance가 항상 감소.-구체적 증명은 논문) Pre-LN에 대한 증명은 Appendix A-1에 있다. (증명 부는 일단 skip)

자, 아무튼 이제 Post-LN decoder에서만 gradient norm이 상승함에따라 역전파 때 gradient vanishing이 일어남을 알았다. 그렇다면 연구진들의 첫 생각대로, 정말 이 현상이 Post-LN의 발산 현상의 진짜 문제일까?

 

  • 발산 현상의 핵심 원인은 정말 Post-LN decoder의 gradient vanishing일까?

여기까지 보면 "야 저기 딱 gradient가 혼자서 저기서 튀네, decoder가 문제네!" 라고 할텐데, 실제 실험 결과는 의외였다.

Post-LN decoder가 문제였으니, encoder만 post-LN으로 바꾸면 괜찮겠네! 였지만, 결과는 역시 발산이었다. gradient가 사라지지도 않았는데, 왜 발산이 된걸까??

이 결과외에도, 어떤 조합을 쓰든 모든 attention 기반 module의 gradient가 unbalance하다는 사실을 논문은 말한다. 이 때문에 SGD같은 optimizer가 소용이 없고, adaptive한 Adam같은 것만 유효했다고 논문은 말해준다. 관련 그래프가 첨부되어 있는데 해상도가 별로다.

  • 진짜 원인이 무엇일까?

그럼 진짜 원인은? 저자들은 Post-LN과 Pre-LN의 early/late training 과정을 분석하며 새 원인을 찾아냈고, 이를 Amplification Effect라고 칭했다. 이제부터 이 증폭 효과에 대해 알아보자.

Pre-LN과 Post-LN의 차이가 무엇인가? 당연하겠지만, Layer Norm의 위치다.

Pre-LN의 경우 어떤 sublayer를 거치든, 한번의 같은 LN과정을 거친다. 곧 더해진 두 residual의 분포가 비슷할 것이고 magnitude에도 별 차이가 없을 것이다.

반면 Post-LN의 경우 Attn을 거치는 sublayer가 LN을 한번 거친 뒤 다시 LN을 거친다. 두 번째 LN 계산 때 차이가 생기는데, 이미 LN을 거친 residual보다 거치지 않은 residual이 정규화 되지 않은 상태기에, 합치는 과정에서 더 큰 magnitude를 가지게 되는 것이다.

이러한 현상이 실제로 학습에 어떤 영향을 미칠까? 다음은 j≤i일 때, i번째 layer output에 대한 j번째 layer의 비율을 B_ij라 할 때 그 분포 그래프이다. (수식으론 대강 output 표준편차 비라고 생각하면 된다.)

Post-LN의 경우 항상 맨 마지막 residual output의 비율이 거의 반을 넘을 정도로 큰 반면, Pre-LN의 경우 이전 layer랑 거의 비슷하다. 그리고 둘다 학습해가면서 직전 residual의 비율이 커진다.

이러한 현상이 영향을 미친다고 보는건데, Pre-LN의 경우 수식 상 magnitude 비율이 거의 같기 때문에, 학습에 limit가 있어 potential이 발휘되지 못하는 것이고, Post-LN의 경우 층이 깊어져도 새로운 층의 비율이 커 잠재적인 특징을 잘잡아 학습이 잘된다고 보는 것이다.

하지만 이러한 효과는 Post-LN에 단점도 있는데, parameter변화의 fluctation도 증폭시킨다는 것이다.

쉽게 이해해보자면, F는 output이다. W는 random하게 init된 parameter로 W*은 W + random bias 다.

논문의 Theorem 2에 따르면, weight update에 따른 output 변화는 이전 layer들의 B_jiC 의 합으로써 추정될 수 있다고 한다. (C는 constant)

그런데 앞에서 우리가 분석한것처럼, B_ij 중 마지막 layer의 B_ii는 Pre-LN에서는 층마다 거의 비슷하니 각각 1/i로 볼 수 있고, Post-LN에서는 거의 C에 근접한다고 볼 수 있다.(거의 마지막 layer값이므로)

그래서 논문은 다음과 같이 말한다.

즉 output의 변화가 학습할 수록 Pre-LN은 logN의 모양으로, Post-LN은 N의 모양으로 진행된다는 것이며, 이는 위 Figure4의 왼쪽그래프를 보면 실험적으로도 알 수 있다.

이러한 output의 큰 변화가, 학습 발산(불안정)의 가장 큰 원인이었던 것이다. 그래서 warmup과정이 transformer에서 중요했던 것이고. (O(N)으로 진행될 학습을 보다 느슨하게.)

논문은 이를 Weight parameter initailization에 변화를 주어, 해결하려 한다.

 

  • ADMIN-Adptive Model Initialization

논문은 Post-LN의 potential을 남기면서, 학습을 안정적으로 바꾸기 위해 output변화를 O(N)에서 O(logN)으로 바꾸려 한다. 이전 처음 ADMIN 기계번역실험 논문을 볼 때 언급했던 ADMIN이 바로 여기서 나온다.

ADMIN은 크게 두가지 과정으로 나뉜다. 이렇게 나눈 이유는 모델들이 서로 다르고, 학습 config가 다르니 Profiling 과정을 통해 이를 먼저 체크하고, Initialization을 통해 주된 Initialization을 실행하기 위함이다.

wi는 새로 추가된 parameter이다. xi = LN(bi)로써 sublayer를 해석하면 된다. fi는 attn or ffn.

Profiling

우선 이전처럼 wi를 일반적으로 initialization(wi=1과 같은 무난한 방법도 ok)한다. 그리고 parameter update없이 각 층 output variation을 기록한다.(영어로 profiling)

어차피 모든 층이 독립이고, 같은 방법으로 init하기에 한 배치정도면 충분하다고 되어있다.

Initialization

이제 profiling에서와 같은 방법으로 weight을 초기화 하되, 값을 아래 값으로 바꾼다.

이렇게 함으로써, early stage에서 마지막 residual의 영향을 Pre-LN과 같이 줄일 수 있다고 한다. B_ii = 1/i가 되어 O(logN)을 가질 수 있다는 것이다! 물론, late training에선 이전의 Post-LN과 같이, 마지막 층에 가중치를 많이 주도록 학습되기 때문에 Post-LN의 장점도 유지하면서, 발산을 막을 수 있다는 것이다. 그림을 보면 좀 더 이해가 쉬울 것이다. Figure 7과 비교해봐도 좋다.

신기하지 않은가. 앞서 Figure4의 세번째 그래프를 보면 Post-LN역시 ADMIN의 영향으로 logN의 그래프를 그림을 알 수 있다. 이로써 initialization에 의존하게 되는 문제를 피하면서, 제대로 학습될 수 있는 환경을 조성할 수 있였다.

 

  • Performance

뭐 앞서 논문에서도 봤지만, 기계번역에서 이 테크닉을 통해 여러 분야에서 SOTA를 찍었다.

이 테크닉을 통해, 보다 깊은 층의 Post-LN 형식 transfomer를 사용할 수 있게 됨으로써 성능이 이전보다 훨씬 좋은 모델을 개발할 수 있는 것을 기대하게 된다. 구체적인 분석을 통해 원인을 찾고, 또한 간단한 초기화 스텝(ADMIN)을 통해 이를 해결하려 한 연구진분들이 새삼 너무나 대단했던 논문이다. 논문에선 Theorem 2의 일반화도 얘기하셨다.

아직도 Deep Learning에는 수많은 문제들이 있는 것 같다. 공부할 게 새삼 너무나 많은 것 같다. 그래도 굉장히 재미있게 읽은 논문이었다! 뭔가 연구진들의 발자국을 따라가는 느낌이랄까..!

'내 맘대로 읽는 논문 리뷰 > NLP' 카테고리의 다른 글

Deeper Transformer with ADMIN (1)  (0) 2021.03.10
BART  (0) 2021.03.10
MT-DNN  (0) 2021.03.10