오늘 읽은 논문은 DCGAN, Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks 입니다.
DCGAN은 generator와 discriminator 구조에 CNN을 적용한 것입니다. 이미지 특징을 포착하는 데에 특화되어 있는 CNN으로 모델 구조를 구성하므로 기존 FC layer로 구성되어 있는 GAN보다 성능이 탁월합니다. DCGAN의 generator와 discriminator이 어떤 구조를 갖고 있는지 구현 코드와 함께 살펴보겠습니다.
전체 구현 코드는 아래 깃허브에서 살펴보실 수 있습니다.
Generator
모델 구조
# Generator: noise를 입력받아 가짜 이미지를 생성합니다.
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
nz = params['nz'] # noise 수, 100
ngf = params['ngf'] # conv filter 수, 64
img_channel = params['img_channel'] # 이미지 채널, 3
self.dconv1 = nn.ConvTranspose2d(nz,ngf*8,4, stride=1, padding=0, bias=False) # x4
self.bn1 = nn.BatchNorm2d(ngf*8)
self.dconv2 = nn.ConvTranspose2d(ngf*8,ngf*4, 4, stride=2, padding=1, bias=False) # x2
self.bn2 = nn.BatchNorm2d(ngf*4)
self.dconv3 = nn.ConvTranspose2d(ngf*4,ngf*2,4,stride=2,padding=1,bias=False) # x2
self.bn3 = nn.BatchNorm2d(ngf*2)
self.dconv4 = nn.ConvTranspose2d(ngf*2,ngf,4,stride=2,padding=1,bias=False) # x2
self.bn4 = nn.BatchNorm2d(ngf)
self.dconv5 = nn.ConvTranspose2d(ngf,img_channel,4,stride=2,padding=1,bias=False) # x2
def forward(self,x):
x = F.relu(self.bn1(self.dconv1(x)))
x = F.relu(self.bn2(self.dconv2(x)))
x = F.relu(self.bn3(self.dconv3(x)))
x = F.relu(self.bn4(self.dconv4(x)))
x = torch.tanh(self.dconv5(x))
return x
# check
x = torch.randn(1,100,1,1, device=device)
model_gen = Generator(params).to(device)
out_gen = model_gen(x)
print(out_gen.shape)
generator은 noise를 입력 받아 가짜 이미지를 생성합니다. batch,100,1,1 크기의 noise를 3x64x64 크기의 이미지를 생성하기 위해서는 feature map 크기를 64배 해야합니다. 즉, noise에 x2 up_sample을 6번 적용해야 합니다. DCGAN은 up_sample 방법으로 transposed Conv를 사용합니다.
Transposed Conv는 위 표처럼 kernel_size, stride, padding을 적용하여 출력값 크기를 조절할 수 있습니다. Transposed Conv에 대한 내용은 아래 블로그에 자세하게 설명되어 있습니다.
손실 함수
손실 함수는 기존의 gan과 동일합니다.
위 손실 함수를 최소화 하는 방향으로 학습을 진행합니다. 위 손실 함수가 최소값을 갖으려면 D(G(z)) = 1의 값을 가져야 합니다. 1은 True, G(z)은 generator이 생성한 가짜 이미지를 의미하므로 discriminator이 가짜 이미지를 진짜 이미지로 식별하는 경우에 D(G(z))=1이 되어 손실 함수는 최소값을 갖습니다. 즉, 손실 함수를 최소화 하기 위해서는 generator이 진짜같은 가짜 이미지를 생성해야 합니다.
p_fake = discriminator(generator(noise))) # discriminator에 가짜 이미지를 입력하여 확률값 출력
loss_g = torch.log(1.-p_fake).mean() # 출력한 확률값이 1이 되도록 손실 함수를 설계
Discriminator
# Discriminator: 진짜 이미지와 가짜 이미지를 식별합니다.
class Discriminator(nn.Module):
def __init__(self,params):
super().__init__()
img_channel = params['img_channel'] # 3
ndf = params['ndf'] # 64
self.conv1 = nn.Conv2d(img_channel,ndf,4,stride=2,padding=1,bias=False)
self.conv2 = nn.Conv2d(ndf,ndf*2,4,stride=2,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(ndf*2)
self.conv3 = nn.Conv2d(ndf*2,ndf*4,4,stride=2,padding=1,bias=False)
self.bn3 = nn.BatchNorm2d(ndf*4)
self.conv4 = nn.Conv2d(ndf*4,ndf*8,4,stride=2,padding=1,bias=False)
self.bn4 = nn.BatchNorm2d(ndf*8)
self.conv5 = nn.Conv2d(ndf*8,1,4,stride=1,padding=0,bias=False)
def forward(self,x):
x = F.leaky_relu(self.conv1(x),0.2)
x = F.leaky_relu(self.bn2(self.conv2(x)),0.2)
x = F.leaky_relu(self.bn3(self.conv3(x)),0.2)
x = F.leaky_relu(self.bn4(self.conv4(x)),0.2)
x = torch.sigmoid(self.conv5(x))
return x.view(-1,1)
# check
x = torch.randn(16,3,64,64,device=device)
model_dis = Discriminator(params).to(device)
out_dis = model_dis(x)
print(out_dis.shape)
discriminator은 입력 받은 이미지를 진짜 이미지와 가짜 이미지로 식별하는 이진 분류를 수행합니다. 모델 구조는 일반적인 CNN과 동일합니다.
손실 함수
손실 함수는 기존의 GAN과 동일합니다.
위 손실 함수를 최대화하는 방향으로 학습을 진행합니다. 위 손실 함수가 최대값을 갖으려면 D(x)=1, D(G(z))=0이 되어야 합니다. D(x)=1은 진짜 이미지를 진짜로 식별, D(G(z))=0은 generator이 생성한 가짜이미지를 가짜로 식별하는 것을 의미합니다.
p_real = discriminator(진짜 이미지) # 진짜 이미지를 입력받아 출력한 확률
p_fake = discriminator(가짜 이미지) # 가짜 이미지를 입력받아 출력한 확률
loss_real = -1 * torch.log(p_real) # p_real = 1이 되어야 최대값을 갖습니다.
loss_fake = -1 * torch.log(1.-p_fake) # p_fake = 0이 되어야 최대값을 갖습니다.
loss_d = (loss_real + loss_fake).mean()
Performance
침대 이미지 셋으로 학습된 generator이 생성한 가짜 이미지입니다. 정말 감쪽같네요
아래 결과는 흥미롭습니다. 사람 얼굴을 학습한 generator에 (안경쓴 남자를 출력하는 noise를 찾아서 평균을 취한것 - 안경이 없는 남자를 출력하는 남자를 출력하는 noise를 평균한 것 + 안경 없는 여자를 출력하는 noise의 평균) 의 연산을 수행하여 generator에 입력해주었더니 안경쓴 여자가 출력되었습니다.
참고자료