Python/PyTorch 공부

[PyTorch] torch.bernoulli 를 활용한 Stochastic depth 학습

AI 꿈나무 2021. 3. 29. 04:14
반응형

 안녕하세요! torch.bernoulli 함수를 활용해서 Stochastic depth 학습 하는 법을 알아보겠습니다.

 

 이 포스팅은 stochastic depth 학습을 구현하는 법을 잊을까봐 기록합니다.

 

 아래 class는 efficientnet에서 사용하는 bottlenet 입니다.

class BottleNeck(nn.Module):
    expand = 6
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, se_ratio=4, p=0.5):
        super().__init__()
        self.p = torch.tensor(p).float() if stride == 1 else torch.tensor(1).float()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * BottleNeck.expand, 1, stride=stride, padding=0, bias=False),
            nn.BatchNorm2d(in_channels * BottleNeck.expand, momentum=0.99, eps=1e-3),
            Swish(),
            nn.Conv2d(in_channels * BottleNeck.expand, in_channels * BottleNeck.expand, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=False, groups=in_channels*BottleNeck.expand),
            nn.BatchNorm2d(in_channels * BottleNeck.expand, momentum=0.99, eps=1e-3),
            Swish()
        )

        self.se = SEBlock(in_channels * BottleNeck.expand)

        self.project = nn.Sequential(
            nn.Conv2d(in_channels*BottleNeck.expand, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels, momentum=0.99, eps=1e-3)
        )

        self.shortcut = (stride == 1) and (in_channels == out_channels)

    def forward(self, x):
        if self.training:
            if not torch.bernoulli(self.p):
                return x

        x_shortcut = x
        x_residual = self.residual(x)
        x_se = self.se(x_residual)

        x = x_se * x_residual
        x = self.project(x)

        if self.shortcut:
            x= x_shortcut + x

        return x

 

 여기에 확률이 0이면, torch.bernoulli 함수가 0을 반환하여 bottleneck은 죽고 입력값을 그대로 출력합니다. 만약 베르누이 함수 결과가 1이면 정상적으로 bottlenet이 작동합니다. 확인을 해보겠습니다.

 

 확률 = 0 일때, torch.bernoulli 결과값 = 0 => 입력 값 그대로 출력

x = torch.randn(3, 16, 24, 24)
model = BottleNeck(x.size(1), x.size(1), 3, stride=1, p=0)
output = model(x)
x = (output == x)
print(output.size(), x[1])

 

 확률 >0 일때, torch.bernoulli 결과값 = 1 => 정상적으로 bottleneck 작동

x = torch.randn(3, 16, 24, 24)
model = BottleNeck(x.size(1), x.size(1), 3, stride=1, p=1)
output = model(x)
x = (output == x)
print(output.size(), x[1])

반응형