[논문 구현] PyTorch로 ResNext(2017) 구현하고 학습하기
안녕하세요! 이번 포스팅에서는 ResNext(2017)을 구현하고 학습까지 해보도록 하겠습니다. 작업 환경은 Google Colab에서 진행했습니다.
ResNext는 Group Convolution을 제안한 논문입니다. Group Convolution을 활용하면, Conv 연산에서 필요한 연산량이 감소하게 되어서, 동일한 연산량 내에 더 많은 피쳐맵을 활용할 수 있는 이점이 있습니다. ResNext에서 사용하는 Group Convolution이 현재 최신 모델까지 이용하고 있는 것을 보면 영향력이 대단한 것 같네요ㅎㅎ 아래 포스팅에서 논문 리뷰를 확인하실 수 있습니다.
전체 코드는 여기에서 확인하실 수 있습니다.
우선 Google Colab에 mount를 합니다.
from google.colab import drive
drive.mount('resnext')
# import package
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
# dataset and transformation
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
# display images
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
# utils
import numpy as np
from torchsummary import summary
import time
import copy
1. 데이터셋 불러오기
데이터셋은 torchvision 패키지에서 제공하는 STL10 dataset을 이용하겠습니다. STL10 dataset은 10개의 label을 갖으며 train dataset 5000개, test dataset 8000개로 구성됩니다.
데이터셋을 불러옵니다.
# specify path to data
path2data = '/content/resnext/MyDrive/data'
# if not exists the path, make the directory
if not os.path.exists(path2data):
os.mkdir(path2data)
# load dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())
print(len(train_ds))
print(len(val_ds))
transformation을 정의하고, dataset에 적용합니다.
# define transformation
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224)
])
# apply transformation to dataset
train_ds.transform = transformation
val_ds.transform = transformation
# make dataloade
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True)
샘플 이미지를 확인합니다.
# check sample images
def show(img, y=None):
npimg = img.numpy()
npimg_tr = np.transpose(npimg, (1, 2, 0))
plt.imshow(npimg_tr)
if y is not None:
plt.title('labels:' + str(y))
np.random.seed(10)
torch.manual_seed(0)
grid_size=4
rnd_ind = np.random.randint(0, len(train_ds), grid_size)
x_grid = [train_ds[i][0] for i in rnd_ind]
y_grid = [val_ds[i][1] for i in rnd_ind]
x_grid = utils.make_grid(x_grid, nrow=grid_size, padding=2)
plt.figure(figsize=(10,10))
show(x_grid, y_grid)
2. 모델 구축하기
코드는 https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/resnext.py 를 참고했습니다.
ResNet에서 사용하는 BottleNet과의 차이점은 3x3 conv 함수에 groups 인자를 활용한 것입니다. 파이토치에서는 group convolution 기능을 제공하여, group 인자를 설정해주면 편리하게 group convolution을 사용할 수 있습니다.
class BottleNeck(nn.Module):
expansion = 4
Cardinality = 32 # group 수
Basewidth = 64 # bottleneck 채널이 64이면 group convolution의 채널은 depth가 됩니다.
Depth = 4 # basewidth일 때, group convolution의 채널 수
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
C = BottleNeck.Cardinality
D = int(BottleNeck.Depth * out_channels / BottleNeck.Basewidth)
self.conv_residual = nn.Sequential(
nn.Conv2d(in_channels, C * D, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C*D),
nn.ReLU(),
nn.Conv2d(C*D, C*D, 3, stride=stride, padding=1, groups=BottleNeck.Cardinality, bias=False),
nn.BatchNorm2d(C*D),
nn.ReLU(),
nn.Conv2d(C*D, out_channels * BottleNeck.expansion, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
self.conv_shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, 1, stride=stride, padding=0)
def forward(self, x):
x = self.conv_residual(x) + self.conv_shortcut(x)
return x
# ResNext
class ResNext(nn.Module):
def __init__(self, nblocks, num_classes=10, init_weights=True):
super().__init__()
self.init_weights=init_weights
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=2, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(3, stride=2, padding=1)
)
self.conv2 = self._make_res_block(nblocks[0], 64, 1)
self.conv3 = self._make_res_block(nblocks[1], 128, 2)
self.conv4 = self._make_res_block(nblocks[2], 256, 2)
self.conv5 = self._make_res_block(nblocks[3], 512, 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.linear = nn.Linear(512 * BottleNeck.expansion, num_classes)
# weights initialization
if self.init_weights:
self._initialize_weights()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def _make_res_block(self, nblock, out_channels, stride):
strides = [stride] + [1] * (nblock-1)
res_block = nn.Sequential()
for i, stride in enumerate(strides):
res_block.add_module('dens_layer_{}'.format(i), BottleNeck(self.in_channels, out_channels, stride))
self.in_channels = out_channels * BottleNeck.expansion
return res_block
# weights initialization function
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def ResNext50():
return ResNext([3, 4, 6, 3])
모델을 확인합니다.
# check model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.randn((3, 3, 224, 224)).to(device)
model = ResNext50().to(device)
output = model(x)
print('output size: ', output.size())
모델 summry를 출력합니다.
# print summary
summary(model, (3, 224, 224), device=device.type)
3. 학습하기
학습에 필요한 함수를 정의하고, 학습을 해보도록 하겠습니다.
# define loss function, optimizer, lr_scheduler
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.01)
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10)
# get current lr
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# calculate the metric per mini-batch
def metric_batch(output, target):
pred = output.argmax(1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# calculate the loss per mini-batch
def loss_batch(loss_func, output, target, opt=None):
loss_b = loss_func(output, target)
metric_b = metric_batch(output, target)
if opt is not None:
opt.zero_grad()
loss_b.backward()
opt.step()
return loss_b.item(), metric_b
# calculate the loss per epochs
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / len_data
metric = running_metric / len_data
return loss, metric
# function to start training
def train_val(model, params):
num_epochs=params['num_epochs']
loss_func=params['loss_func']
opt=params['optimizer']
train_dl=params['train_dl']
val_dl=params['val_dl']
sanity_check=params['sanity_check']
lr_scheduler=params['lr_scheduler']
path2weights=params['path2weights']
loss_history = {'train': [], 'val': []}
metric_history = {'train': [], 'val': []}
best_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())
start_time = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr= {}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print('Copied best model weights!')
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print('Loading best model weights!')
model.load_state_dict(best_model_wts)
print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
# define the training parameters
params_train = {
'num_epochs':30,
'optimizer':opt,
'loss_func':loss_func,
'train_dl':train_dl,
'val_dl':val_dl,
'sanity_check':False,
'lr_scheduler':lr_scheduler,
'path2weights':'./models/weights.pt',
}
# check the directory to save weights.pt
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSerror:
print('Error')
createFolder('./models')
학습 시작! 30 epoch 학습하겠습니다.
model, loss_hist, metric_hist = train_val(model, params_train)
loss, accuracy progress를 출력합니다.
# train-val progress
num_epochs = params_train['num_epochs']
# plot loss progress
plt.title('Train-Val Loss')
plt.plot(range(1, num_epochs+1), loss_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), loss_hist['val'], label='val')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
# plot accuracy progress
plt.title('Train-Val Accuracy')
plt.plot(range(1, num_epochs+1), metric_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), metric_hist['val'], label='val')
plt.ylabel('Accuracy')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
감사합니다.