반응형
안녕하세요, 이번 포스팅에서는 class를 정의하는 경우에 nn.Module이 아닌 nn.Sequential을 상속하여 사용하는 것에 대해 알아보겠습니다.
nn.Module 대신에 nn.Sequential을 subclass하면 어떤 이점이 있을까요??
바로 forward method를 작성하지 않아도 됩니다 ㅎㅎ
예시 코드를 살펴보겠습니다.
# Subclassing nn.Sequential to avoid writing the forward method.
class FeedFowardBlock(nn.Sequential):
def __init__(self, emb_size, expansion=4, drop_p=0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
위 코드는 FC-GELU-FC로 이루어진 layer를 구현한 코드입니다.
nn.Sequential을 상속하여 구현했기 때문에 forward method를 작성하지 않아도 되는데요, 잘 작동하는지 확인해보겠습니다.
# check
x = torch.randn(16,1,128).to(device)
model = FeedFowardBlock(128).to(device)
output = model(x)
print(output.shape)
잘 작동하네요 ㅎㅎ
반응형
'Python > PyTorch 공부' 카테고리의 다른 글
[PyTorch] Boolean value of Tensor with more than one value is ambiguous (0) | 2021.10.31 |
---|---|
[에러 해결] CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasCreate(handle) (3) | 2021.09.14 |
이미지 분류 신경망의 결과를 t-SNE 시각화하기 (0) | 2021.07.11 |
[PyTorch] ShuffleSplit와 subset 함수를 사용하여 dataset 분할하기 (0) | 2021.07.10 |
[PyTorch] VOC Segmentation 데이터셋 사용하기 (0) | 2021.06.25 |