논문 구현

[논문 구현] PyToch로 AlexNet(2012) 구현하기

AI 꿈나무 2021. 3. 13. 22:12
반응형

 AlexNet(2012) 논문 리뷰는 여기에서 확인하실 수 있습니다.

 

[논문 리뷰] AlexNet(2012) 리뷰와 파이토치 구현

딥러닝 논문 읽고 파이토치로 구현하기 시리즈 1. [논문 리뷰] LeNet-5 (1998), 파이토치로 구현하기  이번에 읽어볼 논문은 'ImageNet Classification with Deep Convilutional Neural Networks'(AlexNet) 입니..

deep-learning-study.tistory.com

 

 전체 코드는 여기에서 확인하실 수 있습니다.

 제 깃허브 저장소입니다! 스타도 눌러주시면 감사하겠습니다...ㅎㅎ

 아직 내용이 많이 빈약하지만 꾸준히 업데이트하여 풍부한 내용을 담도록 하겠습니다! 

 

1. 데이터셋 불러오고 Transformation 적용하기

데이터셋은 torchvision에서 제공하는 STL-10 dataset을 사용합니다. STL-10 dataset은 train과 test dataset만 제공합니다. 따라서, test dataset을 분할해서 val dataset을 따로 만들어 주겠습니다. transformation은 resize, horizontal flip, normalize를 적용합니다.

 

 우선 구글 코랩에 마운트를 합니다.

from google.colab import drive
drive.mount('alexnet')

 

 STL-10 dataset을 불러옵니다. torchvision 패키지에서 데이터셋을 제공합니다.

# loading training dataset
from torchvision import datasets
import torchvision.transforms as transforms
import os

# specify a data path
path2data = '/data'

# if not exists the path, make the path
if not os.path.exists(path2data):
    os.mkdir(path2data)

# load STL10 train dataset, and check
data_transformer = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.STL10(path2data, split='train', download=True, transform=data_transformer)
print(train_ds.data.shape)

 

# load STL10 test dataset
test0_ds = datasets.STL10(path2data, split='test', download=True, transform=data_transformer)
print(test0_ds.data.shape)

 

 데이터 정규화를 위해 평균, 표준편차 값을 계산합니다.

# calculate the mean and standard deviation of train_ds
import numpy as np

meanRGB = [np.mean(x.numpy(), axis=(1,2)) for x, _ in train_ds]
stdRGB = [np.std(x.numpy(), axis=(1,2)) for x, _ in train_ds]

meanR = np.mean([m[0] for m in meanRGB])
meanG = np.mean([m[1] for m in meanRGB])
meanB = np.mean([m[2] for m in meanRGB])

stdR = np.mean([s[0] for s in stdRGB])
stdG = np.mean([s[1] for s in stdRGB])
stdB = np.mean([s[2] for s in stdRGB])

print(meanR, meanG, meanB)
print(stdR, stdG, stdB)

 

 image transformation을 정의합니다.

# define the image transformation for trains_ds
# in paper, using FiveCrop, normalize, horizontal reflection
train_transformer = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(227),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize([meanR, meanG, meanB], [stdR, stdG, stdB]),
])

# define the image transforamtion for test0_ds
test_transformer = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([meanR, meanG, meanB], [stdR, stdG, stdB]),
                transforms.Resize(227)
])

 

 불러온 데이터셋에 transformation을 적용합니다.

# apply transformation to train_ds and test0_ds
train_ds.transform = train_transformer
test0_ds.transform = test_transformer

 

 transformation이 적용된 sample image를 확인합니다.

import torch
from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# display the transformed sample images from train_ds

# define helper function to show images
def show(img, y=None, color=True):

    npimg = img.numpy()
    npimg_tr = np.transpose(npimg, (1, 2, 0))
    plt.imshow(npimg_tr)
    
    # plt.imshow(npimg_tr)
    if y is not None:
        plt.title('labels: ' + str(y))

np.random.seed(0)
torch.manual_seed(0)

# pick a random sample image 
rnd_inds = int(np.random.randint(0, len(train_ds), 1))
print(rnd_inds)
img, label = train_ds[rnd_inds]
print('images indices: ', rnd_inds)

plt.figure(figsize=(10, 10))
show(img)

 색상이 좀 기괴하네요ㅎㅎ

 귀여운 강아지...가 사라졌습니다.

 색상 정규화를 적용하지 않으면 귀여운 강아지가 다시 나타날껍니다ㅎㅎ

 

 train_ds에서 카테고리당 이미지 개수를 확인합니다.

# count the number of images per category in train_ds
import collections
y_train = [y for _, y in train_ds]
counter_train = collections.Counter(y_train)
print(counter_train)

 10개의 label이 존재하고 각 label에 500개의 이미지가 있네요

 

  STL-10 dataset은 val dataset을 제공하지 않으므로 test dataset에서 val dataset을 분할하겠습니다.

 StratifiedShuffleSplit 함수는 각 레이블에서 동일한 비율로 이미지 인덱스를 추출합니다.

 이제 Subset 함수를 활용하여 추출한 인덱스로 데이터셋을 분할해주면 됩니다!!

# split the indices of test0_ds into two groups
# there aren't validation dataset in STL10 dataset, so make validation dataset
# by spliting test0 dataset
from sklearn.model_selection import StratifiedShuffleSplit

# StratifiedShuffleSplit splits indices of test0 in same proportion of labels
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)

indices = list(range(len(test0_ds)))
y_test0 = [y for _,y in test0_ds]

for test_index, val_index in sss.split(indices, y_test0):
    print('test :', len(test_index) , 'val :', len(val_index))

 

 Subset 함수로 test dataset과 val dataset을 나눕니다.

 test0 dataset이 업데이트되면 test dataset과 val dataset 모두 업데이트 됩니다.

# create two datasets from test0_ds
from torch.utils.data import Subset

# if test0_ds is updated, val_ds and test_ds are updated
# because val_ds and test_ds are a subset of test0_ds
val_ds = Subset(test0_ds, val_index)
test_ds = Subset(test0_ds, test_index)
# count the number of images per calss in val_ds and test_ds
import collections
import numpy as np

y_test = [y for _, y in test_ds]
y_val = [y for _, y in val_ds]

counter_test = collections.Counter(y_test)
counter_val = collections.Counter(y_val)
print(counter_test)
print(counter_val)

 

 DataLoader를 생성합니다.

# create dataloaders from train_ds and val_ds
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False)

# check dataloader
for x,y in train_dl:
    print(x.shape)
    print(y.shape)
    break

for x,y in val_dl:
    print(x.shape)
    print(y.shape)
    break

 

2. 모델 구축하기

 우선 device를 정의합니다.

# define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
import torch.nn as nn
import torch.nn.functional as F

class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet,self).__init__()
        # input size : (b x 3 x 227 x 227)
        # 논문에는 image 크기가 224 pixel이라고 나와 있지만, 오타입니다.
        # 227x227을 사용합니다.

        # Conv layer
        self.net = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), # (b x 96 x 55 x 55)
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2), # (b x 96 x 27 x 27)

            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2), # (b x 256 x 27 x 27)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2), # (b x 256 x 13 x 13)

            nn.Conv2d(256, 384, 3, 1, 1), # (b x 384 x 13 x 13)
            nn.ReLU(),

            nn.Conv2d(384, 384, 3, 1, 1), # (b x 384 x 13 x 13)
            nn.ReLU(),

            nn.Conv2d(384, 256, 3, 1, 1), # (b x 256 x 13 x 13)
            nn.ReLU(),
            nn.MaxPool2d(3, 2), # (b x 256 x 6 x 6)
        )

        # fc layer
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=(256 * 6 * 6), out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=num_classes),
        )

        # weight initialization
        self.init_weight()

    # define weight initialization function
    def init_weight(self):
        for layer in self.net:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)
        # in paper, initialize bias to 1 for conv2, 4, 5 layer
        nn.init.constant_(self.net[4].bias, 1)
        nn.init.constant_(self.net[10].bias, 1)
        nn.init.constant_(self.net[12].bias, 1)
    
    def forward(self,x):
        x = self.net(x)
        x = x.view(-1, 256 * 6* 6)
        x = self.classifier(x)
        return x

 

 구축한 모델을 생성하고 확인합니다.

# check the model
model = AlexNet().to(device)
print(model)

 

 model summary를 출력합니다.

# get the model summary
from torchsummary import summary
summary(Model, input_size=(3, 227, 227), device=device.type)

 

 모델이 잘 생성되었네요ㅎㅎ

 

 가중치 초기화가 잘 적용됬는지 확인해봅니다.

# check weight initialization
for p in model.parameters():
    print(p)
    break

 

3. 모델 학습

# define the loss function
loss_func = nn.CrossEntropyLoss(reduction='sum')

# define the optimizer
from torch import optim
opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
# opt = optim.Adam(model.parameters(), lr=0.01)

# read the current value of the learning rate using the following function
def get_lr(opt):
    for param_group in opt.param_groups:
        return param_group['lr']
        
# define lr_scheduler : 1/10 per 10 epochs
from torch.optim.lr_scheduler import StepLR
lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=30, gamma=0.1)

# define a function to count the number of correct predictions per mini-batch
def metrics_batch(output, target):
    # get output class
    pred = output.argmax(dim=1, keepdim=True)
    # compare output class with target class
    corrects = pred.eq(target.view_as(pred)).sum().item()
    return corrects

# define a function to compute the loss value per mini-batch
def loss_batch(loss_func, output, target, opt=None):
    loss = loss_func(output, target)

    metric_b = metrics_batch(output, target)
    if opt is not None:
        opt.zero_grad()
        loss.backward()
        opt.step()
    return loss.item(), metric_b

# develop a function to compute the loss value and the performance metric for the epoch
def loss_epoch(model, loss_func, dataset_dl, sanity_check=False, opt=None):
    running_loss = 0
    running_metric = 0
    len_data = len(dataset_dl.dataset)

    for xb, yb in dataset_dl:
        # move batcch to device
        xb = xb.to(device)
        yb = yb.to(device)
        # get model output
        output = model(xb)

        # get loss per batch
        loss_b, metric_b = loss_batch(loss_func, output, yb, opt)

        # update running loss
        running_loss += loss_b
        # update running metric
        if metric_b is not None:
            running_metric += metric_b
        
        # break the loop in case of sanity check
        if sanity_check is True:
            break

    # average loss value and metric value
    loss = running_loss / float(len_data)
    metric = running_metric / float(len_data)
    return loss, metric

import time
import copy

# develop train_val function
def train_val(model, params):
    # extract model parameters
    num_epochs = params['num_epochs']
    loss_func = params['loss_func']
    opt = params['optimizer']
    train_dl = params['train_dl']
    val_dl = params['val_dl']
    sanity_check = params['sanity_check']
    lr_scheduler = params['lr_scheduler']
    path2weights = params['path2weights']

    # keep a history of the loss and the metric value
    loss_history = {
        'train': [],
        'val': [],
    }

    metric_history = {
        'train': [],
        'val': [],
    }

    # save the best perfirming model
    best_model_wts = copy.deepcopy(model.state_dict())

    # initializer the best loss to an infinite value
    best_loss = float('inf')

    for epoch in range(num_epochs):
        start_time = time.time()

        # get current learning rate
        current_lr = get_lr(opt)
        print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs-1, current_lr))

        # train model on trainin dataset
        model.train()
        train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)

        # collect loss and metric for the training dataset
        loss_history['train'].append(train_loss)
        metric_history['train'].append(train_metric)

        # evaluate model on validation dataset
        model.eval()
        with torch.no_grad():
            val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
        
        # store the best model
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            # store weights into a local file
            torch.save(model.state_dict(), path2weights)
            print('Copied best model weights')

        # collect loss and metric for validation dataset
        loss_history["val"].append(val_loss)
        metric_history["val"].append(val_metric)

        # update the learning rate
        lr_scheduler.step()
        if current_lr != get_lr(opt):
            print('Loading best model weights!')
            model.load_state_dict(best_model_wts)

        # print the loss and accuracy values and return the trained model
        print('train loss: %.6f, dev loss: %.6f, accuracy: %.2f, time: %.4f s' %(train_loss, val_loss, 100*val_metric, time.time()-start_time))
        print('-'*10)

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, loss_history, metric_history

# define the training parameters
params_train = {
    'num_epochs':3,
    'optimizer':opt,
    'loss_func':loss_func,
    'train_dl':train_dl,
    'val_dl':val_dl,
    'sanity_check':True,
    'lr_scheduler':lr_scheduler,
    'path2weights':'./models/weights.pt',
}

# check the directory to save weights.pt
def createFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSerror:
        print('Error')
createFolder('./models')

# train model
# sanity_check를 True로 설정하여, 학습이 되는지 확인해봅니다.
model, loss_hist, metric_hist = train_val(model, params_train)

 학습을 위한 함수가 많이 기네요...ㅎㅎ

 sanity_check = True로 설정하여 1epoch당 1 mini batch, 총 3epoh 학습시켰습니다.

 학습이 진행되는 걸로 보아 위 함수가 잘 작동하네요!

 

 sanity_check = False로 80 epoch까지 학습을 시키다가 코랩이 끊겨버렸습니다!!!

 그때 결과를 보았을 때, 수렴이 되지 않았던 걸로 기억하네요ㅠㅠ

 아마 데이터셋도 너무 적고 AlexNet이 얇은 신경망이기 때문이라고 생각합니다....ㅎㅎ

 

 

 loss가 낮은 이유는 1epoch당 1 mini-batch를 계산하도록 설정했기 때문입니다. 전체 로스는 데이터셋 수로 나눠주므로 낮은 수가 나올 수 밖에 없습니다. 정확도는 0.31%인 것으로 보아 모델이 거의 예측을 못하고 있네요...ㅋㅋ 아마 400 epoch정도 학습시켜주면 수렴하지 않을까.... 생각하고 있습니다!

 

 아래 함수를 이용하면 저장했던 train_loss, val_loss를 표로 만들어, 시각적으로 확인할 수 있습니다!

 

# Train-Validation Progress
num_epochs=params_train["num_epochs"]

# plot loss progress
plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

 

 다음에는 VGGnet 논문 구현 포스팅으로 뵙겠습니다. 읽어주셔서 감사합니다..!!

반응형