Python/PyTorch 공부

[PyTorch] pretrained VGG 불러오고, 파라미터 freeze 하기

AI 꿈나무 2021. 6. 11. 16:50
반응형

 안녕하세요! 이번 포스팅에서는 pretrained VGG net을 불러오고, 모델의 파라미터 freeze를 하는 방법을 알아보겠습니다 ㅎㅎ!!

 

우선 pretrained VGG19를 불러옵니다.

# pretrained VGG19를 불러옵니다
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_vgg = models.vgg19(pretrained=True).features.to(device).eval()

 

파라미터를 freeze 합니다.

# 파라미터를 freeze 합니다.
for param in model_vgg.parameters():
    param.requires_grad_(False)

 

 이제 pre-trained VGG19을 원하시는 작업에 사용하시면 됩니다!

반응형