왜 nn.CrossEntropyLoss 사용시
F.softmax를 output activation으로 사용하지 않아도 되는가?
Cross-Entropy Loss (CELoss)와 Softmax를 결합한 형태로 간단하게 정리할 수 있기 때문이다.
PyTorch에서 제공하는 nn.CrossEntropyLoss( )는 아래의 간소화된 형태의 식을 사용한다.
$CELoss = -x_{i} + log\sum_{c=1}^{N}e^{x_{c}}$
위의 식은 Softmax와 CELoss를 각각 계산하는 것보다 한번에 계산하기에 computationally efficient 하기 때문이다.
그렇기에 PyTorch의 nn.CrossEntropyLoss( )는 내부적으로 Softmax 연산을 수행한다고 볼 수 있다.
구체적인 전개 과정은 다음과 같다.
$CELoss = -\sum_{c=1}^{N} y_{c}log P_{c}$
$y$는 label/ground truth 이며 (Target value, label, Class) One-Hot Encoding 되어 있어 1 또는 0의 값을 갖는다.
$P$는 0과 1사이의 값을 갖는 확률 값이며, $N$은 총 Class의 개수이다.
$y_{c}$의 Label 중에서 오직 하나의 $y_{c}$만 1이고 나머지가 0이다. (Due to One-Hot Encoding)
$CELoss = - logP_{c}$
따라서 위의 식이 성립한다. $P_{c}$는 Label이 1인 neuron의 probability output이다.
모든 $P_{c}$ 값들은 Neural Network의 Output Layer에 Softmax Activation을 적용하여 구한 것이다.
$P_{c} = \sigma(x_{i}) = \frac{e^{x_{i}}}{\sum_{c=1}^{N} e^{x_{c}}} $
$\sigma(x_{i})$는 한 Neuron의 Softmax output value이다. 즉, $y_{c}$가 1인 한 Neuro의 $P_{c}$ 값이 $\sigma(x_{i})$이다.
$CELoss = -log \frac{e^{x_{i}}}{\sum_{c=1}^{N} e^{x_{c}}} $
$CELoss = -x_{i} + log\sum_{c=1}^{N}e^{x_{c}}$
다음과 같이 정리가 가능하다.
즉, PyTorch에서 제공하는 nn.CrossEntropyLoss( ) 모듈은 CELoss와 Softmax를 하나의 식에 적용한 결과이다.
MNIST Fashion을 예시로 들어보면 위의 코드처럼 Output Layer에
x = F.softmax(self.fc3(x))
으로 작성하는 것이 아니라
x = self.fc3(x)
Linear Activation을 적용한 후 Loss Function을 nn.CrossEntropyLoss( ) 으로 작성해주는 것이 효율적이다.
- Bad Example
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
# input data(image) = x
# x.shape = torch.Size([64, 1, 28, 28])
x = x.view(-1, 784)
# x.shape = torch.Size([64, 784])
x = F.relu(self.fc1(x))
# x.shape = torch.Size([64, 1, 28, 28])
x = F.relu(self.fc2(x))
# x.shape = torch.Size([64, 1, 28, 28])
x = F.softmax(self.fc3(x))
# x.shape = torch.Size([64, 1, 28, 28])
return x
- Good Example
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
# input data(image) = x
# x.shape = torch.Size([64, 1, 28, 28])
x = x.view(-1, 784)
# x.shape = torch.Size([64, 784])
x = F.relu(self.fc1(x))
# x.shape = torch.Size([64, 1, 28, 28])
x = F.relu(self.fc2(x))
# x.shape = torch.Size([64, 1, 28, 28])
x = self.fc3(x)
# x.shape = torch.Size([64, 1, 28, 28])
return x