이번 포스팅에서는 Conditional GAN을 PyTorch로 구현하고 MNIST dataset으로 학습한 후 generator이 생성한 가짜 이미지를 확인해보겠습니다.
논문 리뷰는 아래 포스팅에서 확인하실 수 있습니다.
[논문 읽기] 구현 코드로 살펴보는 CGAN(2014), Conditional Generative Adversarial Nets
오늘 읽은 논문은 CGAN(2014), Conditional Generative Adversarial Nets 입니다. GAN에 대한 배경지식이 있다고 가정하여 포스팅을 작성합니다. GAN 논문 리뷰는 아래 포스팅에서 살펴보실 수 있습니다. [논..
deep-learning-study.tistory.com
전체 코드는 아래 깃허브에서 확인하실 수 있습니다.
Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch
공부 목적으로 논문을 리뷰하고 파이토치로 구현하고 있습니다. Contribute to Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch development by creating an account on GitHub.
github.com
목차
1. dataset 불러오기
2. 모델 구축하기
3. 학습하기
4. Generator이 생성한 가짜 이미지 확인하기
우선, 필요한 라이브러리를 불러옵니다.
import torchvision.datasets as datasets
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
import time
%matplotlib inline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1. dataset 불러오기
# 데이터 경로 지정
path2data = './data'
os.makedirs(path2data, exist_ok=True)
# Transformation 정의
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5]),
])
# MNIST dataset 불러오기
train_ds = datasets.MNIST(path2data, train=True, transform=train_transform, download=True)
샘플 이미지를 확인합니다.
# 샘플 이미지 확인하기
img, label = train_ds.data, train_ds.targets
# 차원 추가
if len(img.shape) == 3:
img = img.unsqueeze(1) # B*C*H*W
# 그리드 생성
img_grid = utils.make_grid(img[:40], nrow=8, padding=2)
def show(img):
npimg = img.numpy()
npimg_tr = npimg.transpose((1,2,0)) # [C,H,W] -> [H,W,C]
plt.imshow(npimg_tr, interpolation='nearest')
show(img_grid)
데이터 로더를 생성합니다.
# 데이터 로더 생성
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
2. 모델 구축하기
코드는 https://github.com/eriklindernoren/PyTorch-GAN 를 참고했습니다.
파라미터를 설정합니다.
# 파라미터 설정
params = {'num_classes':10,
'nz':100,
'input_size':(1,28,28)}
Generator
# Generator: 가짜 이미지를 생성합니다.
# noise와 label을 결합하여 학습합니다..
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
self.num_classes = params['num_classes'] # 클래스 수, 10
self.nz = params['nz'] # 노이즈 수, 100
self.input_size = params['input_size'] # (1,28,28)
# noise와 label을 결합할 용도인 label embedding matrix를 생성합니다.
self.label_emb = nn.Embedding(self.num_classes, self.num_classes)
self.gen = nn.Sequential(
nn.Linear(self.nz + self.num_classes, 128),
nn.LeakyReLU(0.2),
nn.Linear(128,256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256,512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512,1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024,int(np.prod(self.input_size))),
nn.Tanh()
)
def forward(self, noise, labels):
# noise와 label 결합
gen_input = torch.cat((self.label_emb(labels),noise),-1)
x = self.gen(gen_input)
x = x.view(x.size(0), *self.input_size)
return x
# check
x = torch.randn(16,100,device=device) # 노이즈
label = torch.randint(0,10,(16,),device=device) # 레이블
model_gen = Generator(params).to(device)
out_gen = model_gen(x,label) # 가짜 이미지 생성
print(out_gen.shape)
Discriminator
# Discriminator: 가짜 이미지와 진짜 이미지를 식별합니다.
class Discriminator(nn.Module):
def __init__(self,params):
super().__init__()
self.input_size = params['input_size']
self.num_classes = params['num_classes']
self.label_embedding = nn.Embedding(self.num_classes, self.num_classes)
self.dis = nn.Sequential(
nn.Linear(self.num_classes+int(np.prod(self.input_size)),512),
nn.LeakyReLU(0.2),
nn.Linear(512,512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512,512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512,1),
nn.Sigmoid()
)
def forward(self, img, labels):
# 이미지와 label 결합
dis_input = torch.cat((img.view(img.size(0),-1),self.label_embedding(labels)),-1)
x = self.dis(dis_input)
return x
# check
x = torch.randn(16,1,28,28,device=device)
label = torch.randint(0,10,(16,), device=device)
model_dis = Discriminator(params).to(device)
out_dis = model_dis(x,label)
print(out_dis.shape)
가중치 초기화
# 가중치 초기화
def initialize_weights(model):
classname = model.__class__.__name__
# fc layer
if classname.find('Linear') != -1:
nn.init.normal_(model.weight.data, 0.0, 0.02)
nn.init.constant_(model.bias.data, 0)
# batchnorm
elif classname.find('BatchNorm') != -1:
nn.init.normal_(model.weight.data, 1.0, 0.02)
nn.init.constant_(model.bias.data, 0)
# 가중치 초기화 적용
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);
3. 학습하기
# 손실 함수
loss_func = nn.BCELoss()
from torch import optim
lr = 2e-4
beta1 = 0.5
beta2 = 0.999
# optimization
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1,beta2))
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1,beta2))
# 학습 파라미터
nz = params['nz']
num_epochs = 100
loss_history={'gen':[],
'dis':[]}
# 학습
batch_count = 0
start_time = time.time()
model_dis.train()
model_gen.train()
for epoch in range(num_epochs):
for xb, yb in train_dl:
ba_si = xb.shape[0]
xb = xb.to(device)
yb = yb.to(device)
yb_real = torch.Tensor(ba_si,1).fill_(1.0).to(device) # real_label
yb_fake = torch.Tensor(ba_si,1).fill_(0.0).to(device) # fake_label
# Genetator
model_gen.zero_grad()
noise = torch.randn(ba_si,100).to(device) # 노이즈 생성
gen_label = torch.randint(0,10,(ba_si,)).to(device) # label 생성
# 가짜 이미지 생성
out_gen = model_gen(noise, gen_label)
# 가짜 이미지 판별
out_dis = model_dis(out_gen, gen_label)
loss_gen = loss_func(out_dis, yb_real)
loss_gen.backward()
opt_gen.step()
# Discriminator
model_dis.zero_grad()
# 진짜 이미지 판별
out_dis = model_dis(xb, yb)
loss_real = loss_func(out_dis, yb_real)
# 가짜 이미지 판별
out_dis = model_dis(out_gen.detach(),gen_label)
loss_fake = loss_func(out_dis,yb_fake)
loss_dis = (loss_real + loss_fake) / 2
loss_dis.backward()
opt_dis.step()
loss_history['gen'].append(loss_gen.item())
loss_history['dis'].append(loss_dis.item())
batch_count += 1
if batch_count % 1000 == 0:
print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, loss_gen.item(), loss_dis.item(), (time.time()-start_time)/60))
loss history 시각화
# plot loss history
plt.figure(figsize=(10,5))
plt.title('Loss Progress')
plt.plot(loss_history['gen'], label='Gen. Loss')
plt.plot(loss_history['dis'], label='Dis. Loss')
plt.xlabel('batch count')
plt.ylabel('Loss')
plt.legend()
plt.show()
가중치 저장
# 가중치 저장
path2models = './models/'
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, 'weights_gen.pt')
path2weights_dis = os.path.join(path2models, 'weights_dis.pt')
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)
4. Generator이 생성한 가짜 이미지 확인하기
# 가중치 불러오기
weights = torch.load(path2weights_gen)
model_gen.load_state_dict(weights)
# evalutaion mode
model_gen.eval()
# fake image 생성
with torch.no_grad():
fixed_noise = torch.randn(16, 100, device=device)
label = torch.randint(0,10,(16,), device=device)
img_fake = model_gen(fixed_noise, label).detach().cpu()
print(img_fake.shape)
가짜 이미지 시각화
# 가짜 이미지 시각화
plt.figure(figsize=(10,10))
for ii in range(16):
plt.subplot(4,4,ii+1)
plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5), cmap='gray')
plt.axis('off')
학습이 더 필요하네요..ㅎㅎ
'논문 구현' 카테고리의 다른 글
[논문 구현] PyTorch로 Pix2Pix(2016) 구현하고 학습하기 (5) | 2021.05.20 |
---|---|
[논문 구현] PyTorch로 DCGAN(2015) 구현하고 학습하기 (0) | 2021.05.19 |
[논문 구현] PyTorch로 GAN(2014) 구현하고 학습하기 (3) | 2021.05.17 |
[논문 구현] PyTorch로 RetinaNet(2017) 구현하고 학습하기 (2) | 2021.05.06 |
[논문 구현] PyTorch로 YOLOv3(2018) 구현하고 학습하기 (6) | 2021.04.04 |