논문 구현
[논문 구현] PyTorch로 LeNet-5(1998) 구현하기
AI 꿈나무
2021. 3. 8. 19:36
반응형
안녕하세요! 공부 목적으로 LeNet-5를 파이토치로 구현해보도록 하겠습니다!
논문 리뷰는 여기에서 확인하실 수 있습니다.
전체 코드는 여기에서 확인하실 수 있습니다.
제 깃허브 주소입니다...ㅎㅎ 많이 빈약하지만 꾸준히 업데이트하여 풍부한 내용을 담아보도록 하겠습니다...!!!
스타도 눌러주시면 감사하겠습니다!!
1. MNIST dataset 불러오기
LeNet-5를 학습하기 위한 MNIST dataset을 불러옵니다. MNIST dataset은 torchvision 패키지에서 제공됩니다.
# 우선, MNIST dataset에 적용할 transformation 객체를 생성합니다.
from torchvision import transforms
# transformation 정의하기
data_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
# MNIST training dataset 불러오기
from torchvision import datasets
# 데이터를 저장할 경로 설정
path2data = '/content/data'
# training data 불러오기
train_data = datasets.MNIST(path2data, train=True, download=True, transform=data_transform)
# MNIST test dataset 불러오기
val_data = datasets.MNIST(path2data, train=False, download=True, transform=data_transform)
불러온 MNIST dataset을 Data loader로 wrap 해줍니다.
# data loader 를 생성합니다.
from torch.utils.data import DataLoader
train_dl = DataLoader(train_data, batch_size=32, shuffle=True)
val_dl = DataLoader(val_data, batch_size=32)
MNIST dataset에서 sample image를 확인하겠습니다.
# sample images를 확인합니다.
from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# training data를 추출합니다.
x_train, y_train = train_data.data, train_data.targets
# val data를 추출합니다.
x_val, y_val = val_data.data, val_data.targets
# 차원을 추가하여 B*C*H*W 가 되도록 합니다.
if len(x_train.shape) == 3:
x_train = x_train.unsqueeze(1)
if len(x_val.shape) == 3:
x_val = x_val.unsqueeze(1)
# tensor를 image로 변경하는 함수를 정의합니다.
def show(img):
# tensor를 numpy array로 변경합니다.
npimg = img.numpy()
# C*H*W를 H*W*C로 변경합니다.
npimg_tr = npimg.transpose((1,2,0))
plt.imshow(npimg_tr, interpolation='nearest')
# images grid를 생성하고 출력합니다.
# 총 40개 이미지, 행당 8개 이미지를 출력합니다.
x_grid = utils.make_grid(x_train[:40], nrow=8, padding=2)
show(x_grid)
2. LeNet-5 모델 구축하기
data 전처리는 끝났습니다. 이제 LeNet-5 모델을 구현하겠습니다.
from torch import nn
import torch.nn.functional as F
class LeNet_5(nn.Module):
def __init__(self):
super(LeNet_5,self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
self.conv3 = nn.Conv2d(16, 120, kernel_size=5, stride=1)
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 10)
def forward(self, x):
x = F.tanh(self.conv1(x))
x = F.avg_pool2d(x, 2, 2)
x = F.tanh(self.conv2(x))
x = F.avg_pool2d(x, 2, 2)
x = F.tanh(self.conv3(x))
x = x.view(-1, 120)
x = F.tanh(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
model = LeNet_5()
print(model)
모델을 CUDA device로 전달합니다.
# 모델을 CUDA로 전달합니다.
model.to(device)
print(next(model.parameters()).device)
모델 summary를 확인합니다.
# 모델 summary를 확인합니다.
from torchsummary import summary
summary(model, input_size=(1, 32, 32))
3. loss function, optimizer 정의하기
loss function을 정의합니다.
# loss function 정의합니다.
loss_func = nn.CrossEntropyLoss(reduction='sum')
optimizer를 정의합니다.
# optimizer 정의합니다.
from torch import optim
opt = optim.Adam(model.parameters(), lr=0.001)
# 현재 lr을 계산하는 함수를 정의합니다.
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# 러닝레이트 스케쥴러를 정의합니다.
from torch.optim.lr_scheduler import CosineAnnealingLR
lr_scheduler = CosineAnnealingLR(opt, T_max=2, eta_min=1e-05)
4. 학습을 위한 helper function 정의
# 배치당 performance metric 을 계산하는 함수 정의
def metrics_batch(output, target):
pred = output.argmax(dim=1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# 배치당 loss를 계산하는 함수를 정의
def loss_batch(loss_func, output, target, opt=None):
loss = loss_func(output, target)
metric_b = metrics_batch(output, target)
if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), metric_b
# epoch당 loss와 performance metric을 계산하는 함수 정의
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.type(torch.float).to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True: # sanity_check가 True이면 1epoch만 학습합니다.
break
loss = running_loss / float(len_data)
metric = running_metric / float(len_data)
return loss, metric
# train_val 함수 정의
def train_val(model, params):
num_epochs = params['num_epochs']
loss_func = params['loss_func']
opt = params['optimizer']
train_dl = params['train_dl']
val_dl = params['val_dl']
sanity_check = params['sanity_check']
lr_scheduler = params['lr_scheduler']
path2weights = params['path2weights']
loss_history = {
'train': [],
'val': [],
}
metric_history = {
'train': [],
'val': [],
}
# best model parameter를 저장합니다.
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float('inf')
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print('Copied best model weights')
lr_scheduler.step()
print('train loss: %.6f, dev loss: %.6f, accuracy: %.2f' %(train_loss, val_loss, 100*val_metric))
print('-'*10)
# best model을 반환합니다.
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
5. 모델 학습하기
위에서 정의한 함수를 사용하여 모델을 학습시킵니다.
# 모델을 학습합니다.
model,loss_hist,metric_hist=train_val(model,params_train)
3epoch만 학습했는데도 정확도가 높게 나오네요
loss와 accuracy 그래프를 plot 해봅니다.
num_epochs=params_train["num_epochs"]
plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
반응형