当前位置:网站首页 > 技术博客 > 正文

pytorch版本



要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下

[!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料



MNIST是计算机视觉领域中最为基础的一个数据集,也是很多人第一个神经网络模型

MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。
MNIST手写数字识别中的部分样本

MNIST中所有样本都会将原本28*28的灰度图转换为长度为784的一维向量作为输入,其中每个元素分别对应了灰度图中的灰度值。MNIST使用一个长度为10的one-hot向量作为该样本所对应的标签,其中向量索引值对应了该样本以该索引为结果的预测概率。


2.1 任务目的

如本文标题所示,MNIST手写数字识别的主要目为:训练出一个模型,让这个模型能够对手写数字图片进行分类。

2.2 开发环境

为了实现本文的目标,你需要安装如下Python库

 

Pytorch官网上有着详细的安装教程,你可以看着来进行安装 - 传送门

tqdm库是Python的一个动态显示库,我们需要他来进行训练可视化

 

matplotlib库是Python的一个数据可视化库,我们需要他来进行训练结果可视化

 

2.3 实现流程

本代码的实现流程如下所示
本文代码主要流程

3.1 数据预处理部分

3.1.1 初始化全局变量

首先,我们需要导入上述提到的库,为了能够更全面的展示程序中每个函数的具体来源,因此本项目中的库不采用缩写的方式

 

对于Pytorch,我们需要手动去定义它是在CPU还是在GPU中训练;同时,我们需要使用到torchvision中的图片处理库torchvision.transforms来将图片转换为适用于网络的张量。

 
3.1.2 构建数据集

torchvision中的torchvision.datasets库中提供了MNIST数据集的下载地址,因此我们可以直接二调用对应的函数来下载MNIST的训练集和测试集

 

Pytorch中提供了一种叫做DataLoader的方法来让我们进行训练,该方法自动将数据集打包成为迭代器,能够让我们很方便地进行后续的训练处理

 

至此,数据集已经准备完毕。

3.2 训练部分

3.2.1 构建模型

在这里使用的是一个简单的卷积神经网络,其结构如下

 

其中torch.nn.Sequential函数能够自动将层数合并为一个模型,对于新手而言这种方式能够减少非常多的计算过程

随后,我们需要构建一个模型实例

 

to() 方法用于将张量放入到指定的设备(如CPU或GPU中),记住的是:不同设备的张量是无法进行运算的

如果一切正常,那么输出结果如下

 

读者也可以根据自己的兴趣去修改网络结构。

3.2.2 构建迭代器与损失函数
 

模型在构建迭代器的时候需要将所有参数传入到迭代器中,可以通过net.parameters()方法来得到模型的所有参数。

3.2.3 构建训练循环

训练循环是很多新手最头疼的地方,因此将会着重讲解这一部分

我们根据这个流程,构建一个循环框架

 
3.2.3.1 训练部分代码

对于训练部分,我们可以构造的模块为

 
3.2.3.2 测试部分代码

对于测试部分,我们可以构造的模块为

 
3.2.3.3 训练循环代码

将上述两个循环进行结合,就是最终的训练循环代码了

 

假如一切正常,能看到以下的训练过程

 

3.3 数据预后处理部分

数据后处理的部分包括训练结果可视化以及模型保存两个环节

3.3.1 训练结果可视化

我们需要使用到matplotlib来对结果进行可视化

 

结果如下图所示
Loss可视化
准确率可视化

3.3.2 保存模型

对于新手而言,我们选择直接保存整个模型

 

若想对这一方面有进一步的了解,可以参考这篇文章 传送门


版权声明


相关文章:

  • 依赖包括什么2025-05-22 13:30:05
  • 手机拖动滑块获取验证码没有反应2025-05-22 13:30:05
  • vue axios2025-05-22 13:30:05
  • 函数int main(void)已有主体2025-05-22 13:30:05
  • py文件生成exe运行失败2025-05-22 13:30:05
  • java课程设计案例精编2025-05-22 13:30:05
  • ftp下载文件命令2025-05-22 13:30:05
  • c语言中push函数pop函数2025-05-22 13:30:05
  • 深度优先遍历经典例题2025-05-22 13:30:05
  • office离线版2025-05-22 13:30:05