논문 읽기/GAN

[논문 읽기] 구현 코드로 살펴보는 Pix2Pix(2016), Image-to-Image Translation with Conditional Adversarial Networks

AI 꿈나무 2021. 5. 20. 01:06
반응형

 PyTorch 코드와 함께 Pix2Pix를 살펴보도록 하겠습니다.

 

 Pix2Pix는 image를 image로 변환하도록 generator을 학습합니다. 예를 들어, generator의 입력값으로 스케치 그림을 입력하면 완성된 그림이 나오도록 학습할 수 있습니다. 기존 GAN과 비교하여 설명하자면, Pix2Pix는 기존 GAN의 noise 대신에 스케치 그림을 입력하여 학습을 하는 것입니다.

 

 어떻게 generator이 image to image를 생성하도록 학습시킬 수 있는지 살펴보겠습니다.

 

 일반적으로 generator은 스케치를 입력받아 가짜 이미지를 출력합니다. 이 가짜 이미지를 discriminator이 완성된 그림으로 식별하도록 목적 함수를 설계하여 학습을 진행하면 서서히 generator은 완성된 가짜 그림을 출력합니다. 기존 GAN과 비교하여 설명하면 noise를 스케치 그림으로 입력한다고 생각하면 됩니다.

 

 그림과 함께 설명하도록 하겠습니다.

 

Generator

 generator은 아래 그림을 입력으로 취합니다.

 

 generator의 목표는 아래 그림과 같은 출력값을 생성하여 discriminator을 속이는 것이 목적입니다.

 

 

 따라서 손실 함수는 다음과 같이 설계합니다.

 

loss_func = nn.BCELoss() # 이진 분류

fake_img = model_gen(mask) # mask를 입력 받아 가짜 이미지 생성
out_dis = model_dis(fake_img, conditional_data) # 가짜 이미지를 discriminator이 식별

g_loss = loss_func(fake_img, real_label) # discriminator이 가짜 이미지를 1로 식별하도록 학습

 

 가짜 이미지가 진짜 이미지와 일치해야 g_loss가 최소값을 갖으므로 진짜 이미지같은 가짜 이미지를 생성합니다.

 

 pix2pix는 위 코드의 손실 함수에 pixel loss를 추가합니다. 생성한 가짜 이미지와 진짜 이미지사이의 L1 loss를 계산합니다. 즉, 각 pixel값의 차이를 계산하는 것입니다.

 

loss_func_gan = nn.BCELoss() # 이진 분류 손실 함수
loss_func_pix = nn.L1Loss() # L1 손실 함수

lambda_pixel = 100 # 가중치

fake_b = model_gen(real_a) # 가짜 이미지 생성
out_dis = model_dis(fake_b, real_b) # 가짜 이미지 식별

gen_loss = loss_func_gan(out_dis, real_label) # 생성된 가짜 이미지를 discriminator이 진짜로 식별하도록
pixel_loss = loss_func_pix(fake_b, real_b) # 생성된 가짜 이미지와 진짜 이미지의 pixel값 차이

g_loss = gen_loss + lambda_pixel * pixel_loss # generator loss

 

 위 손실 함수의 의미를 해석해보겠습니다.

 

 손실 함수를 l1 loss만 사용하면 generator은 blur가 적용된 row-frequency(저해상도) 이미지를 생성합니다. 아래 그림에서 l1 loss만을 사용한 결과를 살펴볼 수 있습니다. 기존의 gan loss에 l1 loss를 추가하면 gan loss가 high frequency에 집중하도록 할 수 있습니다.

 

 

 Pix2Pix는 이미지 생성 성능을 높이기 위해서 generator 모델 구조로 UNet을 사용합니다. generator 코드를 한번 살펴보겠습니다. https://github.com/eriklindernoren/PyTorch-GAN 코드를 참고하여 재구현 했습니다.

# UNet
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]

        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels)),

        layers.append(nn.LeakyReLU(0.2))

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.down = nn.Sequential(*layers)

    def forward(self, x):
        x = self.down(x)
        return x
        
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels,4,2,1,bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU()
        ]

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.up = nn.Sequential(*layers)

    def forward(self,x,skip):
        x = self.up(x)
        x = torch.cat((x,skip),1)
        return x
        
# generator: 가짜 이미지를 생성합니다.
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64,128)                 
        self.down3 = UNetDown(128,256)               
        self.down4 = UNetDown(256,512,dropout=0.5) 
        self.down5 = UNetDown(512,512,dropout=0.5)      
        self.down6 = UNetDown(512,512,dropout=0.5)             
        self.down7 = UNetDown(512,512,dropout=0.5)              
        self.down8 = UNetDown(512,512,normalize=False,dropout=0.5)

        self.up1 = UNetUp(512,512,dropout=0.5)
        self.up2 = UNetUp(1024,512,dropout=0.5)
        self.up3 = UNetUp(1024,512,dropout=0.5)
        self.up4 = UNetUp(1024,512,dropout=0.5)
        self.up5 = UNetUp(1024,256)
        self.up6 = UNetUp(512,128)
        self.up7 = UNetUp(256,64)
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(128,3,4,stride=2,padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8,d7)
        u2 = self.up2(u1,d6)
        u3 = self.up3(u2,d5)
        u4 = self.up4(u3,d4)
        u5 = self.up5(u4,d3)
        u6 = self.up6(u5,d2)
        u7 = self.up7(u6,d1)
        u8 = self.up8(u7)

        return u8

 

Discriminator

 Pix2Pix는 discriminator을 patch gan을 사용합니다. patch gan은 출력값으로 하나의 scalar 값을 출력하는 것이 아니라 이미지를 분할한 피쳐맵을 출력합니다. 좀 더 구체적으로 설명하자면 256x256 크기의 이미지를 입력 받은 경우에 30x30의 출력값을 생성하는 것입니다. 원래 이미지가 30x30 feature map으로 분할하여 각 pixel을 real, fake 식별하는 것입니다.

 

 또한 patch gan은 조건부 gan 이므로 조건부 데이터를 입력으로 받습니다. 아래 그림 두 개를 한꺼번에 입력으로 받는 것입니다.

 

 discriminator의 구현 코드를 살펴보겠습니다.

class Dis_block(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
    
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        x = self.block(x)
        return x


# Discriminator은 patch gan을 사용합니다.
# Patch Gan: 이미지를 16x16의 패치로 분할하여 각 패치가 진짜인지 가짜인지 식별합니다.
# low-frequency에서 정확도가 향상됩니다.

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        self.stage_1 = Dis_block(in_channels*2,64,normalize=False)
        self.stage_2 = Dis_block(64,128)
        self.stage_3 = Dis_block(128,256)
        self.stage_4 = Dis_block(256,512)

        self.patch = nn.Conv2d(512,1,3,padding=1) # 16x16 패치 생성

    def forward(self,a,b):
        x = torch.cat((a,b),1)
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)
        x = self.stage_4(x)
        x = self.patch(x)
        x = torch.sigmoid(x)
        return x

 

 위 코드는 3x256x256크기의 두 이미지를 입력 받아, 하나로 결합한 뒤에 1x16x16 feature map을 출력하도록 설계한 것입니다. 즉, 기존 이미지를 16x16의 patch로 분할한 것이라고 생각해볼 수 있습니다. 각 patch가 real인지 fake인지 식별합니다.

 

 왜 patch gan을 사용했을 까요?? patch gan을 사용하면 high-frequency의 정확도가 향상됩니다. high-frequency는 고해상도로 생각할 수 있으며, high-frequency의 정확도가 향상된다는 말은 디테일한 부분이 향상된다는 것으로 해석할 수 있습니다. 분할된 patch를 기준으로 real, fake를 식별하므로 좀 더 디테일한 결과를 출력할 수 있는 것입니다.

 

 이미지를 많이 분할할 수록 high-frequency가 향상됩니다.

 

 discriminator의 손실 함수는 기존의 gan과 동일합니다. 주의할 점은 real label과 fake label을 분할된 피쳐맵 1x16x16과 동일한 크기로 생성해야 합니다.

 

patch = (1,16,16)
loss_func_gan = nn.BCELoss() # 이진 분류 손실 함수

# patch label, 16x16 크기의 label 생성
real_label = torch.ones(ba_si, *patch, requires_grad=False).to(device)
fake_label = torch.zeros(ba_si, *patch, requires_grad=False).to(device)

out_dis = model_dis(real_b, real_a) # 진짜 이미지 식별, 1x16x16 크기 벡터 출력
real_loss = loss_func_gan(out_dis,real_label)

out_dis = model_dis(fake_b.detach(), real_a) # 가짜 이미지 식별, 1x16x16 크기 벡터 출력
fake_loss = loss_func_gan(out_dis,fake_label)
        
d_loss = (real_loss + fake_loss) / 2.

 

Performance

 아래 그림은 pix2pix를 구현하여 facade dataset으로 학습시킨 generator이 생성한 가짜 이미지입니다. 오른쪽 그림이 가짜 이미지입니다. 100epoch만 학습시켰는데도 얼추 비슷한 이미지를 생성하고 있습니다.

 

 

 pix2pix를 구현하고 학습한 코드는 아래 깃허브에서 확인하실 수 있습니다.

 

Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch

공부 목적으로 논문을 리뷰하고 파이토치로 구현하고 있습니다. Contribute to Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch development by creating an account on GitHub.

github.com


참고자료

[1] https://arxiv.org/abs/1611.07004

반응형