반응형
안녕하세요! torch.bernoulli 함수를 활용해서 Stochastic depth 학습 하는 법을 알아보겠습니다.
이 포스팅은 stochastic depth 학습을 구현하는 법을 잊을까봐 기록합니다.
아래 class는 efficientnet에서 사용하는 bottlenet 입니다.
class BottleNeck(nn.Module):
expand = 6
def __init__(self, in_channels, out_channels, kernel_size, stride=1, se_ratio=4, p=0.5):
super().__init__()
self.p = torch.tensor(p).float() if stride == 1 else torch.tensor(1).float()
self.residual = nn.Sequential(
nn.Conv2d(in_channels, in_channels * BottleNeck.expand, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(in_channels * BottleNeck.expand, momentum=0.99, eps=1e-3),
Swish(),
nn.Conv2d(in_channels * BottleNeck.expand, in_channels * BottleNeck.expand, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=False, groups=in_channels*BottleNeck.expand),
nn.BatchNorm2d(in_channels * BottleNeck.expand, momentum=0.99, eps=1e-3),
Swish()
)
self.se = SEBlock(in_channels * BottleNeck.expand)
self.project = nn.Sequential(
nn.Conv2d(in_channels*BottleNeck.expand, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels, momentum=0.99, eps=1e-3)
)
self.shortcut = (stride == 1) and (in_channels == out_channels)
def forward(self, x):
if self.training:
if not torch.bernoulli(self.p):
return x
x_shortcut = x
x_residual = self.residual(x)
x_se = self.se(x_residual)
x = x_se * x_residual
x = self.project(x)
if self.shortcut:
x= x_shortcut + x
return x
여기에 확률이 0이면, torch.bernoulli 함수가 0을 반환하여 bottleneck은 죽고 입력값을 그대로 출력합니다. 만약 베르누이 함수 결과가 1이면 정상적으로 bottlenet이 작동합니다. 확인을 해보겠습니다.
확률 = 0 일때, torch.bernoulli 결과값 = 0 => 입력 값 그대로 출력
x = torch.randn(3, 16, 24, 24)
model = BottleNeck(x.size(1), x.size(1), 3, stride=1, p=0)
output = model(x)
x = (output == x)
print(output.size(), x[1])
확률 >0 일때, torch.bernoulli 결과값 = 1 => 정상적으로 bottleneck 작동
x = torch.randn(3, 16, 24, 24)
model = BottleNeck(x.size(1), x.size(1), 3, stride=1, p=1)
output = model(x)
x = (output == x)
print(output.size(), x[1])
반응형
'Python > PyTorch 공부' 카테고리의 다른 글
[PyTorch] 가중치 초기화 함수 정의하고 모델에 적용하기 (0) | 2021.05.29 |
---|---|
[PyTorch] PyTorch에서 제공하는 VOC dataset 불러와서 사용하기 (0) | 2021.05.05 |
[PyTorch] Swish 활성화 함수 정의해서 사용하기 (0) | 2021.03.27 |
[PyTorch] YOLOv3 학습을 위한 VOC2007 커스텀 데이터셋 생성하기 (2) | 2021.03.15 |
[PyTorch] 러닝 레이트 스케쥴러(Learning Rate Scheduler) ReducedLROnPlateau 함수 (2) | 2021.03.06 |