본문 바로가기
내 맘대로 읽는 논문 리뷰/RL

Neural Architecture Search with RL

by 동석2 2021. 4. 29.

Neural Architecture Search with Reinforcement Learning

https://arxiv.org/pdf/1611.01578.pdf

 

오랜만에 읽는 논문, 중간이 끝나서 조금 시간이 났다..

이제부터 강화학습 수업 프로젝트를 위해 NAS 관련 공부를 진행할 터라, 가장 처음에 나온 이 논문을 읽어보려 한다.

 

Intro

NAS는 neural network search의 약자로, Controller로 불리는 하나의 RNN을 통해, variable-length string으로 된 neural network 구조/connectivity를 scratch부터 학습시켜가며 최적의 network를 구축하는 방법이다. 학습에는 validation accuracy를 reward로 한 policy gradient방법을 사용했다.

이러한 방법으로 저자들은 기존의 사람의 직관으로 구축한 network들(resnet, VGG)등등보다 더 좋은 성능을 냈다. (굉장히 오래 걸리지만)

 

Method

generate model descriptions with controller RNN

유연성을 위해, controller는 RNN으로 설정했다고 한다. 위 그림을 보면 알 수 있듯, Conv layer의 경우 filter의 h,w, stride h,w, #f와 같은 hyperparam들을 output으로 내고, 이를 input으로 다시 예측하는 RNN 형식으로 network 구조를 형성한다.

모든 형성이 끝나면 이 구조(child-network)로 학습하여 val acc을 구해, RNN param을 backward로 학습하도록 한다. 학습 방법은 다음 섹션에서 나온다.

training with REINFORCE

논문이 사용한 policy gradient 방법은 REINFORCE 방법이다.

우리가 optimize할 것은 RNN param에 대한 reward(val acc) 기댓값인데, 우리의 reward는 미분불가능하기 때문에 policy gradient방법을 이용한다.

그 중 REINFORCE라는 이름의 방법을 사용하는데, 아래 식과 같이 업데이트한다.

REINFORCE를 짚고 가면, MC policy gradient 방법으로 마지막 (여기선 T) time step부터 앞으로 가면서 Return G_t를 accumulate하며 구하여, 각 step마다 policy gradient를 적용하는 방법이다.

여기서 log prob을 사용하는 것은 log값이 well-scaled이기 떄문으로, prob의 범위와 맞기 때문에 사용했다고 한다.

논문은 문제가 너무 variant가 크기 때문에, 총 emperical approximation에서 G_t에서 이전 step moving average를 빼어 계산함으로써 variant를 줄이는 방법을 사용했다.

논문의 approximation equation : R_k가 이번 구조의 validation accuracy(reward), b가 이전 구조의 reward들의 moving average이다. b는 어차피 이번 구조와 관련이 없기 때문에, b를 빼줌으로써 R_k의 편향을 어느정도 방지하는 역할을 하도록 한다.

또한 효율적인 training을 위해 (하나하나 학습 엄청오래걸림) S개의 서버를 두고 K개의 controller를 두어, 각각 m개의 구조를 parallel하게 학습한 뒤 gradient를 모아 server에서 param을 학습하도록 햇다.

architecture complexity : skip connection & layers

앞에서 layer선택 시, 이전 layer에 대한 것은 고려하지 않았었다. (skip / residual..) 이런 구조를 고려하기 위해 이전 layer에서 다음 layer와의 연결 확률을 sigmoid를 이용해 다음과 같이 층마다 Anchor point를 정의해 사용하였다.

현재 hidden/이전 hidden에 대한 sigmoid 출력값이다. W와 v역시 trainable parameter로, 큰 변화없이 기존 REINFORCE로 학습된다.

이 확률을 통해 이전 N-1개의 layer 중 어떤 층을 연결할 지를 결정한다. 차원이 안 맞을 수도 있는데, 이 경우 연결되지 않은 층도 마지막에 합치거나, zero-padding등을 활용해 맞춘다.

또한 앞에선 conv layer만 고려하였으나, lr, batch norm과 같은 추가 layer들도 추가하기 위해 맨 앞에 layer type을 예측하는 층을 추가하기도 한다.

generate RNN

이번엔 RNN 구조를 만드는 방법을 알아보자. 위 예시에선 LSTM과 비슷하게, cell state를 통해 다음 output과 hidden state를 도출하는 recurrent network를 만든다.

우선 세 층에 대한 connection/activation을 먼저 controller를 통해 찾고, cell injection에 쓸 연산 블록을 구한다. 이후엔 cell indices를 구하는데, 나온 index의 순서에 맞춰 왼쪽 tree구조를 합치는 방식이다.

실제 예시그림인 오른쪽 그림을 보면,

(1) leaf node인 0번째, 1번째 블록을 controller대로 계산

(2) c_(t-1)을 연산하는 index = cell indices의 두번째 블록 = 0, 곧 0번째 블록과 c_(t-1)을 연산

(3) 이렇게 나온 a_new와 남은 1번 블록을 2번째 블록으로 연산, h_t 도출

(4) c_t를 연산하는 index = cell indices의 첫번째 블록 = 1, 곧 1번째 블록에서 c_t 도출 (이 때 activation이전 output을 사용한다)

사실 구체적으로 c_t를 연산하는 공식?에 관해선 제대로 설명이 안되있어서, 우선 대강 이런 식으로 recurrent network도 만들 수 있다는 것에 집중하면 될 것 같다. LSTM 방식과 비슷하게 계산방식을 만든건지..

 

Experiment

실제 실험에서, 연구진들은 Conv layer와 skip connection, bn layer정도의 search space에서 학습을 진행했다. convolution layer의 경우 filter #/H/W 범위와 stride범위도 일정하게 정해진 채로 학습이 진행되었다,

실험에는 2-layer LSTM controller를 대략 100개정도(server는 20개 가량) 두고, 8개씩 child network를 돌렸는데, 어느정도 규모냐면 800개 가량의 GPU를 동시에 사용한 규모이다. 실제로 학습 시간도 엄청 걸렸다고... reward는 best validation accuracy를 통해 부여했다.

이렇게 12800개 정도의 구조를 트레이닝 후, best validation model을 grid search하여 가장 좋은 모델을 찾은 결과, 기존의 모델 구조보다 더 좋은 성능을 냈다고 한다.

실험은 세 방식으로 진행되었는데, 먼저 stride, pooling 예측 없이 진행했을 때는 depth와 accuracy의 balance를 지킨, 괜찮은 모델이 나왔다. 또한 skip connection을 빼거나 등의 실험에선 빼면 7%까지 rate이 올라가는 등의 모습도 볼 수 있었다.

두 번째로 stride 예측 시엔 별 차이가 없었고, pooling 예측을 포함한 세 번째 결과에선 보다 적은 깊이로 4.47%라는 당시 SOTA모델들과 비슷한 성능을 내는 모델을 볼 수 있었다. 좀 더 param 수를 늘리고 filter를 키우면 제일 좋은 성능을 보였다. (DenseNet-BC의 경우 1x1 Conv로 param 크기를 줄여서, 비교가 어렵다고 한다. 저자 말로는..)

앞은 CIFAR-10 dataset에 대한 결과이고, RNN 구축에 관한 실험을 위해 Penn Treebank dataset에서도 실험을 해보았다고 한다.

이 실험에선 앞서 본 RNN 구축과 비슷하게 combination = { add, element_mult }, activation = {relu, tanh, sigmoid, identity }로 search space를 설정한 뒤 network 구조를 찾는다. 이 실험의 reward는 constant c / validation perplexity 로서 구했다.

이 실험에서도, 가장 좋은 성능을 보였다. 같은 parameter 크기에서도 그랬고, 더 키워서 더좋은 성능도 냈다. 이렇게 찾아낸 cell 구조는 초반에는 LSTM과 비슷한 연산이었고, 이후에 달라진다. (마지막에 그림을 첨부하겠다)

+실험들

  • 위에서 찾은 cell 구조로 다른 task에도 적용한 결과, 여전히 좋았음.
  • 다른 network (GNMT framework)에 해당 cell 적용해도 좋았음 (transfer learning good)
  • max,sin같은 함수를 더 넣어 찾아도 좋음 (bigger search space)
  • random search와의 accuracy 차이도 굉장히 컷음. (policy gradient 확인)

Conclusion

이번 논문은 NAS 분야를 제대로 제시한 첫 논문으로, 강화학습을 통해 기존 구조들보다 더 좋은 모델을 찾을 수 있다는 것을 보여주었다. 한편 실험에서 알 수 있듯 굉장히 많은 training cost를 요구하는데, 사실 성능은 결국 이 search space와 training 속도에 많이 달려있기 때문에 이후 논문들은 NAS를 어떻게 하면 효율적이고 빠르게 할 수 있을지에 집중하게 된다. 앞으로는 이러한 논문들을 하나씩 볼 예정이다.

NAS를 통해 찾은 LSTM을 대신하는 cell 구조들. 왼쪽위가 LSTM, 오른쪽이 찾은 cell이다. 가장 아래는 sin,max operation을 포함하여 찾은 cell구조이다.

 

'내 맘대로 읽는 논문 리뷰 > RL' 카테고리의 다른 글

ProxylessNAS  (0) 2021.05.10