반응형
안녕하세요! 가중치 초기화 함수를 정의하고 모델에 적용해보도록 하겠습니다.
자꾸 까먹어서 작성합니다ㅎㅎ!
가중치 초기화 함수를 정의하는 방법은 (1) 모델 구현 코드 내에 가중치 초기화 함수 정의하여 사용하기, (2) 모델을 생성한 뒤에 가중치 초기화 함수 정의하여 사용하기. 두 가지 방법이 있습니다.
개인적으로 (2) 번 방법이 편하여 2번 방법을 작성하겠습니다.
우선 구현한 모델을 생성해야 합니다. 저는 현재 gan을 공부하는 중이므로 discriminator, generator 두 개를 구현했습니다 ㅎㅎ
model_dis = Discriminator().to(device)
model_gen = Generator(params).to(device)
가중치 초기화 함수를 정의합니다. 값은 제가 임의로 설정한 값입니다.
# 가중치 초기화
def initialize_weights(model):
classname = model.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(model.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(model.weight.data, 1.0, 0.02)
nn.init.constant_(model.bias.data,0)
elif classname.find('Linear') != -1:
nn.init.constant_(model.bias.data,1)
정의한 함수를 모델에 적용합니다.
# 가중치 초기화 적용
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);
끝. 감사합니다
반응형
'Python > PyTorch 공부' 카테고리의 다른 글
Google Colab에 파일 업로드하기 (0) | 2021.06.10 |
---|---|
[PyTorch] Albumentations 모듈 사용해서 이미지 transformation 적용하기. (0) | 2021.06.10 |
[PyTorch] PyTorch에서 제공하는 VOC dataset 불러와서 사용하기 (0) | 2021.05.05 |
[PyTorch] torch.bernoulli 를 활용한 Stochastic depth 학습 (0) | 2021.03.29 |
[PyTorch] Swish 활성화 함수 정의해서 사용하기 (0) | 2021.03.27 |