논문 읽기/GAN

[논문 읽기] 코드로 살펴보는 GAN(2014), Generative Adversarial Nets

AI 꿈나무 2021. 5. 17. 22:24
반응형

 오늘 읽은 논문은 GAN, Generative Adversarial Nets 입니다. 파이토치 코드와 함께 살펴보도록 하겠습니다.

 

 GAN은 Generator, Discriminator 두 개의 신경망으로 이루어져 있습니다.

 

Generator

 Generator은 무작위로 생성한 noise로 가짜 이미지를 생성합니다.

 

Generator PyTorch 구현 코드

# generator: noise를 입력받아 이미지를 생성합니다.
class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.nz = params['nz'] # 입력 노이즈 벡터 수, 100
        self.img_size = params['img_size'] # 이미지 크기, 1x28x28

        self.model = nn.Sequential(
            *self._fc_layer(self.nz, 128, normalize=False),
            *self._fc_layer(128,256),
            *self._fc_layer(256,512),
            *self._fc_layer(512,1024),
            nn.Linear(1024,int(np.prod(self.img_size))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_size)
        return img

    # fc layer
    def _fc_layer(self, in_channels, out_channels, normalize=True):
        layers = []
        layers.append(nn.Linear(in_channels, out_channels)) # fc layer
        if normalize:
            layers.append(nn.BatchNorm1d(out_channels, 0.8)) # BN
        layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
        return layers

# check
params = {'nz':100,
          'img_size':(1,28,28)}
x = torch.randn(16,100).to(device) # random noise
model_gen = Generator(params).to(device)
output = model_gen(x) # noise를 입력받아 이미지 생성
print(output.shape)

 

Generator의 목적은 Discriminator이 식별하지 못하는 가짜 이미지를 생성하는 것입니다. 따라서 목적 함수는 Discriminator이 가짜 이미지를 실제 이미지로 식별할 확률이 1이 되도록 설계합니다.

 

 Generator은 아래의 목적 함수를 최소화 하는 방향으로 학습합니다. 가짜 이미지를 1(진짜 이미지)로 식별한다면 목적 함수는 0이 됩니다. 목적 함수를 최소화 하도록 generator을 학습시킨 다면, 진짜 같은 가짜 이미지를 생성해낼 수 있습니다.

 

 

 generator은 진짜 이미지 분포를 추정하여, noise를 입력 받아 추정한 분포로 이미지를 생성하는 것으로 생각해볼 수 있습니다. 진짜 이미지 분포와 근사화한 분포를 추정한다면, 진짜 같은 이미지를 생성해낼 수 있습니다.

 

 

 검은 점선은 train_data 분포, 초록 선은 추정한 분포를 나타냅니다. 학습이 진행될 수록 추정한 분포가 train_data 분포와 같아지는 것을 확인할 수 있습니다. 이 경우에 discriminator이 구별하지 못하는 가짜 이미지를 생성할 수 있습니다.

 

 

목적 함수 PyTorch 코드

p_fake = discriminator(generator(noise))) # discriminator에 가짜 이미지를 입력하여 확률값 출력
loss_g = torch.log(1.-p_fake).mean() # 출력한 확률값이 1이 되도록 손실 함수를 설계

 

Discriminator

 Discriminator은 학습 이미지와 generator이 생성한 가짜 이미지를 분류합니다. 

 

Discriminator PyTorch 코드

# discriminator: 진짜 이미지와 가짜 이미지를 분류합니다.
class Discriminator(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.img_size = params['img_size'] # 이미지 크기, 1x28x28

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_size)), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0),-1)
        x = self.model(x)
        return x

# check
x = torch.randn(16,1,28,28).to(device)
model_dis = Discriminator(params).to(device)
output = model_dis(x)
print(output.shape)

 

 Discriminator의 목적은 가짜 이미지를 가짜로 분류하고, 학습 데이터는 진짜 이미지로 분류하는 것입니다. 목적 함수는 진짜 이미지를 진짜로 식별하고, 가짜 이미지를 가짜로 식별할 수 있도록 설계합니다. 아래의 목적 함수가 최대값을 갖도록 학습합니다.

 

 

 D(x)는 진짜 이미지, D(G(z))는 가짜 이미지를 의미합니다. D(x)=1, D(G(z))=0이 되어야 목적 함수는 최대값을 갖습니다.

 

 Generator이 진짜 같은 이미지를 생성해낸다면 discriminator은 항상 p=1/2 값을 출력하게 됩니다. 식별하지 못하기 때문에 진짜, 가짜 둘 중 하나의 값을 랜덤으로 출력하기 때문입니다.

 

목적 함수 PyTorch 코드

p_real = discriminator(진짜 이미지) # 진짜 이미지를 입력받아 출력한 확률
p_fake = discriminator(가짜 이미지) # 가짜 이미지를 입력받아 출력한 확률


loss_real = -1 * torch.log(p_real)   # p_real = 1이 되어야 최대값을 갖습니다.
loss_fake = -1 * torch.log(1.-p_fake) # p_fake = 0이 되어야 최대값을 갖습니다.
loss_d    = (loss_real + loss_fake).mean()

 

Performance

 

 Generator이 충분히 학습된 후에 generator이 생성하는 가짜 이미지 입니다.

 

 GAN PyTorch 구현 코드는 아래 포스팅에서 확인하실 수 있습니다.

 

 

[논문 구현] PyTorch로 GAN(2014) 구현하고 학습하기

 안녕하세요! 이번 포스팅에서는 PyTorch로 구현한 GAN을 MNIST dataset으로 학습한 후, 학습된 generator이 생성한 가짜 이미지를 확인해보겠습니다.  작업 환경은 Google Colab에서 진행합니다. 전체 코드

deep-learning-study.tistory.com


참고자료

[1] http://intelligence.korea.ac.kr/members/wschoi/seminar/tutorial/mnist/pytorch/gan/GAN-%ED%8A%9C%ED%86%A0%EB%A6%AC%EC%96%BC/

[2] https://arxiv.org/pdf/1406.2661.pdf

반응형