pytorch交叉熵损失函数的weight参数的使用
首先
必须将权重也转为Tensor的cuda格式;
然后
将该class_weight作为交叉熵函数对应参数的输入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
补充:关于pytorch的CrossEntropyLoss的weight参数
首先这个weight参数比想象中的要考虑的多
你可以试试下面代码
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)
这里的手动计算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加权呢?
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)
手算发现,并不是单纯的那权重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
发现了么,加权后,除以的是权重的和,不是数目的和。
我们再验证一遍:
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明
(资源库 www.zyku.net)
原文链接:https://niecongchong.blog.csdn.net/article/details/86594621
上一篇:解决python中os.system调用exe文件的问题
栏 目:Python教程
下一篇:matplotlib画混淆矩阵与正确率曲线的实例代码
本文标题:pytorch交叉熵损失函数的weight参数的使用
本文地址:https://www.zyku.net/python/9829.html
您可能感兴趣的文章
- 02-10pytorch 使用半精度模型部署的操作
- 02-10pytorch 中nn.Dropout的使用说明
- 02-10浅谈pytorch中的dropout的概率p
- 02-10基于PyTorch实现一个简单的CNN图像分类器
- 02-10pytorch中.to(device) 和.cuda()的区别说明
- 02-10Pytorch 中net.train 和 net.eval的使用说明
- 02-10Pytorch 如何训练网络时调整学习率
- 02-10pytorch model.cuda()花费时间很长的解决
- 02-10Pytorch GPU内存占用很高,但是利用率很低如何解决
- 02-09PyTorch 如何自动计算梯度
- 02-09pytorch 实现计算 kl散度 F.kl_div()
- 02-09pytorch中LN(LayerNorm)及Relu和其变相的输出操作
- 02-09pytorch 实现多个Dataloader同时训练
- 02-09解决pytorch trainloader遇到的多进程问题
- 02-09Pytorch使用shuffle打乱数据的操作
- 02-08PyTorch梯度裁剪避免训练loss nan的操作
- 02-08我对PyTorch dataloader里的shuffle=True的理解
- 02-08浅谈pytorch中为什么要用 zero_grad() 将梯度清零
- 02-08pytorch DataLoader的num_workers参数与设置大小详解
- 02-08pytorch 带batch的tensor类型图像显示操作
最近更新
阅读排行
猜你喜欢
- 02-08Python基础学习之条件控制语句小结
- 07-08阿里云自定义RAM策略之【对象存储服务
- 01-11宝贝画画涂鸦-宝贝画画涂鸦应用软件功
- 04-24Python 条件判断的缩写方法
- 01-11天立阅卷-天立阅卷应用软件功能介绍
- 10-19iphone13屏幕发黄发暗怎么调节
- 09-12华为荣耀手机怎么设置虚拟按键导航
- 12-25轻松家居遥控器-轻松家居遥控器应用软
- 01-04ucloud云计算实现cdn文件上传
- 03-17Swiper修改轮播图箭头的大小和颜色