Python/PyTorch 공부

[PyTorch] PyTorch에서 제공하는 ResNet을 불러와 마지막 FC layer 수정하기

AI 꿈나무 2021. 6. 12. 22:59
반응형

 안녕하세요! 이번 포스팅에서는 PyTorch에서 제공하는 ResNet을 불러오고, 마지막 FC layer를 수정하는 방법을 살펴보겠습니다.

 

import torch.nn as nn
from torch.nn import functional as F
from .utils_resnet import resnet18

class FaceNet_ResNet18(nn.Module):
    def __init__(self, embedding_dimension=128, pretrained=False):
        super().__init__()
        self.model = resnet18(pretrained=pretrained)
        
        # embedding
        input_features_fc_layer = self.model.fc.in_features # fc layer 채널 수 얻기
        self.model.fc = nn.Linear(input_features_fc_layer, embedding_dimension, bias=False) # fc layer 수정
        
    def forward(self, images):
        embedding = self.model(images) # embedding 생성
        embedding = F.normalize(embedding, p=2, dim=1) # normalize
        return embedding

 

 resnet18 모델을 self.model에 저장합니다.,

 self.model.fc.in_features로 fc layer 입력 채널 수를 얻습니다.

 self.model.fc = nn.Linear ~ 로 수정하여 model의 fc layer를 원하는 layer로 변경할 수 있습니다.

 

 감사합니다.

반응형