오늘 읽은 논문은 GAN, Generative Adversarial Nets 입니다. 파이토치 코드와 함께 살펴보도록 하겠습니다.
GAN은 Generator, Discriminator 두 개의 신경망으로 이루어져 있습니다.
Generator
Generator은 무작위로 생성한 noise로 가짜 이미지를 생성합니다.
Generator PyTorch 구현 코드
# generator: noise를 입력받아 이미지를 생성합니다.
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
self.nz = params['nz'] # 입력 노이즈 벡터 수, 100
self.img_size = params['img_size'] # 이미지 크기, 1x28x28
self.model = nn.Sequential(
*self._fc_layer(self.nz, 128, normalize=False),
*self._fc_layer(128,256),
*self._fc_layer(256,512),
*self._fc_layer(512,1024),
nn.Linear(1024,int(np.prod(self.img_size))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_size)
return img
# fc layer
def _fc_layer(self, in_channels, out_channels, normalize=True):
layers = []
layers.append(nn.Linear(in_channels, out_channels)) # fc layer
if normalize:
layers.append(nn.BatchNorm1d(out_channels, 0.8)) # BN
layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
return layers
# check
params = {'nz':100,
'img_size':(1,28,28)}
x = torch.randn(16,100).to(device) # random noise
model_gen = Generator(params).to(device)
output = model_gen(x) # noise를 입력받아 이미지 생성
print(output.shape)
Generator의 목적은 Discriminator이 식별하지 못하는 가짜 이미지를 생성하는 것입니다. 따라서 목적 함수는 Discriminator이 가짜 이미지를 실제 이미지로 식별할 확률이 1이 되도록 설계합니다.
Generator은 아래의 목적 함수를 최소화 하는 방향으로 학습합니다. 가짜 이미지를 1(진짜 이미지)로 식별한다면 목적 함수는 0이 됩니다. 목적 함수를 최소화 하도록 generator을 학습시킨 다면, 진짜 같은 가짜 이미지를 생성해낼 수 있습니다.
generator은 진짜 이미지 분포를 추정하여, noise를 입력 받아 추정한 분포로 이미지를 생성하는 것으로 생각해볼 수 있습니다. 진짜 이미지 분포와 근사화한 분포를 추정한다면, 진짜 같은 이미지를 생성해낼 수 있습니다.
검은 점선은 train_data 분포, 초록 선은 추정한 분포를 나타냅니다. 학습이 진행될 수록 추정한 분포가 train_data 분포와 같아지는 것을 확인할 수 있습니다. 이 경우에 discriminator이 구별하지 못하는 가짜 이미지를 생성할 수 있습니다.
목적 함수 PyTorch 코드
p_fake = discriminator(generator(noise))) # discriminator에 가짜 이미지를 입력하여 확률값 출력
loss_g = torch.log(1.-p_fake).mean() # 출력한 확률값이 1이 되도록 손실 함수를 설계
Discriminator
Discriminator은 학습 이미지와 generator이 생성한 가짜 이미지를 분류합니다.
Discriminator PyTorch 코드
# discriminator: 진짜 이미지와 가짜 이미지를 분류합니다.
class Discriminator(nn.Module):
def __init__(self,params):
super().__init__()
self.img_size = params['img_size'] # 이미지 크기, 1x28x28
self.model = nn.Sequential(
nn.Linear(int(np.prod(self.img_size)), 512),
nn.LeakyReLU(0.2),
nn.Linear(512,256),
nn.LeakyReLU(0.2),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(x.size(0),-1)
x = self.model(x)
return x
# check
x = torch.randn(16,1,28,28).to(device)
model_dis = Discriminator(params).to(device)
output = model_dis(x)
print(output.shape)
Discriminator의 목적은 가짜 이미지를 가짜로 분류하고, 학습 데이터는 진짜 이미지로 분류하는 것입니다. 목적 함수는 진짜 이미지를 진짜로 식별하고, 가짜 이미지를 가짜로 식별할 수 있도록 설계합니다. 아래의 목적 함수가 최대값을 갖도록 학습합니다.
D(x)는 진짜 이미지, D(G(z))는 가짜 이미지를 의미합니다. D(x)=1, D(G(z))=0이 되어야 목적 함수는 최대값을 갖습니다.
Generator이 진짜 같은 이미지를 생성해낸다면 discriminator은 항상 p=1/2 값을 출력하게 됩니다. 식별하지 못하기 때문에 진짜, 가짜 둘 중 하나의 값을 랜덤으로 출력하기 때문입니다.
목적 함수 PyTorch 코드
p_real = discriminator(진짜 이미지) # 진짜 이미지를 입력받아 출력한 확률
p_fake = discriminator(가짜 이미지) # 가짜 이미지를 입력받아 출력한 확률
loss_real = -1 * torch.log(p_real) # p_real = 1이 되어야 최대값을 갖습니다.
loss_fake = -1 * torch.log(1.-p_fake) # p_fake = 0이 되어야 최대값을 갖습니다.
loss_d = (loss_real + loss_fake).mean()
Performance
Generator이 충분히 학습된 후에 generator이 생성하는 가짜 이미지 입니다.
GAN PyTorch 구현 코드는 아래 포스팅에서 확인하실 수 있습니다.
참고자료