[논문 구현] ViT(2020) PyTorch 구현 및 학습
공부 목적으로 ViT를 구현하고 학습한 내용을 공유합니다 ㅎㅎ. 작업 환경은 Google Colab에서 진행했습니다.
필요한 라이브러리를 설치 및 임포트합니다.
!pip install einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from torch import optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
from torchvision import utils
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import numpy as np
import time
import copy
import random
from tqdm.notebook import tqdm
import math
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1. 데이터셋 불러오기
데이터 셋은 STL-10을 사용했습니다. train data가 5000개 밖에 되지 않지만, 접근성이 용이하여 모델 구현 시에 주로 사용하는 데이터셋 입니다.
# specify path to data
path2data = '/content/data'
# if not exists the path, make the directory
if not os.path.exists(path2data):
os.mkdir(path2data)
# load dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split='test', download=True, transform=transforms.ToTensor())
print(len(train_ds))
print(len(val_ds))
transformation을 구현하고 적용합니다. 지금 발견한건데, 깜박하고 normalization을 적용하지 않았네요 ㅎㅎ
# define transformation
transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(224)
])
# apply transformation to dataset
train_ds.transform = transformation
val_ds.transform = transformation
# make dataloade
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=True)
샘플 이미지를 확인합니다.
# check sample images
def show(img, y=None):
npimg = img.numpy()
npimg_tr = np.transpose(npimg, (1, 2, 0))
plt.imshow(npimg_tr)
if y is not None:
plt.title('labels:' + str(y))
np.random.seed(10)
torch.manual_seed(0)
grid_size=4
rnd_ind = np.random.randint(0, len(train_ds), grid_size)
x_grid = [train_ds[i][0] for i in rnd_ind]
y_grid = [val_ds[i][1] for i in rnd_ind]
x_grid = utils.make_grid(x_grid, nrow=grid_size, padding=2)
plt.figure(figsize=(10,10))
show(x_grid, y_grid)
2. ViT 구현하기
구현 코드는 https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632 를 참고했습니다. 설명과 함께 ViT 모델을 구현하기 때문에 아래 포스팅을 보는 것이 더 유익합니다 ㅎㅎ
(1) Patch Embedding
# To handle 2D images, reshape the image into a sequence of flattened 2D patches.
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
super().__init__()
self.patch_size = patch_size
# # Method 1: Flatten and FC layer
# self.projection = nn.Sequential(
# Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
# nn.Linear(path_size * patch_size * in_channels, emb_size)
# )
# Method 2: Conv
self.projection = nn.Sequential(
# using a conv layer instead of a linear one -> performance gains
nn.Conv2d(in_channels, emb_size, patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e')
)
self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))
self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
def forward(self, x):
b = x.shape[0]
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
# prepend the cls token to the input
x = torch.cat([cls_tokens, x], dim=1)
# add position embedding to prejected patches
x += self.positions
return x
클래스를 정의하고 잘 작동하는지 check 합니다.
# Check PatchEmbedding
x = torch.randn(16, 3, 224, 224).to(device)
patch_embedding = PatchEmbedding().to(device)
patch_output = patch_embedding(x)
print('[batch, 1+num of patches, emb_size] = ', patch_output.shape)
(2) MultiHeadAttention
# MultiHeadAttention
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size=768, num_heads=8, dropout=0):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x, mask=None):
# split keys, queries and values in num_heads
queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h=self.num_heads) # b, 197, 728 -> b, 8, 197, 91
keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h=self.num_heads)
values = rearrange(self.values(x), 'b n (h d) -> b h n d', h=self.num_heads)
# sum up over the last axis, b,h,197,197
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_head, query_len, key_len
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)
# sum up over the third axis
out = torch.einsum('bhal, bhlv -> bhav', att, values) # 197x91
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.projection(out)
return out
잘 구현되었는지 확인합니다.
# Check MultiHeadAttention
MHA = MultiHeadAttention().to(device)
MHA_output = MHA(patch_output)
print(MHA_output.shape)
(3) Residual
# perform the residual addition.
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
(4) FeedForwardBlock
# Subclassing nn.Sequential to avoid writing the forward method.
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion=4, drop_p=0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
Check
# check
x = torch.randn(16,1,128).to(device)
model = FeedForwardBlock(128).to(device)
output = model(x)
print(output.shape)
(5) TransformerEncoderBlock
# Now create the Transformer Encoder Block
class TransformerEncoderBlock(nn.Sequential):
def __init__(self, emb_size=768, drop_p=0., forward_expansion=4, forward_drop_p=0., **kwargs):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, **kwargs),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
))
)
Check
# Check TransformerEncoderBlock
model = TransformerEncoderBlock().to(device)
output = model(patch_output).to(device)
print(output.shape)
(6) TransformerEncoder
# TransformerEncoder consists of L blocks of TransformerBlock
class TransformerEncoder(nn.Sequential):
def __init__(self, depth=12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
Check
# Check TransformerEncoder
model = TransformerEncoder().to(device)
output = model(patch_output)
print(output.shape)
(7) ClassificationHead
# define ClassificationHead which gives the class probability
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size=768, n_classes = 10):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))
Check
# check
x = torch.randn(16, 1, 768).to(device)
model = ClassificationHead().to(device)
output = model(x)
print(output.shape)
(8) ViT
# Define the ViT architecture
class ViT(nn.Sequential):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224, depth=12, n_classes=10, **kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, n_classes)
)
Check
# Check
x = torch.randn(16,3,224,224).to(device)
model = ViT().to(device)
output = model(x)
print(output.shape)
구현한 ViT의 summary를 출력해보겠습니다.
model = ViT().to(device)
summary(model, (3,224,224), device=device.type)
모델 파라미터가 엄청 나네요.
STL-10 dataset은 train data가 적어 모델 파라미터가 너무 크면 학습이 잘 안되는 것 같습니다.
정확도가 잘 안나올거 같아서 불안하네요
3. 학습하기
학습을 위한 함수를 정의하고 학습하겠습니다.
일반적으로 분류 task에서 학습 코드는 동일하므로, 코드를 한번 짜두면 계속해서 재사용 할 수 있습니다.
# define the loss function, optimizer and lr_scheduler
loss_func = nn.CrossEntropyLoss(reduction='sum')
opt = optim.Adam(model.parameters(), lr=0.01)
from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10)
# get current lr
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# calculate the metric per mini-batch
def metric_batch(output, target):
pred = output.argmax(1, keepdim=True)
corrects = pred.eq(target.view_as(pred)).sum().item()
return corrects
# calculate the loss per mini-batch
def loss_batch(loss_func, output, target, opt=None):
loss_b = loss_func(output, target)
metric_b = metric_batch(output, target)
if opt is not None:
opt.zero_grad()
loss_b.backward()
opt.step()
return loss_b.item(), metric_b
# calculate the loss per epochs
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / len_data
metric = running_metric / len_data
return loss, metric
# function to start training
def train_val(model, params):
num_epochs=params['num_epochs']
loss_func=params['loss_func']
opt=params['optimizer']
train_dl=params['train_dl']
val_dl=params['val_dl']
sanity_check=params['sanity_check']
lr_scheduler=params['lr_scheduler']
path2weights=params['path2weights']
loss_history = {'train': [], 'val': []}
metric_history = {'train': [], 'val': []}
best_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())
start_time = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr= {}'.format(epoch, num_epochs-1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history['train'].append(train_loss)
metric_history['train'].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
loss_history['val'].append(val_loss)
metric_history['val'].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print('Copied best model weights!')
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print('Loading best model weights!')
model.load_state_dict(best_model_wts)
print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
# define the training parameters
params_train = {
'num_epochs':100,
'optimizer':opt,
'loss_func':loss_func,
'train_dl':train_dl,
'val_dl':val_dl,
'sanity_check':False,
'lr_scheduler':lr_scheduler,
'path2weights':'./models/weights.pt',
}
# check the directory to save weights.pt
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSerror:
print('Error')
createFolder('./models')
100 epoch 학습하겠습니다.
# Start training
model, loss_hist, metric_hist = train_val(model, params_train)
정확도가 35%밖에 안나오네요. STL-10 dataset에서 MobileNet+SENet이 70% 정도 나왔던 걸로 기억하는데, ViT의 성능이 잘 안나오는 이유는 여러가지가 있습니다.
가장 큰 이유는 데이터셋이 너무 작습니다. ViT는 글로벌한 정보를 학습하기 때문에 CNN보다 더 많은 데이터와 학습 시간이 필요합니다. 이외에도 여러 가지 이유가 있겠지만, 모델 구현과 학습이 목적이었기 때문에 여기까지 언급하겠습니다..ㅎㅎ
Loss history를 출력합니다.
num_epochs = params_train['num_epochs']
# Plot train-val loss
plt.title('Train-Val Loss')
plt.plot(range(1, num_epochs+1), loss_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), loss_hist['val'], label='val')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
# plot train-val accuracy
plt.title('Train-Val Accuracy')
plt.plot(range(1, num_epochs+1), metric_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), metric_hist['val'], label='val')
plt.ylabel('Accuracy')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
감사합니다. 전체 코드는 아래 깃허브에서 확인하실 수 있습니다.
https://github.com/Seonghoon-Yu/Paper_Review_and_Implementation_in_PyTorch