논문 읽기/long tail

[논문 읽기] Class-Balanced Loss(2019), Class-Balanced Loss Based on Effective Number of Samples

AI 꿈나무 2021. 6. 10. 01:18
반응형

 안녕하세요, 오늘 읽은 논문은 Class-Balanced Loss Based on Effective Number of Sample 입니다.

 

 Class-Balanced Loss는 long-tailed data set에서 class imbalance 문제를 해결하기 위해 제안되었습니다. long-tailed data set은 몇 개의 클래스가 데이터의 대부분을 차지하고, 나머지 클래스가 차지하는 양은 적은 데이터셋을 의미합니다. 이 경우에 기존의 해결 방법은 re-sampling 및 re-weighting 같은 re-balancing 방법을 사용하는 것이었습니다. 해당 논문에서는 samples수가 증가함에 따라 새롭게 추가되는 정보의 양은 줄어든 다는 점에 집중하여 문제를 해결하려 하며, 이를 위해 각 데이터가 점이 아닌 영역으로 연관지어 데이터가 overlab 되는 정도를 측정합니다. 

 

 데이터가 적은 class에 해당하는 데이터가 추가되면 새롭게 추가되는 정보는 많을 것입니다. 각 데이터가 영역을 갖고 있다고 가정할 시에 다른 데이터와 overlab 될 확률이 적기 때문입니다. 반대로 데이터가 많은 class에 해당하는 데이터가 추가되면 해당 데이터는 다른 데이터와 overlab 될 확률이 많습니다. 이는 새로운 데이터가 새로운 정보를 거의 포함하고 있지 않다는 것을 의미합니다.

 

long tailed dataset 문제를 해결하는 기존 방법

 

 long tailed dataset은 위그림 처럼 적은 수의 클래스가 데이터의 대부분을 차지하고, 나머지 클래스가 데이터의 소수를 차지하고 있는 경우를 의미합니다. 이경우에 일반적으로 re-sampling과 cost-sensitive re-weighting 방법을 사용할 수 있습니다.

 

(1) re-sampling 방법

 re-sampling 방법은 over-sampling 방법과 under-sampling 방법이 있습니다. over-sampling 방법은 minor한 class로부터 샘플을 반복하여 추가하는 것입니다. 이는 모델의 과적합을 유발합니다. 또는 minor한 클래스에 대한 샘플들을 합성하거나 data augmentation으로 채워넣을 수 있습니다. 하지만 이러한 샘플들은 노이즈를 갖고 있어, 모델이 오류를 발생할 수 있습니다. 또한 새로운 데이터가 추가되어 학습 속도를 지연시킵니다.

 

 under-sampling 방법은 major 클래스에 대한 샘플을 적게 추출하는 것입니다. 이는 중요한 정보를 담고 있는 샘플을 버리는 위험이 발생할 수 있습니다. 하지만 over-sampling보다 under-sampling이 더 선호됩니다.

 

(2) Cost-Sensitive re-weighting

 cost-sensitive re-weighiting 방법은 주어진 데이터 분포를 맞추기 위해서 샘플에 가중치를 부가하는 것입니다. 주로 inverse class frequency 또는 smoothed version of inverse square root of class frequency 방법을 사용합니다. 또다른 방법은 RetinaNet이 사용하는 focal loss 방법입니다. focal loss는 hard example에 대하여 높은 가중치를 부여하고, easy sample에 대하여 낮은 가중치를 부여합니다. 또한 minor class에 높은 가중치, major class에 낮은 가중치를 부여할 수도 있습니다. 하지만 이런 방법은 샘플의 수와 샘플의 난이도 사이에 직접적인 연관이 존재하지 않습니다.

 

Effective Number of Samples

 해당 논문에서는 random covering의 단소화된 버전으로서 data sampling 과정을 공식화 합니다. 핵심 아이디어는 각 샘플이 한 점이 아니라 작은 근접 영역으로 연관짓는 것입니다.

 

 

 모든 데이터 수가 N이고 공간 S에 속해있다고 하겠습니다. 하나의 샘플의 영역 크기는 1입니다. 공간 S를 채우기 위해서는 무수히 많은 샘플을 뽑아야 하며, 각 샘플은 영역을 갖고 있으므로 샘플이 overlab 될 수 있습니다. overlab 되는 경우에 유의미한 정보를 갖고 있지 않다고 간주합니다. 즉 overlab되지 않으면 해당 데이터는 새로운 정보를 갖고 있는 것입니다.

 

 overlab을 어떻게 판단할까요? 논문에서는 확률을 사용합니다. p확률로 overlab되고, (1-p)확률로 overlab 되지 않습니다. p는 다음과 같이 정의합니다.

 

 

 E_(n-1)은 effective number이며 데이터의 평균 부피을 의미합니다. 즉, E_(n-1)은 이전 데이터가 공간 S에서 차지하고 있는 부피를 의미하며, 이 값이 커질수록 overlab할 확률이 높아집니다.

 

 effective number는 다음과 같이 정의됩니다.

 

 

 베타는 [0,1) 사이에 속하는 하이퍼 파라미터 이며, n이 증가함에 따른 effective number의 증가 속도를 조절합니다.

 

Class-Balanced Loss

 Class-Balanced Loss는 softmax cross-entropy, sigmoid cross-entropy, focal loss에 effective number의 역수를 곱한 것입니다. n이 증가함에 따라 effective number 값은 증가합니다. n이 적은 클래스에 높은 가중치를 가하고, n이 큰 클래스에 낮은 가중치를 가하기 위해 역수를 곱합니다.

 

 

(1) Cross-Balanced Softmax Cross-Entropy Loss

 

(2) Class-Balanced Sigmoid Cross-Entropy Loss

 

(3) Class-Balanced Focal Loss

 

PyTorch Code

코드 출처: https://github.com/vandit15/Class-balanced-loss-pytorch

import numpy as np
import torch
import torch.nn.functional as F



def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.
    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
    Args:
      labels: A float tensor of size [batch, num_classes].
      logits: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.
    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """    
    BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + 
            torch.exp(-1.0 * logits)))

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    focal_loss = torch.sum(weighted_loss)

    focal_loss /= torch.sum(labels)
    return focal_loss



def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.
    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.
    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss

 

Performance

 


참고자료

[1] https://yjchoi-95.gitbook.io/paper-review/cvpr-2019-class-balanced-loss-based-on-effective-number-of-samples

[2] https://github.com/vandit15/Class-balanced-loss-pytorch

[3] https://arxiv.org/abs/1901.05555

반응형