안녕하세요, 오늘 읽은 논문은 Swin Transformer: Hierarchical VIsion Transformer using Shifted Windows 입니다.
Swin Transformer는 transformer 구조를 object detection에 적용한 모델입니다. text에 비해서 image는 어떻게 patch로 분할하느냐에 따라서 엄청나게 다양한 variant가 존재하고 이미지들의 resolution이 크다는 차이점이 있습니다. 이 차이첨을 다루기 위해 Sfited Windows를 사용하여 hierarchical transformer로 representation을 학습합니다.
shifted window를 활용한 hierarchical transformer는 어떤 장점이 있을까요? 바로 object detection에서 사용하는 FPN 또는 Segmentation의 U-Nset처럼 계층적인 정보를 활용할 수 있습니다. 객체 검출에서 다양한 scale의 object를 검출하기 위해 FPN의 feature pyramid 구조를 사용합니다. 또한 Segmentation에서 detail한 정보를 활용하기 위해 피쳐맵을 축소했다가 다시 확대하여 mask를 생성합니다. 이처럼 object detection과 segmentation은 hierarchical representation을 활용하는 것이 중요합니다.
Swin transformer에서 제안하는 Shifted window는 각 window내에 여러 patch로 구성되어 있고, 이 window에 대해서만 self-attention을 계산합니다. 전체 영역에 대한 attention이 아닌, window로 묶인 영역만 attention을 계산하므로, 다르게 생각해보면 transformer에 inductive bias를 가한것으로도 생각해볼 수 있습니다(개인적인 생각입니다) ㅎㅎ
또한 stage가 진행될 수록 patch들은 병합되어 더 큰 patch로 이루어진 window를 사용합니다. 이는 좀 더 큰 객체를 잘 검출할 수 있다는 것으로 이해해볼 수 있습니다. 즉, stage가 진행될 수록 patch는 병합되고, 병합된 patch로 이루어진 window는 이미지 내에서 더 큰 영역을 담당합니다.
반대로 초기 계층에서는 window 내에 포함되어 있는 patch의 크기가 작기 때문에, 이미지 내에서 적은 영역을 담당합니다. 이는 작은 객체를 검출하는데에 용이하다고 생각해볼 수 있습니다.
또한 swin transformer의 또 다른 장점은 연산량에 입니다. ViT의 self-attention은 입력 이미지에 qudratic한 연산량을 지니지만, swin transformer는 이미지에대해 선형적인 연산량을 갖습니다. 이 덕분에 비교적 적은 연산량으로 모델의 크기를 키울 수 있을 뿐만아니라, inference가 빨라진다는 장점이 있습니다.
ViT와 Swin Transformer
ViT는 현재까지도 Classification 분야에서 SOTA 성능을 기록하고 있습니다. 반면에 Swin Transformer는 Classification에서 성능은 뒤쳐지면서 OD와 Segmentation 분야에서 SOTA 성능을 기록합니다. classification 성능이 뛰어난 ViT를 backbone으로 사용하여 OD와 Segmentation task를 수행하면 Swin Transformer보다 성능이 뒤쳐지는데 이것은 왜그럴까요??
Swin Transformer는 계층적인 구조를 활용합니다. CNN에서 layer가 깊어지면서 입력 이미지의 resolution을 축소하는 것처럼 Swin Transformer도 layer가 깊어지면서 이미지의 resolution을 변경합니다. 서로 다른 scale information을 갖고 있으므로 OD의 FPN 구조를 사용할 수 있습니다. FPN 구조를 사용하여 muti-scale information 정보를 활용하기 때문에 OD와 Segmentation task에서의 성능이 ViT보다 우수할 수밖에 없습니다.
Overall Architecture
Swin transformer의 전체 구조입니다. 구조를 살펴보시면, stage가 진행될 때 마다 patch merging이 진행됩니다. 또한 Swin transformer block은 일반적인 multi-head self attention과 shifted windowing multi-head self attention으로 이루어져 있습니다.
연산량입니다. M은 window size를 의미합니다.
layer가 진행될수록 window의 위치가 shift 됩니다. 물론 W-MSA에서는 shift를 하지 않습니다.
window는 window size // 2 만큼 shift를 하는데, 이미지 범위를 벗어나면 padding을 하는 것이 아니라 다음의 trick을 사용합니다.
또 다른 trick은 self-attention을 계산할 때 relative position bias를 더해줍니다. 이 덕분에 positional embedding을 사용하지 않아도 되고, 이 relative position bias는 윈도우 내에서 patch의 상대적인 위치를 나타냅니다.
Performance
논문의 세부 구현사항이 이해 안되는 부분이 많네요...ㅎㅎ
WMSA는 전체 patch에 대하여 self-attention을 수행하는가, 아니면 window에 대하여 self-attention을 수행하는지 궁금하네요. 나중에 코드 뜯어보면서 공부해야겠습니다.
참고자료