안녕하세요! 이번에 PyTorch로 구현해볼 모델은 MobileNetV1 입니다. MobileNetV1은 모델 경량화를 위해 Depthwise separable convolution을 활용하여 연산량을 감소한 모델입니다. 자세한 논문 리뷰는 아래 포스팅에서 확인하실 수 있습니다.
전체 코드는 여기에서 확인하실 수 있습니다.
1. 데이터셋 불러오기
데이터셋은 torchvision 패키지에서 제공하는 STL10 dataset을 이용하겠습니다. STL10 dataset은 10개의 label을 갖으며 train dataset 5000개, test dataset 8000개로 구성됩니다.
우선, Google Colab에 mount를 하겠습니다.
from google.colab import drive
drive.mount('mobilenet')
필요한 라이브러리를 import 합니다.
# import package
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
# dataset and transformation
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
# display images
from torchvision import utils
import matplotlib.pyplot as plt
%matplotlib inline
# utils
import numpy as np
from torchsummary import summary
import time
import copy
dataset을 불러옵니다.
# specify path to data
path2data = '/content/mobilenet/MyDrive/data'
# if not exists the path, make the directory
if not os.path.exists(path2data):
os.mkdir(path2data)
# load dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())
print(len(train_ds))
print(len(val_ds))
transformation을 정의하고, dataset에 적용합니다.
# define transformation
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224)
])
# apply transformation to dataset
train_ds.transform = transformation
val_ds.transform = transformation
dataloader를 생성합니다.
# make dataloade
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True)
샘플 이미지를 확인합니다.
# check sample images
def show(img, y=None):
npimg = img.numpy()
npimg_tr = np.transpose(npimg, (1, 2, 0))
plt.imshow(npimg_tr)
if y is not None:
plt.title('labels:' + str(y))
np.random.seed(10)
torch.manual_seed(0)
grid_size=4
rnd_ind = np.random.randint(0, len(train_ds), grid_size)
x_grid = [train_ds[i][0] for i in rnd_ind]
y_grid = [val_ds[i][1] for i in rnd_ind]
x_grid = utils.make_grid(x_grid, nrow=grid_size, padding=2)
plt.figure(figsize=(10,10))
show(x_grid, y_grid)
2. 모델 구축하기
코드는 https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/mobilenet.py 를 참고했습니다.
MobileNet의 전체 구조입니다.
MobileNetV1은 Depthwise separable convolution을 반복하여 쌓은 구조입니다. 따라서 Depthwise separable convolution 클래스를 정의합니다.
# Depthwise Separable Convolution
class Depthwise(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.depthwise = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU6(),
)
self.pointwise = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU6()
)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
BasicConv2d 클래스도 정의하면 편리하게 모델을 구축할 수 있습니다.
# Basic Conv2d
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
x = self.conv(x)
return x
위에서 정의한 클래스를 활용해서 MobileNetV1을 구축하겠습니다. MobileNet은 하이퍼파라미터가 존재합니다. 채널 수를 조절하는 width_multiplier 파라미터를 신경써서 구축합니다.
# MobileNetV1
class MobileNet(nn.Module):
def __init__(self, width_multiplier, num_classes=10, init_weights=True):
super().__init__()
self.init_weights=init_weights
alpha = width_multiplier
self.conv1 = BasicConv2d(3, int(32*alpha), 3, stride=2, padding=1)
self.conv2 = Depthwise(int(32*alpha), int(64*alpha), stride=1)
# down sample
self.conv3 = nn.Sequential(
Depthwise(int(64*alpha), int(128*alpha), stride=2),
Depthwise(int(128*alpha), int(128*alpha), stride=1)
)
# down sample
self.conv4 = nn.Sequential(
Depthwise(int(128*alpha), int(256*alpha), stride=2),
Depthwise(int(256*alpha), int(256*alpha), stride=1)
)
# down sample
self.conv5 = nn.Sequential(
Depthwise(int(256*alpha), int(512*alpha), stride=2),
Depthwise(int(512*alpha), int(512*alpha), stride=1),
Depthwise(int(512*alpha), int(512*alpha), stride=1),
Depthwise(int(512*alpha), int(512*alpha), stride=1),
Depthwise(int(512*alpha), int(512*alpha), stride=1),
Depthwise(int(512*alpha), int(512*alpha), stride=1),
)
# down sample
self.conv6 = nn.Sequential(
Depthwise(int(512*alpha), int(1024*alpha), stride=2)
)
# down sample
self.conv7 = nn.Sequential(
Depthwise(int(1024*alpha), int(1024*alpha), stride=2)
)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.linear = nn.Linear(int(1024*alpha), num_classes)
# weights initialization
if self.init_weights:
self._initialize_weights()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.conv7(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
# weights initialization function
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def mobilenet(alpha=1, num_classes=10):
return MobileNet(alpha, num_classes)
구축한 모델을 확인합니다.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.randn((3, 3, 224, 224)).to(device)
model = mobilenet(alpha=1).to(device)
output = model(x)
print('output size:', output.size())
model summary를 출력합니다.
summary(model, (3, 224, 224), device=device.type)
3. 학습하기
학습에 필요한 함수를 정의합니다.
# define loss function, optimizer, lr_scheduler
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.01)
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10)
# get current lr
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# calculate the metric per mini-batch
def metric_batch(output, target):
pred = output.argmax(1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# calculate the loss per mini-batch
def loss_batch(loss_func, output, target, opt=None):
loss_b = loss_func(output, target)
metric_b = metric_batch(output, target)
if opt is not None:
opt.zero_grad()
loss_b.backward()
opt.step()
return loss_b.item(), metric_b
# calculate the loss per epochs
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / len_data
metric = running_metric / len_data
return loss, metric
# function to start training
def train_val(model, params):
num_epochs=params['num_epochs']
loss_func=params['loss_func']
opt=params['optimizer']
train_dl=params['train_dl']
val_dl=params['val_dl']
sanity_check=params['sanity_check']
lr_scheduler=params['lr_scheduler']
path2weights=params['path2weights']
loss_history = {'train': [], 'val': []}
metric_history = {'train': [], 'val': []}
best_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())
start_time = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr= {}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print('Copied best model weights!')
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print('Loading best model weights!')
model.load_state_dict(best_model_wts)
print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
학습에 필요한 파라미터를 설정합니다.
# define the training parameters
params_train = {
'num_epochs':30,
'optimizer':opt,
'loss_func':loss_func,
'train_dl':train_dl,
'val_dl':val_dl,
'sanity_check':False,
'lr_scheduler':lr_scheduler,
'path2weights':'./models/weights.pt',
}
# check the directory to save weights.pt
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSerror:
print('Error')
createFolder('./models')
학습을 시작하겠습니다. 30epochs를 설정했습니다.
model, loss_hist, metric_hist = train_val(model, params_train)
train loss가 val loss에 비해 많이 낮네요. 오버피팅이 발생했습니다. 데이터 수를 늘려주거나 모델을 손봐야 합니다. MobileNet은 경량화에 치중한 모델이기 때문에 shortcut connection 을 사용하지 않습니다. 이것도 하나의 원인일 수도 있겠네요. 드랍 아웃을 추가하는 것도 좋은 방법입니다.
loss, accuracy progress를 확인하겠습니다.
# train-val progress
num_epochs = params_train['num_epochs']
# plot loss progress
plt.title('Train-Val Loss')
plt.plot(range(1, num_epochs+1), loss_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), loss_hist['val'], label='val')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
# plot accuracy progress
plt.title('Train-Val Accuracy')
plt.plot(range(1, num_epochs+1), metric_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), metric_hist['val'], label='val')
plt.ylabel('Accuracy')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
감사합니다.
'논문 구현' 카테고리의 다른 글
[논문 구현] PyTorch로 ResNext(2017) 구현하고 학습하기 (1) | 2021.03.29 |
---|---|
[논문 구현] PyTorch로 Residual Attention Network(2017) 구현하고 학습하기 (0) | 2021.03.27 |
[논문 구현] PyTorch로 Xception(2017) 구현하고 학습하기 (0) | 2021.03.23 |
[논문 구현] PyTorch로 DenseNet(2017) 구현하고 학습하기 (1) | 2021.03.22 |
[논문 구현] PyTorch로 WRN, Wide residual Network(2016) 구현하고 학습하기 (0) | 2021.03.22 |