논문 구현
[논문 구현] PyTorch로 GAN(2014) 구현하고 학습하기
AI 꿈나무
2021. 5. 17. 22:34
반응형
안녕하세요! 이번 포스팅에서는 PyTorch로 구현한 GAN을 MNIST dataset으로 학습한 후, 학습된 generator이 생성한 가짜 이미지를 확인해보겠습니다.
작업 환경은 Google Colab에서 진행합니다.
전체 코드는 아래 깃허브에서 확인하실 수 있습니다.
목차
1. 데이터셋 불러오기
2. 모델 구축하기
3. 학습하기
4. generator이 생성한 가짜 이미지 확인하기
우선, 필요한 라이브러리를 불러오겠습니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
import os
import numpy as np
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1. 데이터셋 불러오기
# 데이터 경로 지정
path2data = './data'
os.makedirs(path2data, exist_ok=True) # 폴더 생성
# MNIST dataset 불러오기
train_ds = datasets.MNIST(path2data, train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]), download=True)
# 샘플 이미지 확인
img, label = train_ds[0]
plt.imshow(to_pil_image(img),cmap='gray')
# 데이터 로더 생성
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
# check
for x, y in train_dl:
print(x.shape, y.shape)
break
2. 모델 구축하기
모델 코드는 https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 을 참고했습니다.
generator
# generator: noise를 입력받아 이미지를 생성합니다.
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
self.nz = params['nz'] # 입력 노이즈 벡터 수, 100
self.img_size = params['img_size'] # 이미지 크기, 1x28x28
self.model = nn.Sequential(
*self._fc_layer(self.nz, 128, normalize=False),
*self._fc_layer(128,256),
*self._fc_layer(256,512),
*self._fc_layer(512,1024),
nn.Linear(1024,int(np.prod(self.img_size))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_size)
return img
# fc layer
def _fc_layer(self, in_channels, out_channels, normalize=True):
layers = []
layers.append(nn.Linear(in_channels, out_channels)) # fc layer
if normalize:
layers.append(nn.BatchNorm1d(out_channels, 0.8)) # BN
layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
return layers
# check
params = {'nz':100,
'img_size':(1,28,28)}
x = torch.randn(16,100).to(device) # random noise
model_gen = Generator(params).to(device)
output = model_gen(x) # noise를 입력받아 이미지 생성
print(output.shape)
discriminator
# discriminator: 진짜 이미지와 가짜 이미지를 분류합니다.
class Discriminator(nn.Module):
def __init__(self,params):
super().__init__()
self.img_size = params['img_size'] # 이미지 크기, 1x28x28
self.model = nn.Sequential(
nn.Linear(int(np.prod(self.img_size)), 512),
nn.LeakyReLU(0.2),
nn.Linear(512,256),
nn.LeakyReLU(0.2),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(x.size(0),-1)
x = self.model(x)
return x
# check
x = torch.randn(16,1,28,28).to(device)
model_dis = Discriminator(params).to(device)
output = model_dis(x)
print(output.shape)
가중치를 초기화 합니다.
# 가중치 초기화
def initialize_weights(model):
classname = model.__class__.__name__
# fc layer
if classname.find('Linear') != -1:
nn.init.normal_(model.weight.data, 0.0, 0.02)
nn.init.constant_(model.bias.data, 0)
# batchnorm
elif classname.find('BatchNorm') != -1:
nn.init.normal_(model.weight.data, 1.0, 0.02)
nn.init.constant_(model.bias.data, 0)
# 가중치 초기화 적용
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);
3. 학습하기
# 손실 함수
loss_func = nn.BCELoss()
from torch import optim
# 최적화 파라미터
lr = 2e-4
beta1 = 0.5
opt_dis = optim.Adam(model_dis.parameters(),lr=lr,betas=(beta1,0.999))
opt_gen = optim.Adam(model_gen.parameters(),lr=lr,betas=(beta1,0.999))
real_label = 1.
fake_label = 0.
nz = params['nz']
num_epochs = 100
loss_history={'gen':[],
'dis':[]}
batch_count = 0
start_time = time.time()
model_dis.train()
model_gen.train()
for epoch in range(num_epochs):
for xb, yb in train_dl:
ba_si = xb.size(0)
xb = xb.to(device)
yb_real = torch.Tensor(ba_si,1).fill_(1.0).to(device)
yb_fake = torch.Tensor(ba_si,1).fill_(0.0).to(device)
# Generator
model_gen.zero_grad()
noise = torch.randn(ba_si,nz, device=device) # 노이즈 생성
out_gen = model_gen(noise) # 가짜 이미지 생성
out_dis = model_dis(out_gen) # 가짜 이미지 판별
loss_gen = loss_func(out_dis, yb_real)
loss_gen.backward()
opt_gen.step()
# Discriminator
model_dis.zero_grad()
out_real = model_dis(xb) # 진짜 이미지 판별
out_fake = model_dis(out_gen.detach()) # 가짜 이미지 판별
loss_real = loss_func(out_real, yb_real)
loss_fake = loss_func(out_fake, yb_fake)
loss_dis = (loss_real + loss_fake) / 2
loss_dis.backward()
opt_dis.step()
loss_history['gen'].append(loss_gen.item())
loss_history['dis'].append(loss_dis.item())
batch_count += 1
if batch_count % 1000 == 0:
print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, loss_gen.item(), loss_dis.item(), (time.time()-start_time)/60))
Loss history를 출력합니다.
# plot loss history
plt.figure(figsize=(10,5))
plt.title('Loss Progress')
plt.plot(loss_history['gen'], label='Gen. Loss')
plt.plot(loss_history['dis'], label='Dis. Loss')
plt.xlabel('batch count')
plt.ylabel('Loss')
plt.legend()
plt.show()
가중치를 저장합니다.
# 가중치 저장
path2models = './models/'
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, 'weights_gen.pt')
path2weights_dis = os.path.join(path2models, 'weights_dis.pt')
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)
4. Generator이 생성한 가짜 이미지 확인
# 가중치 불러오기
weights = torch.load(path2weights_gen)
model_gen.load_state_dict(weights)
# evaluation mode
model_gen.eval()
# fake image 생성
with torch.no_grad():
fixed_noise = torch.randn(16, 100, device=device)
img_fake = model_gen(fixed_noise).detach().cpu()
print(img_fake.shape)
가짜 이미지 시각화
# 가짜 이미지 시각화
plt.figure(figsize=(10,10))
for ii in range(16):
plt.subplot(4,4,ii+1)
plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5),cmap='gray')
plt.axis('off')
좀 더 많은 학습이 필요하네요..ㅎㅎ
반응형