(Binary / Multi-label / Multi-class) Classification
우리가 보통 Binary Classification(이진 분류) 문제를 풀 때 loss 계산을 위해 torch.nn.BCEWithLogitsLoss를 사용한다. torch.nn.BCELoss도 있는데 BCELoss에 sigmoid 함수를 함께 결합한 것이 BCEWithLogitsLoss다. 보통 이진 분류 모델의 output으로 logit 값이 나오기 때문에 loss 계산 전에 sigmoid를 거쳐야 하므로 해당 과정이 포함된 BCEWithLogitsLoss를 쓰면 편리하다.
- torch.nn.BCEWithLogitsLoss = torch.nn.BCEWithLoss + torch.sigmoid
Binary Classification 문제 말고도 Multi Label Classification(다중 레이블 분류)에서도 loss 계산으로 BCEWithLogitsLoss를 사용할 수 있고 실제로 많이 사용하고 있다. 그 이유는 예를 들어 5개의 label을 분류하는 Multi Label Classification 문제가 있다고 가정해보자. 이때, 모델의 최종 output이 5개의 logit 값이 나오고 각각 sigmoid 함수를 거쳐 일정 threshold 값 이상이 되면 해당 label이 있고 그렇지 않으면 없다고 판단하여 '[1, 0, 1, 0 ,0]' 이런 식으로 나올 것이다. 이때 각 label들은 자기 자신에 해당하는 logit 값만 잘 찾아내면 되므로 5개의 독립적인 이진 분류 문제로 접근할 수 있어 Binary Cross Entropy 함수를 loss 계산에 활용할 수 있는 것이다. 반면에 Multi Class Classification은 최종 output에 softmax를 수행하기 때문에 각 class에 해당하는 logit 값이 서로 영향을 주게 된다. 그래서 Multi Class Classification 문제를 풀 때는 torch.nn.CrossEntropyLoss를 사용한다.
그런데 pytorch를 사용하다보니 torch.nn.MultiLabelSoftMarginLoss 모듈을 발견하게 되었고 문서 설명을 읽으면서 '이거 결국 BCEWithLogitsLoss랑 같은 일 하는 모듈 아니야?'라는 생각을 하게 되었고 좀 더 자세하게 찾아보았다.
BCEWithLogitsLoss / MultiLabelSoftMarginLoss
결론부터 말하자면 둘은 같은 역할을 하는 모듈이다. 즉 Multi label classification 문제에서 loss 계산을 위해 둘 중에 아무거나 사용해도 상관은 없다. 하지만 약간의 출력 값이 다를 수가 있어서 왜 다른지 한 번 확인해보고 사용하면 좋을 것 같다.
위 토론 글에서 ptrblck 형님께서 두 모듈은 같은 일을 하며 두 모듈의 출력 값이 같다는 예시를 들어주셨는데 pytorch 지금 버전으로 돌려보면 값이 달라 오류가 나는 것을 확인할 수 있다.
x = Variable(torch.randn(10, 3))
y = Variable(torch.FloatTensor(10, 3).random_(2))
# double the loss for class 1
class_weight = torch.FloatTensor([1.0, 2.0, 1.0])
# double the loss for last sample
element_weight = torch.FloatTensor([1.0]*9 + [2.0]).view(-1, 1)
element_weight = element_weight.repeat(1, 3)
bce_criterion = nn.BCEWithLogitsLoss(weight=None, reduce=False)
multi_criterion = nn.MultiLabelSoftMarginLoss(weight=None, reduce=False)
bce_criterion_class = nn.BCEWithLogitsLoss(weight=class_weight, reduce=False)
multi_criterion_class = nn.MultiLabelSoftMarginLoss(weight=class_weight, reduce=False)
bce_criterion_element = nn.BCEWithLogitsLoss(weight=element_weight, reduce=False)
multi_criterion_element = nn.MultiLabelSoftMarginLoss(weight=element_weight, reduce=False)
bce_loss = bce_criterion(x, y)
multi_loss = multi_criterion(x, y)
bce_loss_class = bce_criterion_class(x, y)
multi_loss_class = multi_criterion_class(x, y)
bce_loss_element = bce_criterion_element(x, y)
multi_loss_element = multi_criterion_element(x, y)
print(bce_loss - multi_loss)
print(bce_loss_class - multi_loss_class)
print(bce_loss_element - multi_loss_element)
우선 BCEWithLogitsLoss 모듈에 대한 설명을 보면 BCEWithLogitsLoss는 기본적으로 one single class에 대한 loss를 계산하기 위해 나온 모듈이라고 한다. 그리고 이를 확장시켜서 multi label에 대해서 활용이 가능하다고 한다. 그래서 만일 데이터가 5개의 class를 가진 multi label이라면 각 class 별로 BCEWithLogitsLoss 값이 나오게 된다.
그리고 MultiLabelSoftMarginLoss 모듈에 대한 설명을 보면 각 label에 대한 BCE 값을 평균 낸다고 한다. 즉, Multi Label Classification 문제는 애초에 label 별로 계산된 loss 값이 개별적으로 쓰일 필요가 없기 때문에 전체 class 개수로 나눈 값을 반환해주고 있다.
예시 코드로 확인해보자.
>>> target = torch.randint(0,2,(3,5), dtype=torch.float32)
>>> target
tensor([[1., 1., 0., 0., 1.],
[1., 1., 0., 1., 0.],
[1., 0., 1., 0., 1.]])
>>> output = torch.randn((3,5))
>>> output
tensor([[-0.7480, 0.6282, -1.8508, -0.5212, -0.6264],
[ 0.0744, -0.2861, 2.3179, -1.2633, 1.0343],
[ 0.6806, -0.0375, -1.4666, -0.2414, 0.3977]])
우선 5개의 label을 갖는 3개의 데이터를 만들어 보았다. target은 정답 데이터이고 output은 모델에서 나온 output logits으로 생각하면 되겠다.
>>> bce_criterion = nn.BCEWithLogitsLoss(reduction='none')
>>> bce_criterion(output, target)
tensor([[1.1355, 0.4276, 0.1459, 0.4661, 1.0546],
[0.6567, 0.8464, 2.4118, 1.5123, 1.3385],
[0.4097, 0.6746, 1.6742, 0.5797, 0.5139]])
>>> multi_criterion = nn.MultiLabelSoftMarginLoss(reduction='none')
>>> multi_criterion(output, target)
tensor([0.6460, 1.3531, 0.7704])
>>> bce_criterion(output, target).mean(axis=-1)
tensor([0.6460, 1.3531, 0.7704])
BCEWithLogitsLoss, MultiLabelSoftMarginLoss 모듈 둘 다 기본적으로 reduction 파라미터의 default 값이 'mean'으로 되어있어 정확한 출력 값 비교를 위해 'none'으로 두고 수행을 돌려보았다.
(위 토론글에서 ptrblck 형님께서 reduce=False로 설정하신 것과 같은 의미인데 pytorch 상위 버전에서는 reduce 파라미터가 사라지고 reduction 파라미터를 써야 한다.)
결과를 보게 되면 BCEWithLogitsLoss의 출력 값은 3x5의 형식으로 데이터의 각 label별로 계산한 loss 값을 그대로 보여주고 있다. 반면 MultiLabelSoftMarginLoss는 각 데이터마다 label별로 계산한 loss 값을 평균 내어 보여주고 있다. 따라서 BCEWithLogitsLoss의 출력 값을 데이터 별로 평균 내보면 MultiLabelSoftMarginLoss 값과 같은 것을 확인할 수 있다.
결론
BCEWithLogitsLoss와 MultiLabelSoftMarginLoss는 같은 일을 한다.
하지만 둘은 Binary Classification을 위해 만들어졌는가 Multi Label Classification을 위해서 만들어졌는가의 차이가 있다. 그래서 loss 계산 결과에 있어 형태의 차이가 있고 값은 같은 것을 확인할 수 있다. 보통 우리는 reduction을 'mean'으로 두고(default) 사용하기 때문에 두 모듈을 섞어 사용해도 차이가 없을 것으로 예상된다. 하지만 이런 작은 차이를 알아둔다면 나중에 디버깅할 때 도움이 될 수도 있을 것 같다.
>>> bce_criterion = nn.BCEWithLogitsLoss()
>>> multi_criterion = nn.MultiLabelSoftMarginLoss()
>>> print(bce_criterion(output, target), multi_criterion(output, target))
tensor(0.9232) tensor(0.9232)
'Deep Learning > Pytorch' 카테고리의 다른 글
[개발팁] 'MultilabelStrarifiedKFold' : Multi-label classification 에 적용 가능한 strarification cross validator (0) | 2021.04.21 |
---|---|
num_workers & pin_memory in DataLoader (0) | 2021.04.18 |
[개발팁] Multi-label Classification에 쓸만한 전처리 모듈 'MultiLabelBinarizer' (0) | 2021.04.18 |
[개발팁] torch.nn 과 torch.nn.functional 어느 것을 써야 하나? (0) | 2020.08.07 |
[Label Smoothing] 요약 정리 (0) | 2020.07.28 |