안녕하세요, 오늘 읽은 논문은 An Empirical Study of Training Self-Supervised Vision Transformers 입니다.
해당 논문은 MoCov1/2보다 좋은 성능을 갖는 MoCov3을 제안하고, 이 MoCov3을 Vision Transfermers(ViT)에 적용하는 실험을 합니다. CNN 구조에 SSL을 적용하는 많은 연구가 이루어져 있지만 ViT 모델에는 어떻게 self-supervised learning을 적용해야하는지에 대해 많은 연구가 이루어지지 않았습니다. 저자는 다양한 실험을 통해 self-supervised ViT의 효과를 조사합니다. 또한 기존 self-supervised transformer 보다 좋은 성능을 보여줍니다.
실험을 위해 batch size, learning rate, optimizer 와 같은 학습에 필수적인 요소를 조사하며 self-supervised ViT에 발생하는 instability 문제를 발견합니다. 이 instability를 완화하기 위한 방법으로 simple trick을 제안하고, ViT의 다양한 모델 구조, 크기를 실험합니다.
그 결과로 contrastive learning을 사용하는 self-supervised transformer가 좋은 성능을 나타낼 수 있다는 것을 발견합니다. 또한 supervised ViT의 경우에 일정 크기 이상으로 모델의 크기가 커지면 성능에 악영향을 주었는데, self-superivsed ViT에서는 매우 큰 ViT 모델을 사용한 경우에 특정 task에 대하여 supervised 보다 성능이 뛰어나다는 것을 보여줍니다.
이외에도 SSL CNN과 ViT 모델을 비교하고 ViT에서 position embedding을 제거하는 실험을 합니다.
MoCov3
큰 batch size(4096)에서 benefit이 사라지기 때문에 MoCo의 가장 큰 특징인 queue를 더이상 사용하지 않습니다. encoder는 SimSiam에서 제안하는 backbone + pred mlp + proj mlp를 사용하며, momentum encoder는 backbone + pred mlp만을 사용합니다. 그리고 이 momentum encoder는 기존 MoCo와 동일하게 momentum update를 합니다.
loss는 InfoNCE loss를 사용하며 symmetric 성질을 갖도록 수정을 합니다.
이 MoCov3를 CNN에 적용하여 SOTA 성능을 달성합니다.
이제 이 MoCov3을 ViT에 적용하는 실험을 합니다.
Stability of Self-Supervised ViT Training
MoCov3을 ViT에 적용하였을 때, 학습의 instability 문제가 발생합니다.
contrastive learning 기반 self-supervsed learning은 batch size를 크게 하는 경우에 더 많은 negative sample을 활용할 수 있어, 성능에 도움을 주지만 ViT는 오히려 instability한 결과와 이로 인해 성능이 감소되는 것을 보여줍니다.
batch=2048는 72.6%의 성능과 부드러운 학습 곡선을 보여주지만, batch=4096, 6144는 불안정한 학습 곡선과 이로 인한 성능저하가 발생합니다. 이 instability 문제를 해결하기 위해 제안하는 한 가지 trick은 아래에서 살펴보도록 하겠습니다.
학습률에 따른 실험을 합니다. 학습률이 작은 경우에 training이 더 stable 하지만 under fitting이 됩니다. 아래 그림을 보면 lr=0.5e-4는 lr=1.0e-4 보다 1.8% 낮은 성능을 보여줍니다. 학습률을 크게 설정하는 경우에 less stable 되고 정확도도 낮아집니다. 또한 optimizer는 AdamW와 LAMB optimizer 두 개를 실험합니다. 실험 결과는 아래 그림에서 살펴보실 수 있습니다.
두 그림을 비교하면 LAMB optimizer가 학습률에 더 민감하다는 것을 보여줍니다. 따라서 해당 논문은 AdamW를 사용하여 실험을 진행합니다.
A Trick for Improving Stability
MoCov3을 ViT에 적용하는 경우에 instability 문제가 발생합니다. stability를 향상시키기 위해 simple trick을 제안합니다.
simple trick을 소개하기 전에 저자는 왜 instability가 발생하는지 추측을 합니다.
위 그림을 보면 첫 번째 레이어(patch projection)에 gradient spike가 발생하고 몇 epoch 후에 last layer에 gradient spike가 발생합니다. 이 결과를 보고 저자는 instability가 shallower layer에서 발생한다고 추축합니다. 따라서 학습 동안 patch projection을 freezing 합니다. 다른 말로하면 저자는 patch를 embedding하는 patch projection을 학습되지 않은 random patch projection으로 고정하여 사용합니다. 이는 stop-gradient를 적용하여 간단히 구현할 수 있습니다.
이 Random patch projection은 학습을 stabilize하게 하고, 더 부드러운 학습 곡선을 갖게 합니다. 또한 정확도도 향상되는 결과를 나타냅니다.
그리고 이 random patch projection이 다른 self-supervised learning + ViT 방법에도 효과적이라는 것을 보여줍니다.
실험을 통해 patch projection layer를 학습하는 것은 필수적이지 않다는 것을 보여줍니다. 기존 patch의 정보를 보존하기 위해 random projection이 충분하다고 말합니다.
Experiment
실험에 사용하는 ViT model variants 입니다.
ViT 모델을 학습하는데 걸리는 시간과 FLOP 입니다.
ViT에 여러 self-supervsed learning 방법을 적용한 결과입니다.
positional embedding choice 비교입니다.
ViT에서 class token에 따른 결과입니다. pool은 global average pooling, LN은 layerNorm을 의미합니다.
ViT 내의 MLP head에 BN을 적용한 결과입니다.
MoCov3에서 사용하는 prediction MLP 유무 실험 결과입니다.
momentum encoder에서 momentum 계수에 따른 실험 결과입니다.
training epoch에 따른 실험입니다.
SOTA 모델과의 비교입니다
transfer learning 결과입니다.
참고자료