Python/PyTorch 공부

[PyTorch] Dice coefficient 을 PyTorch로 구현하기

AI 꿈나무 2021. 6. 25. 18:53
반응형

 안녕하세요, 이번 포스팅에서는 image segmentation 분야에서 자주 사용되는 metric인 Dice coefficient를 PyTorch로 구현해보겠습니다. 또한 이 dice coefficient를 loss로 활용하는 법도 살펴봅니다.

 

Dice coefficient

 dice coefficient는 주로 medical image analysis에서 사용됩니다. 그리고 예측값과 gt 사이의 overlap area에 2를 곱하고 예측값과 gt 영역을 합한 것으로 나눠줍니다. 이는 IoU와 매우 유사합니다.

 

 

 Dice를 boolean data(binary segmentation map)에 적용할 때, Dice coefficient는 F1 score와 동일합니다.

 

 

PyTorch 코드

아래 코드는 pred와 target을 입력받아, dice_loss와 dice_metric을 동시에 출력합니다. 

def dice_loss(pred, target, smooth = 1e-5):
    # binary cross entropy loss
    bce = F.binary_cross_entropy_with_logits(pred, target, reduction='sum')
    
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(2,3))
    union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
    
    # dice coefficient
    dice = 2.0 * (intersection + smooth) / (union + smooth)
    
    # dice loss
    dice_loss = 1.0 - dice
    
    # total loss
    loss = bce + dice_loss
    
    return loss.sum(), dice.sum()

 

 추가로 segmentaion에서 deep learning 모델을 한눈에 보야주는 그림도 함께 첨부합니다 ㅎㅎ

출처:https://arxiv.org/pdf/1906.11172.pdf


참고자료

[1] https://go-hard.tistory.com/117

[2] https://arxiv.org/pdf/1906.11172.pdf

반응형