[논문 구현] MoCov2(2020) PyTorch 구현
안녕하세요, 이번 포스팅에서는 MoCov2를 Google Colab 환경에서 PyTorch로 구현해보도록 하겠습니다.
논문 리뷰와 전체 코드는 아래 주소에서 확인하실 수 있습니다.
MoCov2 PyTorch tutorial
우선 필요한 라이브러리를 import 합니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import torchvision.transforms as transforms
from torchsummary import summary
import numpy as np
from PIL import Image
import os
import time
from collections import OrderedDict
import copy
import random
import matplotlib.pyplot as plt
%matplotlib inline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1. 데이터셋 불러오기
PyTorch torchvision에서 제공하는 STL-10 dataset을 사용하겠습니다. STL-10 dataset은 5000개의 train data를 갖고 있습니다.
데이터셋을 불러오기 전에 transformation을 정의합니다. 동일한 이미지에 transformation을 두 번 적용하여 query와 key를 생성하도록 Split class도 함께 정의합니다.
# define transformation for query and key
# transformation for train
train_transform = transforms.Compose([
transforms.Resize((250,250)),
transforms.RandomResizedCrop(224),
transforms.RandomApply([
transforms.ColorJitter(0.5, 0.5, 0.5)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
# define class to make query and key
class Split:
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
데이터셋을 저장할 경로를 지정하고, STL-10 dataset을 불러옵니다.
# specify data path
path2data = './data'
os.makedirs(path2data, exist_ok=True)
# load STL-10 dataset
train_ds = datasets.STL10(path2data, split='train', download=True, transform=Split(train_transform))
샘플 이미지를 확인합니다.
# check sample image
img, label = train_ds[0]
query, key = img[0], img[1]
print('img size: ',query.shape)
# display query and key
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(to_pil_image(0.5*query+0.5))
plt.title('Query')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(to_pil_image(0.5*key+0.5))
plt.title('Key')
plt.axis('off')
동일한 이미지에서 적용된 trasnformation 이미지는 similar pair 입니다.
dataloader를 정의합니다.
# define dataloader
train_dl = DataLoader(train_ds, batch_size=256, shuffle=True)
2. encoder 구현하기
encoder는 pytorch에서 제공하는 resnet 18 을 사용하겠습니다. resnet18을 불러오고 마지막 fc layer를 MLP 구조로 대체합니다. 그리고 key encoder는 query encoder를 복사하여 구현합니다.
# I use q encoder as resnett18 model
q_encoder = resnet18(pretrained=False)
# define classifier for our task
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(q_encoder.fc.in_features, 100)),
('added_relu1', nn.ReLU()),
('fc2', nn.Linear(100, 50)),
('added_relu2', nn.ReLU()),
('fc3', nn.Linear(50, 25))
]))
# replace classifier
# and this classifier make representation have 25 dimention
q_encoder.fc = classifier
# define encoder for key by coping q_encoder
k_encoder = copy.deepcopy(q_encoder)
# move encoders to device
q_encoder = q_encoder.to(device)
k_encoder = k_encoder.to(device)
모델 summary를 출력해보겠습니다.
# check model
summary(q_encoder, (3,224,224), device=device.type)
3. Unsupervised training
loss function을 정의합니다.
# define loss function
def loss_func(q,k,queue,t=0.05):
# t: temperature
N = q.shape[0] # batch_size
C = q.shape[1] # channel
# bmm: batch matrix multiplication
pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N,1),t))
neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C),torch.t(queue)),t)),dim=1)
# denominator is sum over pos and neg
denominator = pos + neg
return torch.mean(-torch.log(torch.div(pos,denominator)))
optimizer를 정의합니다.
# define optimizer
opt = optim.Adam(q_encoder.parameters())
queue를 정의하고 k_encoder의 출력값인 negative sample을 queue에 저장합니다.
# initialize the queue
queue = None
K = 8192 # K: number of negatives to store in queue
# fill the queue with negative samples
flag = 0
if queue is None:
while True:
with torch.no_grad():
for img, _ in train_dl:
# extract key samples
xk = img[1].to(device)
k = k_encoder(xk).detach()
if queue is None:
queue = k
else:
if queue.shape[0] < K: # queue < 8192
queue = torch.cat((queue,k),0)
else:
flag = 1 # stop filling the queue
if flag == 1:
break
if flag == 1:
break
queue = queue[:K]
queue내의 데이터 개수를 확인해보겠습니다.
# check queue
print('number of negative samples in queue : ',len(queue))
학습을 위한 함수를 정의합니다.
# define function to training
def Training(q_encoder, k_encoder, num_epochs, queue=queue, loss_func=loss_func, opt=opt, data_dl=train_dl, sanity_check=False):
loss_history = []
momentum = 0.999
start_time = time.time()
path2weights = './models/q_weights.pt'
len_data = len(data_dl.dataset)
q_encoder.train()
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs-1))
q_encoder.train()
running_loss = 0
for img, _ in data_dl:
# retrieve query and key
xq = img[0].to(device)
xk = img[1].to(device)
# get model outputs
q = q_encoder(xq)
k = k_encoder(xk).detach()
# normalize representations
q = torch.div(q, torch.norm(q,dim=1).reshape(-1,1))
k = torch.div(k, torch.norm(k,dim=1).reshape(-1,1))
# get loss value
loss = loss_func(q, k, queue)
running_loss += loss
opt.zero_grad()
loss.backward()
opt.step()
# update the queue
queue = torch.cat((queue, k), 0)
if queue.shape[0] > K:
queue = queue[256:,:]
# update k_encoder
for q_params, k_params in zip(q_encoder.parameters(), k_encoder.parameters()):
k_params.data.copy_(momentum*k_params + q_params*(1.0-momentum))
# store loss history
epoch_loss = running_loss / len(data_dl.dataset)
loss_history.append(epoch_loss)
print('train loss: %.6f, time: %.4f min' %(epoch_loss,(time.time()-start_time)/60))
if sanity_check:
break
# save weights
# torch.save(q_encoder.state_dict(), path2weights);
return q_encoder, k_encoder, loss_history
학습을 시작합니다. 저는 300 epoch 학습하겠습니다.
# start training
num_epochs = 300
q_encoder, _, loss_history = Training(q_encoder, k_encoder, num_epochs=num_epochs, sanity_check=False)
4시간 30분이 소요되었네요 ㅎㅎ
loss history를 출력하겠습니다.
# plot loss history
plt.title('Loss History')
plt.plot(range(1, num_epochs+1), loss_history, label='train')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
아직 수렴이 되진 않았네요
4. Transfer Learning
transfer learning을 위한 trasnfer dataset을 정의합니다. transfer dataset은 STL-10 train dataset의 10% 로 구성되어있습니다.
transfer dataset을 정의하기 위해 train dataset을 정의하고 여기서 10% 분할하여 사용하겠습니다.
# define transformation
data_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
# load STL-10 dataset
train_ds = datasets.STL10(path2data, split='train', download='True', transform=data_transform)
val_ds = datasets.STL10(path2data, split='train', download='True', transform=data_transform)
class imbalance 문제가 발생할 수 있으므로 동일한 class 비율을 갖도록 데이터셋을 분할하겠습니다.
# count the number of images per classes in train_ds
import collections
y_train = [y for _, y in train_ds]
counter_train = collections.Counter(y_train)
print(counter_train)
sklearnin 모듈의 StratifiedShuffleSplit과 Pytorch의 Subset 함수를 사용해서 데이터셋을 분할하겠습니다.
# split the indices of train_ds into two grous
# StratifiedShuffleSplit split the data in same proportion of labels
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=0)
indices = list(range(len(train_ds)))
for _, transfer_index in sss.split(indices, y_train):
print('transfer_index:', transfer_index[:10])
print(len(transfer_index))
# create datasets from train_ds
from torch.utils.data import Subset
transfer_ds = Subset(train_ds, transfer_index)
# check transfer_ds
y_transfer = [y for _, y in transfer_ds]
counter_transfer = collections.Counter(y_transfer)
print(counter_transfer)
500개의 image를 가진 transfer dataset을 생성했습니다.
이제 dataloader를 정의합니다.
# define dataloader
# transfer_ds has 500 images
transfer_dl = DataLoader(transfer_ds, 32, True)
val_dl = DataLoader(val_ds, 64, True)
transfer learning을 위해 q_encoder의 MLP projection을 제거하고, Linear layer를 추가합니다.
# removing the projection head of q_encoder
if len(nn.Sequential(*list(q_encoder.fc.children()))) == 5:
q_encoder.fc = nn.Sequential(*list(q_encoder.fc.children())[:-3])
# define Linear Classifier for transfer learning
class LinearClassifier(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(100,10)
def forward(self, x):
x = self.fc1(x)
return x
linear_classifier = LinearClassifier().to(device)
transfer dataset으로 학습을 진행합니다. q_encoder은 freeze하고 linearclassifier만 학습합니다.
20 epoch 학습을 진핵하겠습니다.
# training LinearClassifier
linear_epoch = 20
linear_loss_func = nn.CrossEntropyLoss()
linear_opt = optim.SGD(linear_classifier.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-6)
loss_hist = {'train':[], 'val':[]}
start_time = time.time()
# start training
for epoch in range(linear_epoch):
print('Epoch {}/{}'.format(epoch, linear_epoch-1))
running_train_loss = 0
running_val_loss = 0
running_metric = 0
# transfer dataloader
linear_classifier.train()
for x, y in transfer_dl:
x = x.to(device)
y = y.to(device)
# extract features using q_encoder
with torch.no_grad():
output_encoder = q_encoder(x)
linear_opt.zero_grad()
pred = linear_classifier(output_encoder)
loss = linear_loss_func(pred, y)
running_train_loss += loss
loss.backward()
linear_opt.step()
train_loss = running_train_loss / len(transfer_dl.dataset)
loss_hist['train'].append(train_loss)
# validation dataloader
linear_classifier.eval()
for x, y in val_dl:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
output_encoder = q_encoder(x)
pred = linear_classifier(output_encoder)
loss = linear_loss_func(pred, y)
running_val_loss += loss
pred = pred.argmax(1, keepdim=True)
metric = pred.eq(y.view_as(pred)).sum().item()
running_val_loss += loss
running_metric += metric
val_loss = running_val_loss / len(val_dl.dataset)
loss_hist['val'].append(val_loss)
val_metric = running_metric / len(val_dl.dataset)
print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
print('-'*10)
성능이 38% 밖에 안나오네요. 가장 큰 이유로는 train dataset이 5000개 밖에 되지 않았고, unsupervised learning의 epoch가 적었습니다. unsupervised learning은 supervised learning보다 더 많은 epoch가 필요합니다.
즉, q_encoder가 고차원 data로부터 유의미한 latent representation을 추출하지 못하고 있습니다.
loss history를 출력합니다.
num_epochs = linear_epoch
# Plot train-val loss
plt.title('Train-Val Loss')
plt.plot(range(1, num_epochs+1), loss_hist['train'], label='train')
plt.plot(range(1, num_epochs+1), loss_hist['val'], label='val')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()
결과값을 t-SNE로 시각화하겠습니다.
from sklearn.manifold import TSNE
import seaborn as sns
tsne = TSNE()
def plot_vecs_n_labels(v, labels, fname):
fig = plt.figure(figsize = (10,10))
plt.axis('off')
sns.set_style('darkgrid')
sns.scatterplot(v[:,0], v[:,1], hue=labels, legend='full', palette=sns.color_palette("bright", 10))
plt.legend(['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'])
plt.savefig(fname)
# change batch_size
val_dl = DataLoader(val_ds, 1024, True)
for x, y in val_dl:
x = x.to(device)
with torch.no_grad():
pred = q_encoder(x)
pred = linear_classifier(pred)
pred_tsne = tsne.fit_transform(pred.cpu().data)
plot_vecs_n_labels(pred_tsne, y, 'tsen.png')
break
모델이 분류를 잘 못하고 있네요. 추후에 GPU 자원이 생기게 되면 더 큰 dataset과 longer epoch를 사용하여 학습해보도록 하겠습니다 ㅎㅎ 감사합니다.