본문 바로가기
내 맘대로 읽는 논문 리뷰/CV

TResNet - ASLLoss

by 동석2 2021. 3. 10.

Asymmetric Loss For Multi-Label Classification

arxiv.org/pdf/2009.14119.pdf

 

MVC Dataset을 통해 공부.

MVC dataset은 15년에 나온 Multi-label classification에 적합한 dataset이다. 264개의 label에 대해 binary class로 명시한 truth값을 가졌기 때문이다.

이에 최근에는 Multi-label classification을 어떻게 접근을 했을지 보기 위해, papers-with-code에서 해당 논문을 읽게 되었다.

ASL loss는 이름부터 Asymmetric loss로, 비대칭적으로 계산한다는 것을 생각하면 편하다.

이전에 읽었던 RetinaNet에서는 다음과 같은 Focal Loss를 제시했었다.

이는 흔히 쓰던 Cross Entropy Loss를 One-stage Object detection에 사용할 때 생기는 foreground-background class 간의 불균형을 해결하기 위해 나온 제안으로, 일치하지 않는 Anchor값이 훨씬 많기에 바로잡기 위해 background class의 loss 계산에 gamma라는 지수를 통해 해당 loss값을 줄여 불균형을 해소하려 한 시도였다.

다시 말해 Easily Classified 된 negative class들의 magnitude를 줄이고, 어려운 문제들을 집중하게 해준다.

예로 pt가 작을 때,(어려운 문제) (1-pt)가 1과 가깝기에 기존 CELoss와 차이가 거의 없다.

하지만 pt가 크면,(쉬운 문제) (1-pt)가 0과 가깝기에 기존 CELoss보다 크기가 훨씬 작아지는 것이다.

더욱 구체적인 이야기는 RetinaNet 리뷰를 참고하면 좋다.

 

이러한 idea를 multi-label 분류에도 이용한 것이다. multi-label classification 역시 불균형이 존재한다. 지금 학습하려는 MVC dataset만 보아도, 264개나 되는 label 수를 가졌으나 실제로 한 이미지당 가지는 label 갯수는 많아봐야 2~30개다. 그러니 0의 label이 1보다 많으니 점차 모델이 negative하게 학습될 확률이 커지는 것이다.

논문은 focal loss에서 그치지 않고, 여러 idea를 추가하는데, 우선 focal loss는 pos/neg sample을 가리지 않고 모두 적용하는데, 이는 희귀한 positive sample을 무시할 수도 있다는 이야기를 한다.

그래서 positive와 negative sample에 대해 다르게, 비대칭적으로, loss를 계산한다.

focal loss와의 차이는 gamma가 둘로 나눠졌다는 것이다. r-를 r+보다 크게 설정함으로써 원래의 focal loss효과에 더해 positive sample은 더 챙기게끔 한 것이다.

 

사용한 또 하나의 방법은 Probability Shifting이다. 위의 계산법을 통해 neg sample에 굉장한 penalty를 주었는데, 이에 멈추지 않고 아예 일정 치 이하가 되면 negative sample의 prob를 0으로 만들어버리는 것이다.

즉 아주 쉬운 negative sample의 영향을 0으로 만든다는 것과 같은 뜻이다.

여기서 m은 hyperparameter로, 다시 말해 negative sample의 probability를 m만큼 shifting하는 것과 같다.

그래서 이것까지 적용하면 최종 ASL loss가 된다. 과연 이 loss가 실제에서도 먹힐까? 논문은 Gradient Analysis를 통해 주장을 뒷받침한다.

 

다음 그래프는 negative sample들에 대한 gradient를 나타낸 그래프이다.

p<0.2가 쉬운 예시, p가 커질 수록 어려운 예시로 보면 된다.

기존의 CE Loss를 보면, gradient가 직선이다. 그에 반해 ASL loss는 어려운 예시일 때 굉장히 높은 grad를 가지며, 아주 쉽거나 아주 어려운 경우엔 0에 수렴하는 값을 가진다. 아주 쉬운 경우는 우리가 asymmetric하게 loss를 계산함을 통해 이해하였다. 그런데 아주 어려운 경우도 일정 p*이상이 되면 loss가 0이 된다. 어떻게 된거지?

아주 어려운 경우에 대해 논문은 Mislabeled를 생각한 경우라고 말하는데, 실제로 아주 어려운 multi-label 문제, 불균형한 데이터셋에 대한 문제들은 잘못 label된 데이터하나가 아주 큰 영향을 줄 수 있다. 어려울 수록 mislabel 데이터가 나올 확률이 크기도 하고.

이 현상은 꼭 ASL이 아니더라도, Probability shifting에 의해 만들어지는 것이다. 그래프로 확인할 수 있을 것이다.

다음은 mean probability analysis 이다.

pt는 pos면 p, neg면 (1-p)로 계산한 확률이다. 즉 CE, Focal loss 모두 neg sample에 너무 많은 가중치를 준다는 것이다. 반면 ASL loss의 경우 pos와 neg sample에 대한 가중치가 거의 비슷하게 주어, 올바르게 학습될 확률이 앞의 두 loss방식보다 훨씬 클 것이다.

이제 실제 결과를 보자. 모두 성능이 더 좋았기에 그래프는 하나만 넣었다.

3가지 모델로 실험한 결과, ASL loss가 확연히 좋은 성능을 보이는 걸 알 수 있었다.

 

논문 막바지 부록을 보면, ASL loss 방식이 multi-label 분류 태스크에 한정하지 않더라도, focal loss보다 좋은 성능을 냈다고 한다. 꽤나 고무적이며, 어떤 데이터셋이든 asymmetric 한 분포를 가진다면 굳이 복잡한 모델 구조를 specific하게 설정하지 않더라도 이 Loss방식을 통해 Optimization을 진행하면 좋은 성능을 기대할 수 있을 것 같다. 

'내 맘대로 읽는 논문 리뷰 > CV' 카테고리의 다른 글

SORT  (0) 2021.03.10
A Survey on Moving Object Detection & Tracking methods  (0) 2021.03.10
Early-Learning Regularization : ELRloss  (0) 2021.03.10
DG-STA  (0) 2021.03.10
RetinaNet  (0) 2021.03.10