논문 구현
[논문 구현] PyTorch로 PreAct-ResNet(2016) 구현하고 학습하기
AI 꿈나무
2021. 3. 20. 23:10
반응형
이번에 구현해볼 모델은 PreAct-ResNet입니다. 논문 리뷰는 아래 포스팅에서 확인하실 수 있습니다.
전체 코드는 여기에서 확인하실 수 있습니다.
1. 데이터셋 불러오기
데이터셋은 torchvision 패키지에서 제공하는 STL10 dataset을 이용하겠습니다. STL10 dataset은 10개의 label을 갖으며 train dataset 5000개, test dataset 8000개로 구성됩니다.
우선 colab mount를 하겠습니다.
from google.colab import drive
drive.mount('pre-act-resnet')
필요한 패키지를 import 합니다.
# import package
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
# dataset and transformation
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
# display images
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
# utils
import numpy as np
from torchsummary import summary
import time
import copy
dataset을 불러옵니다.
# specift the data path
path2data = '/content/pre-act-resnet/MyDrive/data'
# if not exists the path, make the directory
if not os.path.exists(path2data):
os.mkdir(path2data)
# load dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())
print(len(train_ds))
print(len(val_ds))
transformation 객체를 정의하고 데이터셋에 적용합니다.
# define image transformation
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224)
])
train_ds.transform = transformation
val_ds.transform = transformation
dataloader를 생성합니다.
# create dataloader
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=True)
샘플 이미지를 확인하겠습니다.
# display sample images
def show(img, y=None, color=True):
npimg = img.numpy()
npimg_tr = np.transpose(npimg, (1, 2, 0))
plt.imshow(npimg_tr)
if y is not None:
plt.title('labels:' + str(y))
np.random.seed(3)
torch.manual_seed(0)
grid_size = 4
rnd_ind = np.random.randint(0, len(train_ds), grid_size)
x_grid = [train_ds[i][0] for i in rnd_ind]
y_grid = [train_ds[i][1] for i in rnd_ind]
plt.figure(figsize=(10,10))
x_grid = utils.make_grid(x_grid, nrow=4, padding=2)
show(x_grid, y_grid)
2. 모델 구축하기
코드는 https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/preactresnet.py 을 참고했습니다.
class BottleNeck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.residual = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, 1, stride=stride),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, 1)
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, 1, stride=stride)
def forward(self, x):
x_shortcut = self.shortcut(x)
x_residual = self.residual(x)
return x_shortcut + x_residual
class PreActResNet(nn.Module):
def __init__(self, num_blocks, num_classes=10, init_weights=True):
super().__init__()
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(3, 2, 1)
)
self.conv2 = self._make_layers(num_blocks[0], 64, 1)
self.conv3 = self._make_layers(num_blocks[1], 128, 2)
self.conv4 = self._make_layers(num_blocks[2], 256, 2)
self.conv5 = self._make_layers(num_blocks[3], 512, 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512 * BottleNeck.expansion, num_classes)
# weights Initialization
if init_weights:
self._initialize_weights()
def _make_layers(self, num_blocks, out_channels, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(BottleNeck(self.in_channels, out_channels, stride))
self.in_channels = out_channels * BottleNeck.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def PreAct_ResNet50():
return PreActResNet([3,4,6,3])
PreAct_ResNet50을 불러오고, summary를 출력합니다.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PreAct_ResNet50().to(device)
summary(model, (3,224,224), device=device.type)
3. 학습하기
학습에 필요한 함수를 정의합니다.
# define loss function, optimizer, lr_scheduler
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.001)
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10)
# get current lr
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# calculate the metric per mini-batch
def metric_batch(output, target):
pred = output.argmax(1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# calculate the loss per mini-batch
def loss_batch(loss_func, output, target, opt=None):
loss_b = loss_func(output, target)
metric_b = metric_batch(output, target)
if opt is not None:
opt.zero_grad()
loss_b.backward()
opt.step()
return loss_b.item(), metric_b
# calculate the loss per epoch
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / len_data
metric = running_metric / len_data
return loss, metric
# function to start training
def train_val(model, params):
num_epochs=params['num_epochs']
loss_func=params['loss_func']
opt=params['optimizer']
train_dl=params['train_dl']
val_dl=params['val_dl']
sanity_check=params['sanity_check']
lr_scheduler=params['lr_scheduler']
path2weights=params['path2weights']
loss_history={'train': [], 'val': []}
metric_history={'train': [], 'val': []}
best_loss = float('inf')
start_time = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
print('Get Best val loss!!')
lr_scheduler.step(val_loss)
print('train loss: %.6f, val_loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
return model, loss_history, metric_history
training 파라미터를 정의합니다.
# define the training parameters
params_train = {
'num_epochs':20,
'optimizer':opt,
'loss_func':loss_func,
'train_dl':train_dl,
'val_dl':val_dl,
'sanity_check':False,
'lr_scheduler':lr_scheduler,
'path2weights':'./models/weights.pt',
}
# create the directory that stores weights.pt
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSerror:
print('Error')
createFolder('./models')
학습 시작!
model, loss_hist, metric_hist = train_val(model, params_train)
loss progress를 출력하겠습니다.
# Train-Val progress
num_epochs=params_train['num_epochs']
# plot loss progress
plt.title('Train_Val Loss')
plt.plot(range(1,num_epochs+1), loss_hist['train'], label='train')
plt.plot(range(1,num_epochs+1), loss_hist['val'], label='val')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
# plot accuracy progress
plt.title('Train-Val Accuracy')
plt.plot(range(1, num_epochs+1), metric_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), metric_hist['val'], label='val')
plt.ylabel('Accuracy')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
반응형