重庆分公司,新征程启航
为企业提供网站建设、域名注册、服务器等服务
这篇文章将为大家详细讲解有关pytorch中如何计算交叉熵损失,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。
麻江网站制作公司哪家好,找创新互联建站!从网页设计、网站建设、微信开发、APP开发、成都响应式网站建设公司等网站项目制作,到程序开发,运营维护。创新互联建站2013年开创至今到现在10年的时间,我们拥有了丰富的建站经验和运维经验,来保证我们的工作的顺利进行。专注于网站建设就选创新互联建站。公式
首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:
其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。
测试代码(一维)
import torch import torch.nn as nn import math criterion = nn.CrossEntropyLoss() output = torch.randn(1, 5, requires_grad=True) label = torch.empty(1, dtype=torch.long).random_(5) loss = criterion(output, label) print("网络输出为5类:") print(output) print("要计算label的类别:") print(label) print("计算loss的结果:") print(loss) first = 0 for i in range(1): first = -output[i][label[i]] second = 0 for i in range(1): for j in range(5): second += math.exp(output[i][j]) res = 0 res = (first + math.log(second)) print("自己的计算结果:") print(res)
测试代码(多维)
import torch import torch.nn as nn import math criterion = nn.CrossEntropyLoss() output = torch.randn(3, 5, requires_grad=True) label = torch.empty(3, dtype=torch.long).random_(5) loss = criterion(output, label) print("网络输出为3个5类:") print(output) print("要计算loss的类别:") print(label) print("计算loss的结果:") print(loss) first = [0, 0, 0] for i in range(3): first[i] = -output[i][label[i]] second = [0, 0, 0] for i in range(3): for j in range(5): second[i] += math.exp(output[i][j]) res = 0 for i in range(3): res += (first[i] + math.log(second[i])) print("自己的计算结果:") print(res/3)
nn.CrossEntropyLoss()中的计算方法
注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。
在pytorch中,CrossEntropyLoss计算公式为:
CrossEntropyLoss带权重的计算公式为(默认weight=None):
关于“pytorch中如何计算交叉熵损失”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。