Isometric Propagation Network for Generalized Zero-Shot Learning
https://arxiv.org/abs/2102.02038
Zero-Shot Learning이 잘 되기 위해서는 semantic space와 visual space 사이 alignment가 중요하다.
흔히 domain shift problem이라고 불리는데, semantic information(word embedding or attributes)와 visual feature(이미지로부터 resnet 추출한 feature) 는 학습이 진행된 domain이 다르므로 class간의 relation에 차이가 존재한다. 이 domain shift problem을 해결하기 위해 여러 논문이 제안되었는데, 이 논문은 좀 더 정교하게 설계한 framework라고 볼 수 있다.
논문의 method는 두 가지(Isometric Propagation Network, Episodic training)로 구성되어 있다.
(1) Isometric Propagation Network
구성 요소는 visual prototype, semantic prototype, category graph 이다.
two space 상의 class 별 prototype이 갱신되가고 graph 또한 갱신된다. 따라서 초기값이 필요하다.
visual prototype의 초기값은 class representation의 평균값으로 사용. 아래 식에서 W는 semantic prototype과 차원을 동일하게 해주는 transformation matrix이다.
semantic prototype은 semantic representation을 사용한다.
prototype을 초기화했으면 graph 또한 초기화 한다. node는 class이며 edge는 prototype의 similarity가 일정값 이상한 경우 생성된다. visual space에서의 graph와 semantic space의 graph가 존재한다. 함수 c는 cosine similarity를 계산하는 함수이다.
이제 prototype과 graph를 갱신해야 한다.
다음 step에서 prototype은 node와 연결된 edge prototype의 가중합으로 갱신된다. attention module이 사용되는데 이따가 설명하겠다. 아래 식과 같이 semantic prototype도 동일하게 갱신된다.
attention module은 다음과 같이 정의된다. cosine similarty 를 계산하고 soft max로 전달해주는 형태.
prototype을 원하는 step 까지 갱신했으면, visual prototype과 semantic prototype을 concat하여 최종 prototype을 생성한다.
이 prototype이 classification에 사용된다.
image와 prototype을 MLP로 전달하는데 왜 visual feature가 아닌 image를 사용할까? 의문.
두 space상에 생성된 prototype들 사이의 relation이 동일해지도록 consistency loss를 추가한다. 이는 두 space상의 prototype 사이 KL divergence 를 사용.
(2) Episodic training
meta learning에서 사용되는 episodic training이다 training set에서 sub set을 설정해 test로 사용한다. 이렇게 학습시키면 새로운 task에 쉽게 적응할 수 있다. 실제 실험 결과도 성능 향상을 보여줌.
재밌는 점은 train loss가 아니라 val loss를 기준으로 update 한다. meta learning에서 자주 사용하는 방법인거 같은데, overfitting을 방지하고 generalization 성능을 높여준다.
위 loss에 앞서 설명한 consistency loss를 추가한 것이 전체 loss이다.
Experiment
흥미로운 결과만 가져와봄,
visual space, semantic space 상의 prototype을 사용하지 않고, 한가지만 사용한 실험에서 semantic prototype 만을 사용한 것이 성능이 더 좋았다. 물론 둘다 사용한게 더 좋았음 ㅎㅎ. visual prototype은 unseen에 대한 sample을 사용하지 않는 반면에, semantic prototype은 unseen에 대한 semantic representation 정보를 지니고 있기 때문이라고 설명한다. 즉, 둘중 하나의 prototype을 사용해야 한다면 semantic representation을 활용해 prototype을 만드는 것이 낫다는 말!! 내 연구에 참고를 할 수 있을 듯? 근데 사실 당연한 말이긴 하다. zero shot은 semantic 정보를 활용해서 seen으로 학습한 지식을 unseen으로 transfer 하는게 기본 컨셉. 두개의 prototype을 사용했을 때 성능이 높은 이유는, visual prototype이 visual feature와 semantic feature을 연결해주고 추가적인 정보를 제공하기 때문이라고 설명한다.
재밌는 점은 consistency loss에 관한 부분이다. visual attention module을 semantic space에 적용하여 similarity가 동일해지도록 했는데, 이는 feature도 동일해지므로 성능에 악영향. consistency loss는 output similarity만 동일해지도록 할 뿐, similarity 계산 전 feature는 다르다. 즉, 두 space상의 각 class에 대한 feature는 꼭 같은 필요가 없다는 것 !!!
https://github.com/Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch