Python/PyTorch 공부

[PyTorch] 가중치 초기화 함수 정의하고 모델에 적용하기

AI 꿈나무 2021. 5. 29. 23:50
반응형

 안녕하세요! 가중치 초기화 함수를 정의하고 모델에 적용해보도록 하겠습니다.

 

 자꾸 까먹어서 작성합니다ㅎㅎ!

 

 가중치 초기화 함수를 정의하는 방법은 (1) 모델 구현 코드 내에 가중치 초기화 함수 정의하여 사용하기, (2) 모델을 생성한 뒤에 가중치 초기화 함수 정의하여 사용하기. 두 가지 방법이 있습니다.

 

 개인적으로 (2) 번 방법이 편하여 2번 방법을 작성하겠습니다.

 

 우선 구현한 모델을 생성해야 합니다. 저는 현재 gan을 공부하는 중이므로 discriminator, generator 두 개를 구현했습니다 ㅎㅎ

model_dis = Discriminator().to(device)
model_gen = Generator(params).to(device)

 

 가중치 초기화 함수를 정의합니다. 값은 제가 임의로 설정한 값입니다.

# 가중치 초기화
def initialize_weights(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data,0)
    elif classname.find('Linear') != -1:
        nn.init.constant_(model.bias.data,1)

 

 정의한 함수를 모델에 적용합니다.

# 가중치 초기화 적용
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);

 

 끝. 감사합니다

반응형