Deep Learning/Pytorch

[개발팁] torch.nn 과 torch.nn.functional 어느 것을 써야 하나?

족제비다아 2020. 8. 7. 14:08

파이토치를 쓰다보니 같은 기능에 대해 두 방식(torch.nn, torch.nn.functional)으로 구현 된 것들이 있다.

관련된 글들을 찾아보니 결론은 두 방식 다 같은 결과를 제공해주며 편한 것으로 선택해서 개발하면 된다고 한다.

 

What is the difference between torch.nn and torch.nn.functional?

 

What is the difference between torch.nn and torch.nn.functional?

They look like a little same… so, is there any difference between them?

discuss.pytorch.org

 

torch.nn.CrossEntropyLoss / torch.nn.functional.cross_entropy


naming 에서도 알 수 있듯이 torch.nn.functional은 함수고 torch.nn은 클래스로 정의되어 있다.

그렇기 때문에 torch.nn으로 구현한 클래스의 경우에는 attribute를 활용해 state를 저장하고 활용할 수 있고 torch.nn.functional로 구현한 함수의 경우에는 인스턴스화 시킬 필요 없이 사용이 가능하다.

 

각자 개발하는 스타일이나 편의에 맞춰서 사용하면 되며 둘 중에 어느게 좋고 나쁘고는 없는 것 같다.

그저 model class를 만들 때 init 부분에 torch.nn 클래스를 이용하여 모델을 정의해 버리거나 forward 진행할 때 직접 torch.nn.functional 함수를 이용하여 계산해주거나의 차이일 뿐!

 

torch.nn

import torch
import torch.nn as nn

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

 

torch.nn.functional

import torch
import torch.nn.functional as F

input = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)
loss = F.cross_entropy(input, target)
loss.backward()

 

위의 두 예제 처럼 대체적으로 두 방법은 같은 결과를 내며 차이가 없다. 그럼 cross entropy 에서 state로 들어가는 것은 무엇이 있을까?

 

pytorch 공식 문서에 torch.nn.CrossEntropyLoss 설명을 보면 파라미터로 weight가 있는데 이것이 state 변수라고 할 수 있다. 여기서 사용되는 weight는 각 class에 대한 가중치 정보다. 예를들어 10개의 class로 분류하는 문제에 손실 함수로 cross entropy loss를 사용하고 특정 클래스(7번)를 더 잘 찾고 싶다면, 7번 class 부분의 loss에 큰 값의 가중치를 곱해서 더 잘 학습시키게 만들 수 있다. 이를 위해서 처음 CrossEntropyLoss 클래스를 인스턴스화 할 때 weight 값을 인자로 보내 한 번만 설정하면 그 뒤로 학습시킬 때 계속 적용이 된다. 반면 torch.nn.functional.cross_entropy 함수에도 파라미터로 weight가 동일하게 있지만 이 weight 를 적용시켜서 loss 계산을 하고싶을 때는 매 번 함수를 호출할 때마다 인자로 weight 값을 넣어줘야 한다.