Class-Prototype Discriminative Network for Generalized Zero-Shot Learning
https://ieeexplore.ieee.org/abstract/document/8966463
Zero-Shot Classification 논문.
논문의 모델은 3가지 network와 3가지 loss로 이루어져 있다.
3가지 Network
(1) feature extractor(ResNet101)
이미지로부터 visual feature를 추출하는 network
(2) generative network
class의 semantic representation을 입력 받아 visual prototype을 생성한다.
(3) metric network
visual feature와 prototype을 concat 한 것을 입력 받아 similarity score를 출력한다.
3가지 Loss
(1) Label Prediction Loss
metric network가 출력한 relation score와 원핫인코딩된 ground-truth label 사이 MSE를 계산한다
(2) Prototype-Sample Metric Loss(PSML)
추가적인 discriminative power를 활용하기 위한 loss이다. sample과 class-prototype 사이의 similarity가 높아지도록 한다.
(3) Class-Prototype Scatter Loss(CPSL)
class-prototype 사이의 simiarity를 감소시킨다. prototype의 discriminative한 성질을 얻기 위함.
최종 loss는 다음과 같다.
prediction은 샘플 x와 prototype의 relation score가 가장 높은 class를 선택한다.