[논문 구현] PyTorch로 ResNet(2015) 구현하고 학습하기
이번 포스팅에서는 PyTorch로 ResNet을 구현하고 학습까지 해보겠습니다.
논문 리뷰는 여기에서 확인하실 수 있습니다.
전체 코드는 여기에서 확인하실 수 있습니다.
github.com/Seonghoon-Yu/paper-implement-in-pytorch
작업 환경은 구글 코랩에서 진행했습니다.
1. 데이터셋 불러오기
데이터셋은 torchvision 패키지에서 제공하는 STL10 dataset을 이용하겠습니다. STL10 dataset은 10개의 label을 갖으며 train dataset 5000개, test dataset 8000개로 구성됩니다.
우선 필요한 라이브러리를 import 합니다.
# import package
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
from torch.optim.lr_scheduler import StepLR
# dataset and transformation
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
# display images
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
# utils
import numpy as np
from torchsummary import summary
import time
import copy
데이터셋 다운받을 경로를 지정하고, 데이터셋을 불러옵니다.
# specify the data path
path2data = '/content/resnet/MyDrive/data'
# if not exists the path, make the directory
if not os.path.exists(path2data):
os.mkdir(path2data)
# load dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())
print(len(train_ds))
print(len(val_ds))
이미지에 Normalization을 적용하기 위해, 이미지 픽셀값의 평균, 표준편차를 계산합니다.
# To normalize the dataset, calculate the mean and std
train_meanRGB = [np.mean(x.numpy(), axis=(1,2)) for x, _ in train_ds]
train_stdRGB = [np.std(x.numpy(), axis=(1,2)) for x, _ in train_ds]
train_meanR = np.mean([m[0] for m in train_meanRGB])
train_meanG = np.mean([m[1] for m in train_meanRGB])
train_meanB = np.mean([m[2] for m in train_meanRGB])
train_stdR = np.mean([s[0] for s in train_stdRGB])
train_stdG = np.mean([s[1] for s in train_stdRGB])
train_stdB = np.mean([s[2] for s in train_stdRGB])
val_meanRGB = [np.mean(x.numpy(), axis=(1,2)) for x, _ in val_ds]
val_stdRGB = [np.std(x.numpy(), axis=(1,2)) for x, _ in val_ds]
val_meanR = np.mean([m[0] for m in val_meanRGB])
val_meanG = np.mean([m[1] for m in val_meanRGB])
val_meanB = np.mean([m[2] for m in val_meanRGB])
val_stdR = np.mean([s[0] for s in val_stdRGB])
val_stdG = np.mean([s[1] for s in val_stdRGB])
val_stdB = np.mean([s[2] for s in val_stdRGB])
print(train_meanR, train_meanG, train_meanB)
print(val_meanR, val_meanG, val_meanB)
dataset에 적용할 transformation을 정의합니다.
# define the image transformation
train_transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224),
transforms.Normalize([train_meanR, train_meanG, train_meanB],[train_stdR, train_stdG, train_stdB]),
transforms.RandomHorizontalFlip(),
])
val_transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224),
transforms.Normalize([train_meanR, train_meanG, train_meanB],[train_stdR, train_stdG, train_stdB]),
])
transformation을 dataset에 적용하고, dataloader를 생성합니다.
# apply transforamtion
train_ds.transform = train_transformation
val_ds.transform = val_transformation
# create DataLoader
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True)
transformation이 적용된 샘플 이미지를 확인하겠습니다.
# display sample images
def show(img, y=None, color=True):
npimg = img.numpy()
npimg_tr = np.transpose(npimg, (1,2,0))
plt.imshow(npimg_tr)
if y is not None:
plt.title('labels :' + str(y))
np.random.seed(1)
torch.manual_seed(1)
grid_size = 4
rnd_inds = np.random.randint(0, len(train_ds), grid_size)
print('image indices:',rnd_inds)
x_grid = [train_ds[i][0] for i in rnd_inds]
y_grid = [train_ds[i][1] for i in rnd_inds]
x_grid = utils.make_grid(x_grid, nrow=grid_size, padding=2)
show(x_grid, y_grid)
잘 적용되었네요!
2. 모델 구축하기
코드는 여기를 참고했습니다. github.com/weiaicunzai/pytorch-cifar100/blob/master/models/resnet.py
ResNet은 residual block이 겹겹이 쌓여 구성된 모델입니다.
ResNet-18,34는 왼쪽 residual block을 사용하고, ResNet-50 부터는 오른쪽 BottleNeck을 사용합니다.
각각의 residual block을 정의합니다.
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
# BatchNorm에 bias가 포함되어 있으므로, conv2d는 bias=False로 설정합니다.
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion),
)
# identity mapping, input과 output의 feature map size, filter 수가 동일한 경우 사용.
self.shortcut = nn.Sequential()
self.relu = nn.ReLU()
# projection mapping using 1x1conv
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
x = self.residual_function(x) + self.shortcut(x)
x = self.relu(x)
return x
class BottleNeck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
self.relu = nn.ReLU()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels*BottleNeck.expansion)
)
def forward(self, x):
x = self.residual_function(x) + self.shortcut(x)
x = self.relu(x)
return x
전체 구조입니다.
class ResNet(nn.Module):
def __init__(self, block, num_block, num_classes=10, init_weights=True):
super().__init__()
self.in_channels=64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# weights inittialization
if init_weights:
self._initialize_weights()
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self,x):
output = self.conv1(x)
output = self.conv2_x(output)
x = self.conv3_x(output)
x = self.conv4_x(x)
x = self.conv5_x(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# define weight initialization function
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def resnet18():
return ResNet(BasicBlock, [2,2,2,2])
def resnet34():
return ResNet(BasicBlock, [3, 4, 6, 3])
def resnet50():
return ResNet(BottleNeck, [3,4,6,3])
def resnet101():
return ResNet(BottleNeck, [3, 4, 23, 3])
def resnet152():
return ResNet(BottleNeck, [3, 8, 36, 3])
모델이 잘 구축됬는지 확인합니다.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = resnet50().to(device)
x = torch.randn(3, 3, 224, 224).to(device)
output = model(x)
print(output.size())
summary(model, (3, 224, 224), device=device.type)
VGG-16은 800MB이었는데, ResNet-50이 더 가볍네요ㅎㅎ 하지만 GoogLeNet보다는 무겁습니다.
3. 모델 학습하기
손실 함수, optimizer, lr_scheduler를 정의합니다.
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.001)
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10)
현재 lr을 계산하는 함수를 정의합니다.
# function to get current lr
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
배치당 loss와 metric을 계산하는 함수를 정의합니다.
# function to calculate metric per mini-batch
def metric_batch(output, target):
pred = output.argmax(1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# function to calculate loss per mini-batch
def loss_batch(loss_func, output, target, opt=None):
loss = loss_func(output, target)
metric_b = metric_batch(output, target)
if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), metric_b
epoch당 loss를 정의하는 함수입니다.
# function to calculate loss and metric per epoch
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / len_data
metric = running_metric / len_data
return loss, metric
이제, 위에서 정의한 함수를 활용해서 학습을 시작하는 함수를 정의합니다.
val_loss가 가장 낮을 때 모델의 가중치를 저장하는 코드를 구현했는데 GPU out of memory가 발생해서 주석처리 했습니다.
# function to start training
def train_val(model, params):
num_epochs=params['num_epochs']
loss_func=params["loss_func"]
opt=params["optimizer"]
train_dl=params["train_dl"]
val_dl=params["val_dl"]
sanity_check=params["sanity_check"]
lr_scheduler=params["lr_scheduler"]
path2weights=params["path2weights"]
loss_history = {'train': [], 'val': []}
metric_history = {'train': [], 'val': []}
# # GPU out of memoty error
# best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float('inf')
start_time = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
# best_model_wts = copy.deepcopy(model.state_dict())
# torch.save(model.state_dict(), path2weights)
# print('Copied best model weights!')
print('Get best val_loss')
lr_scheduler.step(val_loss)
print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
# model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
하이퍼 파라미터를 정의합니다.
# definc the training parameters
params_train = {
'num_epochs':20,
'optimizer':opt,
'loss_func':loss_func,
'train_dl':train_dl,
'val_dl':val_dl,
'sanity_check':False,
'lr_scheduler':lr_scheduler,
'path2weights':'./models/weights.pt',
}
# create the directory that stores weights.pt
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSerror:
print('Error')
createFolder('./models')
학습을 시작합니다. 저는 20epoch만 진행하겠습니다. 코랩 환경에서 진행하다보니 끊기는 경우가 종종 있어서 오랫 동안 학습하기가 불안하네요..ㅎㅎ
model, loss_hist, metric_hist = train_val(model, params_train)
학습이 잘 되네요
train loss, val loss 그래프와 정확도 그래프를 확인하겠습니다.
# Train-Validation Progress
num_epochs=params_train["num_epochs"]
# plot loss progress
plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
아직 수렴을 안했네요. epoch 20으로는 부족한것 같습니다.