논문 읽기/GAN

[논문 읽기] 구현 코드로 살펴보는 CGAN(2014), Conditional Generative Adversarial Nets

AI 꿈나무 2021. 5. 18. 15:29
반응형

 오늘 읽은 논문은 CGAN(2014), Conditional Generative Adversarial Nets 입니다.

 

 GAN에 대한 배경지식이 있다고 가정하여 포스팅을 작성합니다. GAN 논문 리뷰는 아래 포스팅에서 살펴보실 수 있습니다.

 

[논문 읽기] 코드로 살펴보는 GAN(2014), Generative Adversarial Nets

 오늘 읽은 논문은 GAN, Generative Adversarial Nets 입니다. 파이토치 코드와 함께 살펴보도록 하겠습니다.  GAN은 Generator, Discriminator 두 개의 신경망으로 이루어져 있습니다. Generator  Generator은..

deep-learning-study.tistory.com

 

 CGAN은 GAN에 조건부 데이터를 입력하여 성능을 향상시킨 모델입니다.

 

 조건부 데이터는 무엇을 의미할까요?

 

CGAN

 기존 GAN은 generator에 noise와 discriminator에 img만 입력해주었습니다. 조건부 데이터는 레이블을 갖는 noise와 레이블을 갖는 img을 말합니다. 즉, CGAN은 GAN에 레이블이라는 추가적인 정보를 주어서 성능을 향상시키는 것입니다.

 

 

 이미지와 레이블을 GAN의 입력값으로 전달하고, GAN 내부에서 이미지와 레이블을 결합합니다. 어떻게 결합하는지 PyTorch 코드로 살펴보겠습니다.

 

generator

# Generator: 가짜 이미지를 생성합니다.
# noise와 label을 결합하여 학습합니다..

class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.num_classes = params['num_classes'] # 클래스 수, 10
        self.nz = params['nz'] # 노이즈 수, 100
        self.input_size = params['input_size'] # (1,28,28)

        # noise와 label을 결합할 용도인 label embedding matrix를 생성합니다.
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        self.gen = nn.Sequential(
            nn.Linear(self.nz + self.num_classes, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128,256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024,int(np.prod(self.input_size))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # noise와 label 결합
        gen_input = torch.cat((self.label_emb(labels),noise),-1)
        x = self.gen(gen_input)
        x = x.view(x.size(0), *self.input_size)
        return x
        
# check
x = torch.randn(16,100,device=device) # 노이즈
label = torch.randint(0,10,(16,),device=device) # 레이블
model_gen = Generator(params).to(device)
out_gen = model_gen(x,label) # 가짜 이미지 생성
print(out_gen.shape)

 

discriminator

# Discriminator: 가짜 이미지와 진짜 이미지를 식별합니다.
class Discriminator(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.input_size = params['input_size']
        self.num_classes = params['num_classes']

        self.label_embedding = nn.Embedding(self.num_classes, self.num_classes)

        self.dis = nn.Sequential(
            nn.Linear(self.num_classes+int(np.prod(self.input_size)),512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512,512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # 이미지와 label 결합
        dis_input = torch.cat((img.view(img.size(0),-1),self.label_embedding(labels)),-1)

        x = self.dis(dis_input)
        return x

# check
x = torch.randn(16,1,28,28,device=device)
label = torch.randint(0,10,(16,), device=device)
model_dis = Discriminator(params).to(device)
out_dis = model_dis(x,label)
print(out_dis.shape)

 

 신경망 내부에 embedding layer를 추가하여 label을 embedding 한 후에 forward 부분에서 embedding 된 label과 noise를  결합합니다.

 

손실함수는 기존의 GAN과 동일하게 사용합니다.

 

손실 함수

(1) generator 손실 함수

 generator은 위 손실 함수는 최소화하는 방향으로 학습을 진행합니다. generator이 생성한 가짜 이미지를 discriminator이 진짜라고 판단하면 D(G(z)) = 1이 됩니다. 따라서 generator은 D(G(z))=1이 되도록 학습을 진행하며, D(G(z))=1이 된다면 위 손실함수는 최소값을 갖습니다.

 

p_fake = discriminator(generator(noise))) # discriminator에 가짜 이미지를 입력하여 확률값 출력
loss_g = torch.log(1.-p_fake).mean() # 출력한 확률값이 1이 되도록 손실 함수를 설계

 

(2) discriminator 손실 함수

 

 discriminator은 위 손실 함수를 최대화 하는 방향으로 학습을 진행합니다. 진짜 데이터 x를 discriminator가 진짜라고 판단하면 D(x) = 1의 값을 출력합니다. 반대로 가짜 데이터 G(z)를 discriminator가 가짜라고 판단하면 D(G(z)) = 0의 값을 출력합니다. 즉, 위 손실 함수를 최대화하는 방향으로 학습하는 것은 가짜 데이터를 가짜 데이터로 식별하고, 진짜 데이터는 진짜 데이터를 식별할 수 있도록 파라미터를 갱신합니다.

 

 만약 generator이 진짜같은 가짜 이미지를 생성하게 된다면 진짜와 가짜를 구별할 수 없어 D(x) = D(G(z)) = 1/2 값을 갖습니다.

 

p_real = discriminator(진짜 이미지) # 진짜 이미지를 입력받아 출력한 확률
p_fake = discriminator(가짜 이미지) # 가짜 이미지를 입력받아 출력한 확률


loss_real = -1 * torch.log(p_real)   # p_real = 1이 되어야 최대값을 갖습니다.
loss_fake = -1 * torch.log(1.-p_fake) # p_fake = 0이 되어야 최대값을 갖습니다.
loss_d    = (loss_real + loss_fake).mean()

 

Performence

Generator이 생성한 가짜 이미지

 

 CGAN을 PyTorch로 구현하고 MNIST dataset으로 학습한 후에, generator이 생성한 가짜 이미지를 확인하는 예제 코드는 아래 깃허브에서 확인하실 수 있습니다.

 

Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch

공부 목적으로 논문을 리뷰하고 파이토치로 구현하고 있습니다. Contribute to Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch development by creating an account on GitHub.

github.com

 


참고자료

[1] https://arxiv.org/abs/1411.1784

반응형