CMT: Convolutional Neural Networks Meet Vision Transformers
PDF, Vision Transformer, Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, Chunjing, Yunhe Wang, arXiv 2021
Summary
CMT는 ViT에 CNN 구조를 추가하여 성능을 개선한 모델입니다. long-dependency 정보를 포착하는데 특화되어 있는 ViT와 local feature을 modeling 하는데에 장점이 있는 CNN 구조를 결합하면 더 좋은 성능을 보여줄 수 있다는 것을 보여줍니다.
현재 ViT를 scaling up하여 SOTA 성능을 기록하고 있는 ViT-G를 제외하고, CMT는 동일한 파라미터 내에 순수한 모델 구조만으로 SOTA 성능을 달성하고 있는 것 같습니다. 위 그림에서 CMT가 vision transformer의 variant인 CeiT, CvT, PVT, DeiT의 성능을 능가합니다.
CMT의 가장 핵심적인 특징은 5가지로 생각됩니다.
(1) CMT STEM
CMT는 모델 앞단에 STEM을 추가합니다. 이미지를 패치로 짤라서 transformer에 전달하는 ViT와 다르게 STEM을 거쳐서 해상도를 감소시키고 local feature 정보를 학습하여 transformer에 전달합니다.
위 그림을 살펴보면 STEM을 거쳐서 해상도가 1/2가 되고 stride 2인 2x2 conv를 거쳐서 입력 해상도 대비 1/4 크기의 해상도를 transformer에 전달합니다.
아래 실험은 DeiT에 Stem을 추가하여 성능 향상 실험결과입니다.
(2) Stage-wise architecture
stage 앞단에 stride=2 2x2 conv를 추가하여 해상도를 감소시킵니다. 따라서 각 stage마다 multi-scale 정보를 학습할 수 있습니다. stage 구조를 활용하면 모델의 성능을 향상시킬 수 있습니다.
(3) Local Perception Unit(LPU)
ViT는 image를 patch로 짤라 flatten하여 tranformer에 전달하기 때문에 local relation을 무시했습니다. 이를 개선하기 위하여 transformer에 전달하기 전에 Depthwise conv를 적용합니다. Depthwise Conv는 각각의 채널에 독립적으로 3x3 conv를 수행하는 conv를 의미합니다.
(4) Lightweitht Multi-head Self-attention
입력값을 fc layer에 전달하여 Key와 Value를 생성하기 전에 stride=k인 depthwise conv를 적용합니다. stride=k conv에 의해 해상도가 낮아져 연산량이 감소할 뿐만 아니라 local 정보를 학습할 수 있습니다.
또한 Softmax 연산내에 있는 energy에 relative position bias B를 더해줍니다. B는 무작위 초기화된 학습가능한 파라미터 입니다.
(5) Inverted Residual FFN(IRFFN)
MHSA 이후 FFN을 Inverted Residual FFN으로 변경합니다.
Inverted Residual FFN은 1x1 conv로 채널을 확장하여 3x3 DW Conv 연산을 수행하고 다음 1x1 conv에서 원래 채널로 되돌립니다. 이는 MobileNetv3에서 제안되었던 방법인데 ReLU의 정보 손실을 방지하기 위해 제안되었으며 실제로 성능향상이 있다는 것을 MobileNetV3 논문에서 보여줍니다.
아마 기존의 FFN이 채널을 확장하여 중간 fc layer 연산을 수행하므로 동일한 방식으로 conv를 적용하지 않았나 싶네요
Experiment
my github