栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

不平衡数据分类网络-Pytorch试验

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

不平衡数据分类网络-Pytorch试验

不平衡数据分类网络-Pytorch试验

注意:本试验在参考此代码的基础上。为方便起见,之后简称A

1. 准备数据(CIFAR-10数据集)

1.1 制作不平衡数据集 (下载的为平衡数据集)

脚本:cifar10_to_png.py脚本:image2train_test.py

直接从原始CIFAR-10采样,通过控制每一类采样的个数,就可以产生类别不平衡的训练数据。
步骤:
1)在A提取图片的基础上 ;

2)将数据集分成训练集和测试集 ;

3)在训练集中根据自定义的类别占比,采样不同数量的类别,得到不平衡训练集;

4)在测试集中,采样相同小数量的类别,得到平衡测试集。

PS:为了尽可能近似实际项目中的情况,故训练集中的样本数量设置的比较少。
且第二步的意义是为了防止数据泄露。

2. 数据加载 (参考A) 3. 搭建网络 (参考A)

采用的VGG16网络 参考此博客介绍

4. 训练网络

4.1 训练普通交叉熵损失函数的网络

loss = celoss(outputs, labels)  # 计算损失值

4.2 训练Class-Balanced Loss 的网络

Class-Balanced Loss based on Effective Number of Samples论文解读参考此博客

β beta β为常数,论文中设置为 ( N − 1 ) / N (N-1)/N (N−1)/N, N N N 为总样本数目。 n y n_y ny​ 为第 y y y 类的样本数目。

训练时遇到bugUserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at …c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

解决办法:
这是pytorch1.9的bug,下个版本将修复,我将pytorch降级成1.8就不报这个错了。

5. 训练结果

5.1 第一组试验

数据集:
1)训练集:10类不平衡样本按如下比例分配

trainnum = 1000
class_ratio = [19, 17, 15, 13, 11, 9, 7, 5, 3, 1]

2)测试集:10类平衡样本每类数量为:

testnum = 50

混淆矩阵如有不懂参考此博客,具体代码实现。

e p o c h = 500 epoch = 500 epoch=500 时,在测试集上得到的混淆矩阵如下:

e p o c h = 500 epoch = 500 epoch=500 时,利用类平衡损失函数,在测试集得到的混淆矩阵为:

图1 交叉熵损失函数
图2 类平衡损失函数

e p o c h = 2500 epoch = 2500 epoch=2500 时,在测试集上得到的混淆矩阵如下:
e p o c h = 2000 epoch = 2000 epoch=2000 时,利用类平衡损失函数,在测试集得到的混淆矩阵为:

图1 交叉熵损失函数
图2 类平衡损失函数

结论:类平衡损失函数效果不明显。

可能有如下原因:

1)整体样本数量不是特别多,同类样本之间的特征不是特别统一。后续补做试验

2)没根据Loss去判断网络是否收敛。后续修改程序

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/445274.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号