pytorch 6

BCEWithLogitsLoss 와 MultiLabelSoftMarginLoss 차이

(Binary / Multi-label / Multi-class) Classification 우리가 보통 Binary Classification(이진 분류) 문제를 풀 때 loss 계산을 위해 torch.nn.BCEWithLogitsLoss를 사용한다. torch.nn.BCELoss도 있는데 BCELoss에 sigmoid 함수를 함께 결합한 것이 BCEWithLogitsLoss다. 보통 이진 분류 모델의 output으로 logit 값이 나오기 때문에 loss 계산 전에 sigmoid를 거쳐야 하므로 해당 과정이 포함된 BCEWithLogitsLoss를 쓰면 편리하다. torch.nn.BCEWithLogitsLoss = torch.nn.BCEWithLoss + torch.sigmoid Binary Clas..

num_workers & pin_memory in DataLoader

pytorch를 이용해 딥러닝 모델을 학습시킬 때 custom dataset을 이용할 경우 torch.utils.data.Dataset으로 데이터셋을 정의하고(input data type, augmentation 등) torch.utils.data.DataLoader로 어떻게 데이터셋을 불러올 지(batch size, sampling 등) 정의한다. 학습을 시키다 보면 병목이 생기는 부분이 있는데 특히 데이터를 읽어서 가져올 때 시간이 오래 걸린다. 모델 학습을 하는데 시간을 써도 모자랄 판에 학습하기도 전에 불러오는 데이터에서 시간이 걸린다니... 즉, CPU를 이용해 데이터를 저장된 SSD나 HDD에서 읽어와 호스트의 메모리에 올리고 학습을 위해 GPU 메모리로 전달하는 과정에서 병목이 발생한다. 이..

[개발팁] Multi-label Classification에 쓸만한 전처리 모듈 'MultiLabelBinarizer'

Multi-label Classification problem 캐글에서 진행하고 있는 'Plant Pathology 2021 - FGVC8' competition을 진행하다가 찾게 된 유용한 모듈. Plant Pathology 2021 - FGVC8 Identify the category of foliar diseases in apple trees www.kaggle.com 해당 대회는 Multi-label Image Classification 문제로 사과나무 잎에 어떤 병에 걸렸는지 판별해야 한다. 제공되는 데이터셋의 label을 보면 문자열로 병명을 제공하고 있으며 건강할 경우 'healthy', 병에 걸려있는 경우 병 이름들을 공백으로 구분하여 보여주고 있다. 즉, 두 개 이상의 병에 걸려있는 경우..

[개발팁] torch.nn 과 torch.nn.functional 어느 것을 써야 하나?

파이토치를 쓰다보니 같은 기능에 대해 두 방식(torch.nn, torch.nn.functional)으로 구현 된 것들이 있다. 관련된 글들을 찾아보니 결론은 두 방식 다 같은 결과를 제공해주며 편한 것으로 선택해서 개발하면 된다고 한다. What is the difference between torch.nn and torch.nn.functional? What is the difference between torch.nn and torch.nn.functional? They look like a little same… so, is there any difference between them? discuss.pytorch.org torch.nn.CrossEntropyLoss / torch.nn.funct..

[Label Smoothing] 요약 정리

Label Smoothing논문 1. Label Smoothing? 모델이 Ground Truth(GT)를 정확하게 예측하지 않아도 되게 만들어 주는 것. 모델이 정확하지 않은 학습 데이터셋에 치중되는 경향(overconfident)을 막아 calibration 및 regularization 효과를 가질 수 있다. 2. Why? 보통 학습에 사용되는 데이터셋은 사람이 직접 annotation 하기 때문에 실수의 가능성이 존재하며 100% 정확한 GT 데이터로 생각하면 안된다. 즉, GT데이터가 잘 정제되어 있지 않다면 오분류된 데이터(mislabeled data)가 있을 수 있어 모델이 이를 유하게 학습시키도록 하면 더 효과적이기 때문이다. 3. 정말 좋아? Label smoothing은 mislabel..

[NVIDIA APEX] Amp에 대해 알아보자 (Automatic Mixed Precision)

version update 20-07-25 : amp 모듈이 pytorch 1.5.0 버전부터 기본 라이브러리에 추가되고 있음! pytorch 를 이용해 모델을 학습하다 보면 더 많은 batch size를 학습시키고 싶고 더 빠르게 학습시키고 싶은 생각이 굴뚝같아진다... 하지만 우리가 가지고 있는 데스크탑이나 서버 환경을 물리적으로 확장시키는 방법은 돈이 많이 든다. 돈을 들이지 말고 코드 몇 줄 만으로 모델을 최적화 시키고 batch size를 늘릴 수 없을까? https://github.com/NVIDIA/apex NVIDIA/apex A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch - NV..