파이토치를 쓰다보니 같은 기능에 대해 두 방식(torch.nn, torch.nn.functional)으로 구현 된 것들이 있다.
관련된 글들을 찾아보니 결론은 두 방식 다 같은 결과를 제공해주며 편한 것으로 선택해서 개발하면 된다고 한다.
What is the difference between torch.nn and torch.nn.functional?
torch.nn.CrossEntropyLoss / torch.nn.functional.cross_entropy
naming 에서도 알 수 있듯이 torch.nn.functional은 함수고 torch.nn은 클래스로 정의되어 있다.
그렇기 때문에 torch.nn으로 구현한 클래스의 경우에는 attribute를 활용해 state를 저장하고 활용할 수 있고 torch.nn.functional로 구현한 함수의 경우에는 인스턴스화 시킬 필요 없이 사용이 가능하다.
각자 개발하는 스타일이나 편의에 맞춰서 사용하면 되며 둘 중에 어느게 좋고 나쁘고는 없는 것 같다.
그저 model class를 만들 때 init 부분에 torch.nn 클래스를 이용하여 모델을 정의해 버리거나 forward 진행할 때 직접 torch.nn.functional 함수를 이용하여 계산해주거나의 차이일 뿐!
torch.nn
import torch
import torch.nn as nn
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
torch.nn.functional
import torch
import torch.nn.functional as F
input = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)
loss = F.cross_entropy(input, target)
loss.backward()
위의 두 예제 처럼 대체적으로 두 방법은 같은 결과를 내며 차이가 없다. 그럼 cross entropy 에서 state로 들어가는 것은 무엇이 있을까?
pytorch 공식 문서에 torch.nn.CrossEntropyLoss 설명을 보면 파라미터로 weight가 있는데 이것이 state 변수라고 할 수 있다. 여기서 사용되는 weight는 각 class에 대한 가중치 정보다. 예를들어 10개의 class로 분류하는 문제에 손실 함수로 cross entropy loss를 사용하고 특정 클래스(7번)를 더 잘 찾고 싶다면, 7번 class 부분의 loss에 큰 값의 가중치를 곱해서 더 잘 학습시키게 만들 수 있다. 이를 위해서 처음 CrossEntropyLoss 클래스를 인스턴스화 할 때 weight 값을 인자로 보내 한 번만 설정하면 그 뒤로 학습시킬 때 계속 적용이 된다. 반면 torch.nn.functional.cross_entropy 함수에도 파라미터로 weight가 동일하게 있지만 이 weight 를 적용시켜서 loss 계산을 하고싶을 때는 매 번 함수를 호출할 때마다 인자로 weight 값을 넣어줘야 한다.
'Deep Learning > Pytorch' 카테고리의 다른 글
[개발팁] 'MultilabelStrarifiedKFold' : Multi-label classification 에 적용 가능한 strarification cross validator (0) | 2021.04.21 |
---|---|
num_workers & pin_memory in DataLoader (0) | 2021.04.18 |
[개발팁] Multi-label Classification에 쓸만한 전처리 모듈 'MultiLabelBinarizer' (0) | 2021.04.18 |
[Label Smoothing] 요약 정리 (0) | 2020.07.28 |
[NVIDIA APEX] Amp에 대해 알아보자 (Automatic Mixed Precision) (6) | 2020.07.14 |