[논문 구현] PyTorch로 Xception(2017) 구현하고 학습하기
안녕하세요! 이번에 구현해볼 모델은 Xception(2017) 입니다. 작업 환경은 google colab에서 진행했습니다.
Xception은 Depthwise separable convolution을 활용해 연산량을 줄인 만큼 층을 깊게 쌓은 모델인데요. 자세한 논문 리뷰는 아래 게시글에서 확인하실 수 있습니다!
전체 코드는 여기에서 확인하실 수 있습니다.
1. 데이터셋 불러오기
데이터셋은 torchvision 패키지에서 제공하는 STL10 dataset을 이용하겠습니다. STL10 dataset은 10개의 label을 갖으며 train dataset 5000개, test dataset 8000개로 구성됩니다.
우선 google colab에 mount를 합니다.
from google.colab import drive
drive.mount('xception')
필요한 라이브러리를 import 합니다.
# import package
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
# dataset and transformation
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
# display images
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
# utils
import numpy as np
from torchsummary import summary
import time
import copy
데이터셋을 불러옵니다.
# specify path to data
path2data = '/content/xception/MyDrive/data'
# if not exists the path, make the directory
if not os.path.exists(path2data):
os.mkdir(path2data)
# load dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())
print(len(train_ds))
print(len(val_ds))
transformation을 정의하고, 데이터 셋에 적용합니다.
# define transformation
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(299)
])
# apply transformation to dataset
train_ds.transform = transformation
val_ds.transform = transformation
dataloader를 생성합니다.
# make dataloade
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True)
샘플 이미지를 확인합니다.
# check sample images
def show(img, y=None):
npimg = img.numpy()
npimg_tr = np.transpose(npimg, (1, 2, 0))
plt.imshow(npimg_tr)
if y is not None:
plt.title('labels:' + str(y))
np.random.seed(11)
torch.manual_seed(0)
grid_size=4
rnd_ind = np.random.randint(0, len(train_ds), grid_size)
x_grid = [train_ds[i][0] for i in rnd_ind]
y_grid = [val_ds[i][1] for i in rnd_ind]
x_grid = utils.make_grid(x_grid, nrow=grid_size, padding=2)
plt.figure(figsize=(10,10))
show(x_grid, y_grid)
2. 모델 구축하기
Xception 모델을 구축하겠습니다. Xception은 EnrtyFlow, MiddleFlow, ExitFlow 세 구조로 이루어져있습니다.
Depthwise separable convolution 클래스를 정의합니다.
# Depthwise Separable Convolution
class SeparableConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.seperable = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=False),
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
)
def forward(self, x):
x = self.seperable(x)
return x
EntryFlow를 정의합니다.
# EnrtyFlow
class EntryFlow(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv2_residual = nn.Sequential(
SeparableConv(64, 128),
nn.BatchNorm2d(128),
nn.ReLU(),
SeparableConv(128, 128),
nn.BatchNorm2d(128),
nn.MaxPool2d(3, stride=2, padding=1)
)
self.conv2_shortcut = nn.Sequential(
nn.Conv2d(64, 128, 1, stride=2, padding=0),
nn.BatchNorm2d(128)
)
self.conv3_residual = nn.Sequential(
nn.ReLU(),
SeparableConv(128, 256),
nn.BatchNorm2d(256),
nn.ReLU(),
SeparableConv(256, 256),
nn.BatchNorm2d(256),
nn.MaxPool2d(3, stride=2, padding=1)
)
self.conv3_shortcut = nn.Sequential(
nn.Conv2d(128, 256, 1, stride=2, padding=0),
nn.BatchNorm2d(256)
)
self.conv4_residual = nn.Sequential(
nn.ReLU(),
SeparableConv(256, 728),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv(728, 728),
nn.BatchNorm2d(728),
nn.MaxPool2d(3, stride=2, padding=1)
)
self.conv4_shortcut = nn.Sequential(
nn.Conv2d(256, 728, 1, stride=2, padding=0),
nn.BatchNorm2d(728)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2_residual(x) + self.conv2_shortcut(x)
x = self.conv3_residual(x) + self.conv3_shortcut(x)
x = self.conv4_residual(x) + self.conv4_shortcut(x)
return x
MiddleFlow를 정의합니다.
# MiddleFlow
class MiddleFlow(nn.Module):
def __init__(self):
super().__init__()
self.conv_residual = nn.Sequential(
nn.ReLU(),
SeparableConv(728, 728),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv(728, 728),
nn.BatchNorm2d(728),
nn.ReLU(),
SeparableConv(728, 728),
nn.BatchNorm2d(728)
)
self.conv_shortcut = nn.Sequential()
def forward(self, x):
return self.conv_shortcut(x) + self.conv_residual(x)
ExitFlow를 정의합니다.
# ExitFlow
class ExitFlow(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1_residual = nn.Sequential(
nn.ReLU(),
SeparableConv(728, 1024),
nn.BatchNorm2d(1024),
nn.ReLU(),
SeparableConv(1024, 1024),
nn.BatchNorm2d(1024),
nn.MaxPool2d(3, stride=2, padding=1)
)
self.conv1_shortcut = nn.Sequential(
nn.Conv2d(728, 1024, 1, stride=2, padding=0),
nn.BatchNorm2d(1024)
)
self.conv2 = nn.Sequential(
SeparableConv(1024, 1536),
nn.BatchNorm2d(1536),
nn.ReLU(),
SeparableConv(1536, 2048),
nn.BatchNorm2d(2048),
nn.ReLU()
)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
def forward(self, x):
x = self.conv1_residual(x) + self.conv1_shortcut(x)
x = self.conv2(x)
x = self.avg_pool(x)
return x
위에서 정의한 클래스를 활용해서 Xception을 정의합니다. _make_middle_flow 함수를 정의해서 middleflow를 8번 쌓았습니다.
# Xception
class Xception(nn.Module):
def __init__(self, num_classes=10, init_weights=True):
super().__init__()
self.init_weights = init_weights
self.entry = EntryFlow()
self.middle = self._make_middle_flow()
self.exit = ExitFlow()
self.linear = nn.Linear(2048, num_classes)
# weights initialization
if self.init_weights:
pass
def forward(self, x):
x = self.entry(x)
x = self.middle(x)
x = self.exit(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def _make_middle_flow(self):
middle = nn.Sequential()
for i in range(8):
middle.add_module('middle_block_{}'.format(i), MiddleFlow())
return middle
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init_kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init_constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init_constant_(m.weight, 1)
nn.init_bias_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init_normal_(m.weight, 0, 0.01)
nn.init_constant_(m.bias, 0)
모델이 잘 구축됬는지 확인합니다.
# check model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.randn(3, 3, 299, 299).to(device)
model = Xception().to(device)
output = model(x)
print('output size:', output.size())
summary를 출력합니다.
# print summary
summary(model, (3, 299, 299), device=device.type)
모델이 무겁네요ㅎㅎ
3. 학습하기
학습에 필요한 함수를 정의하고, 학습을 해보겠습니다.
# define loss function, optimizer, lr_scheduler
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.01)
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=8)
# get current lr
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# calculate the metric per mini-batch
def metric_batch(output, target):
pred = output.argmax(1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# calculate the loss per mini-batch
def loss_batch(loss_func, output, target, opt=None):
loss_b = loss_func(output, target)
metric_b = metric_batch(output, target)
if opt is not None:
opt.zero_grad()
loss_b.backward()
opt.step()
return loss_b.item(), metric_b
# calculate the loss per epochs
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / len_data
metric = running_metric / len_data
return loss, metric
# function to start training
def train_val(model, params):
num_epochs=params['num_epochs']
loss_func=params['loss_func']
opt=params['optimizer']
train_dl=params['train_dl']
val_dl=params['val_dl']
sanity_check=params['sanity_check']
lr_scheduler=params['lr_scheduler']
path2weights=params['path2weights']
loss_history = {'train': [], 'val': []}
metric_history = {'train': [], 'val': []}
best_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())
start_time = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr= {}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print('Copied best model weights!')
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print('Loading best model weights!')
model.load_state_dict(best_model_wts)
print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
학습에 필요한 파라미터를 설정합니다.
# define the training parameters
params_train = {
'num_epochs':30,
'optimizer':opt,
'loss_func':loss_func,
'train_dl':train_dl,
'val_dl':val_dl,
'sanity_check':False,
'lr_scheduler':lr_scheduler,
'path2weights':'./models/weights.pt',
}
# check the directory to save weights.pt
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSerror:
print('Error')
createFolder('./models')
학습을 시작합니다. 30epochs 학습 하겠습니다.
model, loss_hist, metric_hist = train_val(model, params_train)
모델 파라미터 수가 많아서 학습을 더 시켜야겠네요. 학습 해보는 것에 의미를 두었기 때문에 수렴할 때 까지 학습하진 않았습니다.
loss, accuracy progress를 출력합니다.
# Train-Validation progress
num_epochs = params_train['num_epochs']
# plot loss progress
plt.title('Train-Val Loss')
plt.plot(range(1, num_epochs+1), loss_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), loss_hist['val'], label='val')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
# plot accuracy progress
plt.title('Train-Val Accuracy')
plt.plot(range(1, num_epochs+1), metric_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), metric_hist['val'], label='val')
plt.ylabel('Accuracy')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()