Pytorch深度学习——用卷积神经网络实现MNIST数据集分类_采用卷积神经网络分类mnist数据集_学习CV的研一小白的博客-程序员宅基地

技术标签: cnn  深度学习  pytorch  PyTorch学习笔记  

目录

1 准备数据集

2 建立模型

3 构造损失函数+优化器

4 训练+测试

5 完整代码+运行结果


 选择下图结构的卷积神经网络来进行训练:

步骤:

  1. 选择 5 x 5 的卷积核,输入通道为 1,输出通道为 10:此时图像矩阵经过 5 x 5 的卷积核后会小两圈,也就是4个数位,变成 24 x 24,输出通道为10;
  2. 选择 2 x 2 的最大池化层:此时图像大小缩短一半,变成 12 x 12,通道数不变;
  3. 再次经过 5 x 5 的卷积核,输入通道为 10,输出通道为 20:此时图像再小两圈,变成 8 *8,输出通道为20;
  4. 再次经过 2 x 2 的最大池化层此时图像大小缩短一半,变成 4 x 4,通道数不变;
  5. 最后将图像整型变换成向量,输入到全连接层中:输入一共有 4 x 4 x 20 = 320 个元素,输出为 10.

具体代码如下:

1 准备数据集

# 准备数据集
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=True,
                               download=True,
                               transform=transform)
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist',
                              train=False,
                              download=True,
                              transform=transform)
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=batch_size)

2 建立模型

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320, 10)

    def forward(self, x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.conv1(x)))
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

3 构造损失函数+优化器

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

4 训练+测试

def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        inputs,target=inputs.to(device),target.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d,%.5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))
            running_loss = 0.0

def test():
    correct=0
    total=0
    with torch.no_grad():
        for data in test_loader:
            inputs,target=data
            inputs,target=inputs.to(device),target.to(device)
            outputs=model(inputs)
            _,predicted=torch.max(outputs.data,dim=1)
            total+=target.size(0)
            correct+=(predicted==target).sum().item()
    print('Accuracy on test set:%d %% [%d%d]' %(100*correct/total,correct,total))

if __name__ =='__main__':
    for epoch in range(10):
        train(epoch)
        test()

5 完整代码+运行结果

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

# 准备数据集
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=True,
                               download=True,
                               transform=transform)
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist',
                              train=False,
                              download=True,
                              transform=transform)
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=batch_size)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320, 10)

    def forward(self, x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.conv1(x)))
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        inputs,target=inputs.to(device),target.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d,%.5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))
            running_loss = 0.0

def test():
    correct=0
    total=0
    with torch.no_grad():
        for data in test_loader:
            inputs,target=data
            inputs,target=inputs.to(device),target.to(device)
            outputs=model(inputs)
            _,predicted=torch.max(outputs.data,dim=1)
            total+=target.size(0)
            correct+=(predicted==target).sum().item()
    print('Accuracy on test set:%d %% [%d%d]' %(100*correct/total,correct,total))

if __name__ =='__main__':
    for epoch in range(10):
        train(epoch)
        test()

运行结果如下:

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Learning_AI/article/details/122641455

智能推荐

java mbean进程_java – 如何连接到另一个本地进程中的mBeanServer?_Karena Lu的博客-程序员宅基地

如果在启动JVM时设置“com.sun.management.jmxremote”系统属性,则可以运行jconsole或visualvm并连接到该本地mBeanServer.我想做他们正在做的同样的事情,但无法弄清楚如何做.是否有可用于标识本地运行的JVM的服务URL?我知道我可以通过在特定端口上设置jmxmp或rmi侦听器然后连接到该端口来实现此目的,但我不想这样做,因为这意味着我必须管理端口并..._java 连接java进程修改mbean参数

【DIY】多模式51单片机心形流水灯+呼吸灯+蜂鸣器音乐_p2=p1=p0=oxff-程序员宅基地

[DIY多模式]51单片机心形流水灯+呼吸灯+蜂鸣器音乐总体设计1.基础硬件DIY设计1).整体原理图2).PCB电路3).3D_PCB2.单片机程序设计1)呼吸灯[简易模拟PWM]2)蜂鸣器音乐3)几种简易流水灯方式3.效果展示1).实物电路顶层图2).实物电路底层图3).整体效果图注:本文仅用于学习分享,分享自己DIY制作的多模式51单片机心形流水灯[纯手工制作],若有不妥之处,请指正,感谢..._p2=p1=p0=oxff

MYSQL--基础--8.1--my.cnf--配置说明_my.cnf 配置-程序员宅基地

一、client客户端默认的连接参数配置port = 3307默认连接端口socket = /data/mysqldata/3307/mysql.sock用于本地连接的socket套接字default-character-set = utf8mb4编码二、mysqld服务端基本设置port = 3307MySQL监听端口socket = /data/mysqldata/3307/mysql.sock为MySQL客户端程序和服务器之间的本地通讯指定一个套接字文件pid-file =_my.cnf 配置

asp.net与SQLsever的连接字符串及配置数据库_asp.net连接sql server的连接字符串怎么写_wzf666的博客-程序员宅基地

asp.net与SQLsever的连接字符串及配置数据库第一种:在web.config文件中使用AppSettings<appSettings> <add key="名字" value="server=服务器名;user id=数据库ID;password=数据库密码;database=数据库名;" /></appSettings>第二种:使用connectionStrings<connectionStrings> <add n_asp.net连接sql server的连接字符串怎么写

Java—java中如何实现动态数组的创建与赋值-程序员宅基地

最近,项目中需要实现:提取一组数据,每个数组都有自己的属性,这组数据的长度又未知,还可能变长,变短,我考虑一会,实现如下;1.在oncreate前面,我声明两个数组String data1[ ][ ];//使用的时候,比如和adapter关联String data2[ ][ ];//提取数据,我这里是Poisaerh出来的数据2.实例化数组,提取数据,判断二维数据的维数

云服务器一:云服务器的选购-程序员宅基地

目录前言一、云服务器是什么?二、购买步骤方式一:方式二:总结前言环境:基于腾讯云服务器搭建Linux简述:由于工作需要和个人发展计划,准备入门Linux开发,一开始是使用虚拟机,VirtualBox和VMware都搭建过,但是感觉不够优雅,换个电脑就要重新来,思来想去,觉得云服务器是个很好的选择,听起来就足够高大上,那就决定这么干了。PS:新手建议先申请试用,大概有7-15天的试用时间,尝试着先搭个环境什么的,等上手了再购买也不迟,现在云服务器的优惠力度都挺大的,可以放心购买。提示:以下是基

随便推点

在QT中对label增加单击事件_qt label 增加单击跳转事件-程序员宅基地

1.加入头文件#include <QMouseEvent>2.在构造函数数添加该部件的事件过滤器,ui->label->installEventFilter(this);3.事件过滤器bool MainWindow::eventFilter(QObject *watched, QEvent *event){ static in..._qt label 增加单击跳转事件

黑马程序员-----程序员之路_____JDK1.5新特性之泛型-程序员宅基地

----------android培训、Java培训、期待与您交流!----------

CFT-ctf.show-信息收集闯关_c1c905c5-050d-4004-8201-75c5e6c6c67b.challenge.ctf_max-sec (辣子鸡丁)的博客-程序员宅基地

CFT-ftf.show闯关游戏_c1c905c5-050d-4004-8201-75c5e6c6c67b.challenge.ctf.show//robots.txt

f_lseek_用stm32移植FATFS的过程中,我弄了几天始终搞不懂2个问题。关于底层引脚和f_read()的问题。...-程序员宅基地

我发现网上大部分的例程都是在"sdcard.c"中直接/*ConfigurePC.08,PC.09,PC.10,PC.11,PC.12pin:D0,D1,D2,D3,CLKpin*/GPIO_InitStructure.GPIO_Pin=GPIO_Pin_8|GPIO_Pin_12;GPIO_InitSt...我发现网上大部分的例程都是在"sdcard.c"中直接 /* Configure PC..._stm32 f_seek

vue 实践中的一些小技巧_vue 实用技巧-程序员宅基地

1 数据过滤:经常会在项目中用到数据过滤,如何让代码更优雅?以下是在项目中使用到的一些小技巧:场景。codeArr:[{projectNum:'123123123312',name:'coco',age:20,gender:'formale',applayNum:'24124126417'},{projectNum:'123123123345',name:'joco',age:..._vue 实用技巧

performance_schema实战-程序员宅基地

// 查看监控的角色// 看到所有的host,所有用户,所有角色都会被监控mysql> select * from performance_schema.setup_actors;+------+------+------+---------+---------+| HOST | USER | ROLE | ENABLED | HISTORY |+------+------+------+---------+---------+| % | % | % | YES .