Python/PyTorch 공부

[PyTorch] Swish 활성화 함수 정의해서 사용하기

AI 꿈나무 2021. 3. 27. 23:40
반응형

 안녕하세요! PyTorch로 Swish 함수를 정의해서 사용하는 법을 알아보겠습니다ㅎㅎ

 

 Swish 함수는 깊은 신경망에서 ReLU보다 좋은 성능을 나타내는데요, 실제로 EfficientNet은 Swish 활성화 함수를 사용하고 MobileNetV3은 Swish 함수를 수정해서 h-Swish 함수를 사용하고 있습니다. 이 Swish 함수는 파이토치 공식 문서에서 명령어를 제공하고 있지 않아 직접 정의해서 사용해야 합니다.

 

그림 출처 : https://eehoeskrap.tistory.com/440

 

 아래와 같이 Swish 함수 클래스를 정의할 수 있습니다.

# Swish activation function
class Swish(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return x * self.sigmoid(x)

 

반응형