논문 읽기/Zero shot

[논문 읽기] CaGNet(2020), Context-aware Feature Generation for Zero-shot Semantic Segmentation

AI 꿈나무 2021. 10. 5. 20:56
반응형

Context-aware Feature Generation for Zero-shot Semantic Segmentation

 PDFZero-shot segmentation, Zhangxuan, at al, arXiv 2020

 

Summary

 unseen object를 segment하는 zero-shot semantic segmentation 논문이다. 해당 논문은 genetation에 context-aware feature를 주입하여 unseen image를 더 정확하게 생성하여 성능을 끌어 올린다. pixel-wise contextual information을 포착하기 위해 contextual module을 제안하는데, 이 semantic word embedding과 contextual information을 함께 generator에 전달하여 더 다양한 image를 생성할 수 있다. 그러면 핵심 아이디어는 contextual information을 생성하는 contextual module로 볼 수 있다.

 

ZS3Net

 generative based method로 segmentation을 수행하는 이전 논문은 ZS3Net인데, ZS3Net은 두 가지 단점이 있다.

 

(1) ZS3Net은 다양한 feature을 생성하기 위하여 하나의 semantic word embedding에 random noise를 추가한다. 이 방법은 generator이 random noise를 무시하고 각 semantic word embedding에 해당하는 다양성이 제한된 채로 이미지를 생성한다.

 

(2) ZS3Net(GC)는 spatial object arragnement를 encode 하기 위하여 relational graph를 활용하지만 contextual cue는 object-level만 고려하고 spatial object arrangement에 제한되어 있다. 따라서 unseen feature를 생성할 때, relational graph는 unseen categories는 포함하지 않는다.

 

CaGNet, Contextual-aware feature Generation model

 GaGNet은 feature을 생성할 때, pixel-wise contextual information을 고려한다.

 

 pixel의 contextual information은 주변 픽셀로부터 추론된 정보를 의미한다. 이는 object arrangment만을 고려하려 제한된 ZSNet과는 다른 점이다. ZS3Net은 generator의 입력값으로 random noise와 semantic word embedding을 입력으로 받는데, CagNet은 semantic word embedding과 pixel-wise contextual latent code를 generator로 전달한다. contextual latent code는 논문제어 제안하는 Contextual Module(CM)으로부터 얻어진다. CM은 segmentation backbone의 출력값을 입력으로 받아 pixel-wise real feature와 이에 해당하는 모든 픽셀에 대한 pixel-wise contextual latent code를 출력한다. 

 

 CM 내부에는 context selector이 구현되어 있는데 이는 different pixel에 대한 contextual information의 서로 다른 scale에 가중치를 적응적으로 가하기 위함이다. 충분한 contextual information이 generator로 전달되어 feature generation의 모호함을 출기 때문에, semantic word embedding과 pixel-wise contextual latent code가 pixel-wise real feature을 reconstruct 될 것이다. 다른 말로 하면, 출력 pixel-wise feature와 입력 pixel-wise contextual latent code 사이의 one-to-one correspondence를 구축한다. 입력 latent code와 출력값 사이의 bijection은 mode collapse problem을 완화할 수 있다. 따라서 논문에서 제안하는 모델은 contextual latent code를 변화하면서 하나의 semantic word embedding으로부터 더 다양한 feature을 생성할 수 있다. 또한 random sampling을 통해 다양한 contextual latent code를 얻기 위해 contextual latent code가 gaussian distribution을 follow 하도록 한다. 그러므로 segmentation network와 feature generation network은 contextual module와 classifier에 의해 연결된다.

 

 CaGNet이 ZS3Net보다 나은 점은 1) object lebel contextual information 대신 더 유익한 pixel-wise contextual information을 활용한다. 2) contextual information을 stochastic sampling을 지원하는 latent code로 encode 한다. 따라서 unseen feature을 생성할 때, unseen category의 명확한 contextual information이 필요하지 않다.

 

overview

 

 CaGNet은 임의의 segmentation network가 적용된다. 논문에서는 Deeplabv2를 사용하는데 backbone E와 classifier C로 이루어져 있다. 입력 이미지가 주어지면 backbone은 이의 real feature map을 출력하고, classifier로 전달되어 segmentation 결과를 얻는다.

 

 segmentation 모델이 unseen categoriy에 대한 object를 segment 하기 위하여 generagor G가 unseen category를 위한 feature를 생성하도록 한다. 위 그림을 보면 generator는 semantic word embedding map과 latent code map을 입력으로 받아 fake feature를 출력한다. 그리고나서 공유된 1x1 conv layer를 지닌 classifier C와 discriminator D는 real/fake feature를 입력 받아 discriminator과 segmentation results를 각각 출력한다. 주의할 점은 C가 feature generation network와 segmentation network와 공유된다는 것이다. generator이 다양한 context-aware feature를 생성하기 위해 Contextual Module을 segmentation backbone 이후에 추가한다. 이 CM은 contextual information을 얻기 위함이다. contextual information은 latent code로 encode되어 G의 안내자 역할을 한다. 그ㄹ므로 segmentation network {E, CM, C}와 feature generation network {CM, G, D, C}는 CM과 C에 의해 연결된다.  

 

Contextual Module(CM)

 

 CM은 segmentation backbone 출력값 Fn(hxwxl)을 입력받아 Fn의 각 픽셀에 대한 pixel-wise contextual information을 취합할 목적이다. pixel의 pixel-wise contextual information은 주변 pixel의 aggregated information을 의미한다. 이를 달성하기 위하여 CM은 Fn을 입력 받아 Fn과 동일한 크기의 contextu map을 생성한다. context map의 각 pixel-wise vector는 Fn에 대한 pixel의 pixel-wise contextual information을 포함한다. CM의 dilated design을 고려하여 두 가지 원식을 고려한다. 1) multi-scale contextu는 더 나은 feature generation을 위해 보존되어야 한다. 2) contexts와 pixel 사이의 one-to-one correspondence는 유지되어야 한다. 이는 pooling layer가 사용되지 않는 것을 의미한다. 이 원칙에 기반하여 dilated conv layer를 사용한다. 왜냐하면 이는 spatial resolution 손실 없이 receptive fields를 확장할 수 있기 때문이다. 위 그림에서 보듯이 3개의 dilated conv를 사용한다. 연속적은 context map을 적용하는 것은 서로 다른 scale의 contextual information을 포착할 수 있다. 깊은 context map은 더 큰 receptive fields를 갖기 땜ㄴ이다.

 

 또, context selector를 사용한다. 3가지 context map을 통합하기 위함이다. 직관적으로 different pixel의 feature는 작은 receptive field 또는 큰 receptive field 의 contextual information에 의해 지배된다. contextual selector는 각 pixel에 대한 적합한 scale의 contextual information을 선택하기 위하여 서로 다른 pixel에 대한 scale weight를 적응적으로 학습한다. 

 

 Contextual latent code를 얻기 위하여 contextual selector의 출력값에 1x1 conv를 적용하여 uz와 oz를 얻는다. uz와 oz는 pixel-wise vector를 의미한다. 그리고나서 contextual latent code z_ni는 가우시안 분포 N(uz, oz)로부터 z = uz + eoz를 사용하여 얻어진다. inference 동안에 stochastic sampling을 위하여 가우시안 분포가 되도록 KL-divergence loss를 사용한다.

 

 

 이 pixel-wise contextual latent code가 이 pixel의 contextual information을 encode한다고 가정한다. 예를 들어, 나무 근처의 고양이 안에 존재하는 pixel이 주어지면 이의 contextual latent code는 cat의 근처 local region, 고양이의 상대적인 위치, 고양이의 posture, tree 같은 배경 객체를 encode 한다.

 

 추가적으로 n번째 이미지에 대한 모든 z_ni를 latent code map으로 encode 한다. 이 말은 CM이 zn에 sigmoid를 취하여 fn과 곱하고 fn에 더하여 새로운 feature map을 출력한다. 이 방법으로 segmentation network에서 residual attention module 역할을 수행한다. 

 

Context-aware Feature Generator

 label에 해당하는 word embedding map W와 contextual latent code Z가 concat되어 ganerator로 입력된다. 우선 seen 으로 학습을 진행하는데 detail 한 feature를 생성하기 위해 reconstruction loss가 사용된다.

 

 

 생성된 feature를 regulate하기 위하여 classification loss와 adversarial loss또한 사용한다. 따라서 down sample된 label map y는 x와 동일한 resolution을 갖으며 y는 x와 pixel-wise한 one-to-one correspondense 관계를 갖는다.

 

 

 

 이렇게 학습된 generator는 latent code z와 semantic word embedding w를 G로 전달하여 unseen, seen의 카테고리에 해당하는 feature를 생성한다. 

 

Optimization

 optimization 과정은 two step으로 이루어져 있다.

 

1) training

 segmentation network와 feature generation network는 seen class의 segmentation mask와 image data로 학습된다. 모든 네트워크 모듈 {E, CM, G, D, C}는 갱신되고 손실함수는 다음과 같이 구성된다.

 

 

2) fine-tuning

 segmentation network가 unseen categories에도 일반화되기 위하여 seen과 unseen data를 모두 사용하여 학습한다. unseen과 seen의 pixel-wise word embedding을 무작위로 쌓아서 m-th word embedding map을 구축한다. 이에 해당하는 label map도 준비합니다. 대략적으로 각 word embedding map의 seen과 unseen pixel을 동일하게 사용한다. 이 방법이 모델의 성능을 향상시켜줌. 그리고나서 가우시안 분포로부터 추출한 pixel wise latent code map과 semantic word embedding으로 fake feature map을 생성한다. 이 과정에서 E와 CM은 freeze 한다. real visual feature이 없기 때문이다. 단지 G, D, C만이 업데이트 관다. 따라서 손실 함수는 다음과 같이 사용한다.

 

 

 backbone E의 초기값은 pre-trained resnet 101을 사용하고 training step에서 충분히 학습한다. 그리고 real feature와 fake feature에 기반한 network optimization을 balance 하기 위해 매 100 반복마다 training step과 fine tunning step을 번갈아 실행한다.

 

Experiment

 

 


my github

 

Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch

공부 목적으로 논문을 리뷰하고 해당 논문 파이토치 재구현을 합니다. Contribute to Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch development by creating an account on GitHub.

github.com

 

반응형