반응형
안녕하세요! 이번 포스팅에서는 PyTorch에서 제공하는 ResNet을 불러오고, 마지막 FC layer를 수정하는 방법을 살펴보겠습니다.
import torch.nn as nn
from torch.nn import functional as F
from .utils_resnet import resnet18
class FaceNet_ResNet18(nn.Module):
def __init__(self, embedding_dimension=128, pretrained=False):
super().__init__()
self.model = resnet18(pretrained=pretrained)
# embedding
input_features_fc_layer = self.model.fc.in_features # fc layer 채널 수 얻기
self.model.fc = nn.Linear(input_features_fc_layer, embedding_dimension, bias=False) # fc layer 수정
def forward(self, images):
embedding = self.model(images) # embedding 생성
embedding = F.normalize(embedding, p=2, dim=1) # normalize
return embedding
resnet18 모델을 self.model에 저장합니다.,
self.model.fc.in_features로 fc layer 입력 채널 수를 얻습니다.
self.model.fc = nn.Linear ~ 로 수정하여 model의 fc layer를 원하는 layer로 변경할 수 있습니다.
감사합니다.
반응형
'Python > PyTorch 공부' 카테고리의 다른 글
[PyTorch] Dice coefficient 을 PyTorch로 구현하기 (3) | 2021.06.25 |
---|---|
[PyTorch] to_pil_image 명령어로 tensor를 pil image로 변경하기 (1) | 2021.06.15 |
[PyTorch] 모델 중간 레이어에서 특징 추출하기(get the intermediate features from the model) (3) | 2021.06.11 |
[PyTorch] pretrained VGG 불러오고, 파라미터 freeze 하기 (0) | 2021.06.11 |
[PyTorch] ShuffleSplit와 subset 함수를 사용하여 dataset 분할하기 (2) | 2021.06.11 |