논문 읽기/Segmentation

[논문 읽기] 구현 코드로 살펴보는 SegNet(2016), A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

AI 꿈나무 2021. 6. 10. 14:44
반응형

 안녕하세요, 오늘 읽은 논문은 SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 입니다.

 

 SegNet은 semantic pixel-wise segmentation을 위한 fully convolutional neural network architecture 입니다. encoder network와 decoder network로 구성되며 이후에 pixel-wise classification layer가 따라옵니다. encoder은 VGG16 network와 동일하게 13 conv layer로 구성되며, decoder network의 역할은 저해상도의 pixel-wise classification을 위해encoer feature map을 입력 해상도로 map 합니다. 이 과정은 encoder의 max pool 연산을 decoder에서 non-linear upsampling으로 sparse한 피쳐맵을 생성하고 해당 피쳐맵에 conv 연산을 통해 dense한 feature map을 생성합니다.

 

SegNet

 

 SegNet은 encoder network와 encoder network에 해당하는 decoder network를 갖고있으며 마지막으로 pixelwise classification layer가 존재합니다. encoder network는 VGG16 에서 마지막 3개의 fc layer를 제거한 13개의 conv layer로 구성됩니다. 각 encoder layer에 해당하는 decoder layer를 갖고 있으므로 decoder network도 13개의 conv로 구성됩니다. 마지막으로 각 픽셀을 독립적으로 분류하는 multi-class soft-max 레이어가 존재합니다.

 

 

 PyTorch Code로 SegNet을 살펴보겠습니다. 제가 직접 구현해보았습니다. 또한 13개의 conv layer가 아닌 5개의 conv layer로만 구성했습니다. 논문에서는 encoder에서 maxpooling을 적용하기 전에 max value의 위치를 저장하여 decoder에서 up sampling을 할때 사용합니다. 하지만 편의상 pytorch의 upsample 연산으로 구현했습니다.

 

class SegNet(nn.Module):
    def __init__(self, params):
        super().__init__()
        C_in, H_in, W_in = 1, 128, 192
        init_f = 16
        num_output = 1
        
        # encoder
        self.conv1 = nn.Conv2d(C_in, init_f, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(init_f, 2*init_f, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2*init_f, 4*init_f, 3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(4*init_f, 8*init_f, 3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(8*init_f, 16*init_f, 3, stride=1, padding=1)
        
        # decoder
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, 3, stride=1, padding=1)
        self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, 3, stride=1, padding=1)
        self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, 3, stride=1, padding=1)
        self.conv_up4 = nn.Conv2d(2*init_f, 1*init_f, 3, stride=1, padding=1)
        self.conv_out = nn.Conv2d(init_f, num_output, 3, stride=1, padding=1)
        
    
    def forward(self, x):
      # encoder
      x = F.relu(self.conv1(x))
      x = F.max_pool2d(x,2,2)

      x = F.relu(self.conv2(x))
      x = F.max_pool2d(x,2,2)

      x = F.relu(self.conv3(x))
      x = F.max_pool2d(x,2,2)

      x = F.relu(self.conv4(x))
      x = F.max_pool2d(x,2,2) 

      x = F.relu(self.conv5(x))

      # decoder
      x = self.upsample(x)
      x = F.relu(self.conv_up1(x))

      x = self.upsample(x)
      x = F.relu(self.conv_up2(x)) 

      x = self.upsample(x)
      x = F.relu(self.conv_up3(x))

      x = self.upsample(x)
      x = F.relu(self.conv_up4(x))

      x = self.conv_out(x)
      return x

 

 손실함수는 다음과 같이 구현합니다.

def dice_loss(pred, target, smooth = 1e-5):
    intersection = (pred * target).sum(dim=(2,3))
    union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = 2.0 * (intersection + smooth) / (union + smooth)
    loss = 1.0 - dice
    return loss.sum(), dice.sum()
    

def loss_func(pred, target):
    bce = F.binary_cross_entropy_with_logits(pred, target, reduction='sum')
    div, _ =dice_loss(pred, target)
    loss = bce + div
    return loss

 

 추후에 SegNet을 PyTorch로 구현하고 학습까지 해보는 포스팅을 업로드 하도록 하겠습니다.

 

Performance

 

 


참고자료

[1] https://arxiv.org/pdf/1511.00561.pdf

반응형