논문 읽기/Zero shot

[논문 읽기] Zero-shot Learning via Shared-Reconstruction-Graph Pursuit(2017)

AI 꿈나무 2021. 11. 22. 23:42
반응형

Zero-shot Learning via Shared-Reconstruction-Graph Pursuit

 PDFZero-Shot Classification, Zhao et al, arXiv 2017

 

Summary

 

 space shift problem을 정의한다. image feature space, attribute space, word vector space 사이의 knowledge structure가 다르다는 문제점이다. 기하학적인 inconsistent가 발생한다는 것인데, 각 space에서 structure가 서로 다른 데이터로 학습되기 때문에 발생한다. 즉, word embedding의 class 관계를 직접 image space로 transfer하면 문제점이 발생한다는 것.

 

 

 논문에서 제안하는 방법은 graph를 사용하여 image feature space와 semantic embedding space 사이를 연결하는 knowledge structure를 새롭게 탄생시킨다. 이 그래프를 사용하여 image prototype과 semantic prototype을 reconstruct 한다.

 

 image prototype은 다음과 같이 계산한다. 클래스에 해당하는 샘플의 평균값으로 prototype을 생성. 우리는 unseen image가 없으므로 unseen에 대한 prototype이 존재하지 않는데, 이를 graph를 이용하여 reconstruct 한다.

 

 

 semantic prototype은 word embedding을 사용한다. 하지만 image feature space와 inconsitent가 발생하므로 다음과 같이 재구성 한다.

 

 

 a는 그래프의 엣지를 나타내는데, 그래프를 구성해야지 semantic prototype을 재구성 할 수 있다.

 

 graph를 생성하려면 unseen에 대한 image prototype과 a(weight matrix)가 필요하다. 이 둘을 학습시켜야 한다.

 

 다음의 loss로 학습시킨다.

 

 

 문제점이 발생하는데, 학습시켜야할 변수가 unseen에 대한 image prototype, weight matrix a 두 가지 이므로 convex하지 않다. 따라서 alternative optimization을 사용한다. F를 고정시키고 a 먼저 학습 시킨후에, a를 고정시키고 F를 학습시킨다.

 

 

 unseen에 대한 F가 존재하지 않으므로 첫 번째 iteration에서 감마를 0으로 조정.

 

 

 

 이외에도 sparsity, locality regularization을 적용한다.

 

 sparsity regularization은 small dataset에서만 적용하는데, prototype을 다른 prototype의 linear combination으로 재구성하는 과정에서 strong한 connection이 존재하는 class에 대한 prototype만을 사용하겠다는 거다. LASSO 처럼 regularization term을 추가하여 값이 작은 a를 억제시킨다.

 

 

 이 D는 small dataset에서는 1인 identity matrix이고 large matrix에서는 distance에서 변하는 값을 사용한다.

 

 tiger와 unbrella는 적은 관계를 갖고 있다. 이에 해당하는 prototype에 높은 D 값을 줘서 강하게 regularization 하겠다는 의미이다.

 

 

 g는 거리에따라 증가하는 함수이다.

 


my github

 

Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch

공부 목적으로 논문을 리뷰하고 해당 논문 파이토치 재구현을 합니다. Contribute to Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch development by creating an account on GitHub.

github.com

 

반응형