Python/PyTorch 공부

[PyTorch] nn.Sequential 을 상속받아 Class 정의하기

AI 꿈나무 2021. 8. 4. 15:21
반응형

 안녕하세요, 이번 포스팅에서는 class를 정의하는 경우에 nn.Module이 아닌 nn.Sequential을 상속하여 사용하는 것에 대해 알아보겠습니다.

 

 nn.Module 대신에 nn.Sequential을 subclass하면 어떤 이점이 있을까요??

 

 바로 forward method를 작성하지 않아도 됩니다 ㅎㅎ

 

 예시 코드를 살펴보겠습니다.

 

# Subclassing nn.Sequential to avoid writing the forward method.
class FeedFowardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion=4, drop_p=0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

 

 위 코드는 FC-GELU-FC로 이루어진 layer를 구현한 코드입니다. 

 nn.Sequential을 상속하여 구현했기 때문에 forward method를 작성하지 않아도 되는데요, 잘 작동하는지 확인해보겠습니다.

 

# check
x = torch.randn(16,1,128).to(device)
model = FeedFowardBlock(128).to(device)
output = model(x)
print(output.shape)

 

 

잘 작동하네요 ㅎㅎ

반응형