基于CIFAR-10的Pytorch深度学习模板构建
基于CIFAR-10的Pytorch深度学习模板构建
构建了一个基于CIFAR-10数据集的pytorch版本深度学习baseline,便于以后更快的迁移到其它深度学习任务中去。代码详情请参看GitHub,如有错误,请指正。
数据集加载
-
首先需要从CIFAR-10官网下载打包好的数据集
CIFAR-10数据集官网:http://www.cs.toronto.edu/~kriz/cifar.html
-
下载对应Python版本的数据集文件并解压
-
根据官网上的Python加载方式加载数据集
-
1 |
|
输出为:
1 |
|
经过以上解码加载之后,对返回值dict
取dict[b'labels']
和dict[b'data']
分别获取batch1中的标签和数据,同理可以依次获得batch2-5以及test_batch的标签和数据。
dict[b'labels']
返回一个长度为10000的列表,每个元素取值范围为0-9,分别对应10个类别的标签。
dict[b'data']
返回一个[10000, 3072]的矩阵,存储每张样本图片的RGB像素值。前1024(32x32)是R通道的像素数据,后面依次是B、G通道。
- 完成训练和测试集的样本的解压加载之后,开始第二步:构造数据集加载器类。这个类有三个主要的构造函数分别是:
__init__()
用于初始化数据集路径和定义一些数据增强Pipline。__getitem__()
用于依据索引获取对应数据增强后的样本(data+label)。__len__()
返回样本集的样本数据量。
具体数据集加载器类构造如下:
1 |
|
至此数据集加载器类构造完成。
网络模型搭建
网络模型为ResNet18,这里直接使用Pytorch官方提供的版本(手动更改最后全连接层的节点数为class_num=10)。
- ResNet18
1 |
|
工具函数及超参数设置
构造了两个工具函数calculation_accuracy()
和training_process_visualization()
分别用于计算准确率和训练过程的可视化。
- calculation_accuracy()
1 |
|
- training_process_visualization()
1 |
|
这个文件一般用来存储模型训练过程中的学习率,数据增强,batch size,损失函数等超参数。
1 |
|
训练和测试
训练步骤
- 初始化超参数实例
- 定义数据集加载器实例
- 定义模型、优化器、损失函数、学习率调整器
- 用于保存训练过程中的损失和准确率(可以不用)
- 通过两层
for
循环开始迭代训练- 取数据(
net.train()
) - 优化器梯度置零
- 前向传播预测
- 计算损失并反向传播
- 学习率调整
- 每个epoch测试一次(
net.eval()
)
- 取数据(
- 根据记录的数据可视化训练过程
完整的train.py
如下所示
1 |
|
Post author: jasonyang
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 3.0 unless stating additionally. 转载请注明出处。