【pytorch】训练集的读取_cifarloader-程序员宅基地

技术标签: pytorch  deep learning-paper  

pytorch读取训练集是非常便捷的,只需要使用到2个类:

(1)torch.utils.data.Dataset

(2)torch.utils.data.DataLoader


常用数据集的读取

1、torchvision.datasets的使用

对于常用数据集,可以使用torchvision.datasets直接进行读取。torchvision.dataset是torch.utils.data.Dataset的实现

该包提供了以下数据集的读取

  • MNIST
  • COCO (Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

下面以cifar10为例:

[python] view plain copy
print ?
  1. import torch  
  2. import torchvision  
  3. from PIL import Image  
  4.   
  5. cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)  
  6. print(cifarSet[0])  
  7. img, label = cifarSet[0]  
  8. print (img)  
  9. print (label)  
  10. print (img.format, img.size, img.mode)  
  11. img.show()  
import torch
import torchvision
from PIL import Image

cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
print(cifarSet[0])
img, label = cifarSet[0]
print (img)
print (label)
print (img.format, img.size, img.mode)
img.show()

2、实例化torch.utils.data.DataLoader

[python] view plain copy
print ?
  1. mytransform = transforms.Compose([  
  2.     transforms.ToTensor()  
  3.     ]  
  4. )  
  5.   
  6. # torch.utils.data.DataLoader  
  7. cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )  
  8. cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)  
mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

# torch.utils.data.DataLoader
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)

下面就可以进行读取数据的显示,以进行简单测试是否读取成功:

[python] view plain copy
print ?
  1. for i, data in enumerate(cifarLoader, 0):  
  2.     print(data[i][0])  
  3.     # PIL  
  4.     img = transforms.ToPILImage()(data[i][0])  
  5.     img.show()  
  6.     break  
for i, data in enumerate(cifarLoader, 0):
    print(data[i][0])
    # PIL
    img = transforms.ToPILImage()(data[i][0])
    img.show()
    break


自定义标签数据集的读取

1、实现torch.utils.data.Dataset

假设我们有一个标签test_images.txt,内容如下:


对应的图像位于images目录下。

首先要继承torch.utils.data.Dataset类,完成图像及标签的读取。

[python] view plain copy
print ?
  1. import os  
  2. import torch  
  3. import torch.utils.data as data  
  4. from PIL import Image  
  5.   
  6. def default_loader(path):  
  7.     return Image.open(path).convert('RGB')  
  8.   
  9. class myImageFloder(data.Dataset):  
  10.     def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):  
  11.         fh = open(label)  
  12.         c=0  
  13.         imgs=[]  
  14.         class_names=[]  
  15.         for line in  fh.readlines():  
  16.             if c==0:  
  17.                 class_names=[n.strip() for n in line.rstrip().split('   ')]  
  18.             else:  
  19.                 cls = line.split()   
  20.                 fn = cls.pop(0)  
  21.                 if os.path.isfile(os.path.join(root, fn)):  
  22.                     imgs.append((fn, tuple([float(v) for v in cls])))  
  23.             c=c+1  
  24.         self.root = root  
  25.         self.imgs = imgs  
  26.         self.classes = class_names  
  27.         self.transform = transform  
  28.         self.target_transform = target_transform  
  29.         self.loader = loader  
  30.   
  31.     def __getitem__(self, index):  
  32.         fn, label = self.imgs[index]  
  33.         img = self.loader(os.path.join(self.root, fn))  
  34.         if self.transform is not None:  
  35.             img = self.transform(img)  
  36.         return img, torch.Tensor(label)  
  37.   
  38.     def __len__(self):  
  39.         return len(self.imgs)  
  40.       
  41.     def getName(self):  
  42.         return self.classes  
import os
import torch
import torch.utils.data as data
from PIL import Image

def default_loader(path):
    return Image.open(path).convert('RGB')

class myImageFloder(data.Dataset):
    def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
        fh = open(label)
        c=0
        imgs=[]
        class_names=[]
        for line in  fh.readlines():
            if c==0:
                class_names=[n.strip() for n in line.rstrip().split('	')]
            else:
                cls = line.split() 
                fn = cls.pop(0)
                if os.path.isfile(os.path.join(root, fn)):
                    imgs.append((fn, tuple([float(v) for v in cls])))
            c=c+1
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.Tensor(label)

    def __len__(self):
        return len(self.imgs)
    
    def getName(self):
        return self.classes

2、实例化torch.utils.data.DataLoader

[python] view plain copy
print ?
  1. mytransform = transforms.Compose([  
  2.     transforms.ToTensor()  
  3.     ]  
  4. )  
  5.   
  6. # torch.utils.data.DataLoader  
  7. imgLoader = torch.utils.data.DataLoader(  
  8.          myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ),   
  9.          batch_size= 2, shuffle= False, num_workers= 2)  
  10.   
  11. for i, data in enumerate(imgLoader, 0):  
  12.     print(data[i][0])  
  13.     # opencv  
  14.     img2 = data[i][0].numpy()*255  
  15.     img2 = img2.astype('uint8')  
  16.     img2 = np.transpose(img2, (1,2,0))  
  17.     img2=img2[:,:,::-1]#RGB->BGR  
  18.     cv2.imshow('img2', img2)  
  19.     cv2.waitKey()  
  20.     break  
mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

# torch.utils.data.DataLoader
imgLoader = torch.utils.data.DataLoader(
         myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), 
         batch_size= 2, shuffle= False, num_workers= 2)

for i, data in enumerate(imgLoader, 0):
    print(data[i][0])
    # opencv
    img2 = data[i][0].numpy()*255
    img2 = img2.astype('uint8')
    img2 = np.transpose(img2, (1,2,0))
    img2=img2[:,:,::-1]#RGB->BGR
    cv2.imshow('img2', img2)
    cv2.waitKey()
    break

相关代码可以查看:tfygg/pytorch-tutorials


---------------------------------------------------------------------------------------------------

在各方小伙伴的努力和支持下,pytorch中文文档 第一版终于上线啦!!!(鼓掌)文档还有很多小瑕疵,但是大体可以放心使用了~我们遵循快速迭代的原则,所以赶紧上线第一版来接受广大开源社区的意见和建议。欢迎加入我们

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/zhuiqiuk/article/details/80297587

智能推荐

学习 Rust 的第十四天:如何使用HashMap

学习 Rust 的第十四天:如何使用HashMap

安卓接入wwise

System.out.print("-------NativeHelper--c++返回值------Looper.getMainLooper():"+Looper.getMainLooper()+"\n");第五步:子线程回到 java 里面UI线程。第四步:postEventFun回调。

玩转SpringBoot整合Mybatis连接访问MySQL数据库_springboot整合mybatis连接mysql数据库配置-程序员宅基地

文章浏览阅读318次。步骤1.导入相关的依赖<?xml version="1.0" encoding="UTF-8"?><project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-._springboot整合mybatis连接mysql数据库配置

JavaScript解析JSON数据的方法和技巧-程序员宅基地

文章浏览阅读842次。js读取解析JSON类型数据JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,采用完全独立于语言的文本格式,是理想的数据交换格式。同时,JSON是 JavaScript 原生格式,这意...js解析json字符串为json对象,js解析json的6种方法1.一种为使用eval()函数。方式如下:var dataObj=eval("("+data+")")..._js解析json数组

解决Rstudio打开空白_rstudio打开后一片空白-程序员宅基地

文章浏览阅读1.9w次,点赞13次,收藏31次。近期更新了下Rstudio,突然出现打不开的情况,看了两天,终于让我解决了,解决方法如下:首先检查:1、R语言安装指定为64位(其实安装时按默认方式就行,现在电脑大多数都是64的);2、R语言和Rstudio的安装父目录要为同一个,你可以选择一个盘直接在里面创建一个R文件夹,把R语言和Rstudio两个安装路径选择我们创建的R文件夹就可以了;3、R语言和Rstudio的安装路径不要有汉字,这个是必须知道的。4、有些人还是会显示空白,就可以选择右键选择以管理员身份运行,如果可以进了,可以用以下方式_rstudio打开后一片空白

理解 Ruby Symbol (Ruby中的冒号)_ruby 冒号-程序员宅基地

文章浏览阅读3w次,点赞5次,收藏16次。Symbol 是什么Ruby 是一个强大的面向对象脚本语言(本文所用 Ruby 版本为1.8.6),在 Ruby 中 Symbol 表示“名字”,比如字符串的名字,标识符的名字。创建一个 Symbol 对象的方法是在名字或者字符串前面加上冒号:创建 symbol 对象 :foo:test_ruby 冒号

随便推点

JS 正则表达式百分比_js正则表达式百分数-程序员宅基地

文章浏览阅读712次。第一个数字后,匹配至少2个数字。可以为0-9. 即最小3位 最小100%开头第一个数字需要为1-9,即不能为0。匹配0个或1个 . ,可以没有小数点。以点之后一个或者两个数字结尾。_js正则表达式百分数

iOS上的UI是如何渲染出来的? 深入浅出UIKit渲染

我们在代码中写的View、Image等组件,最终是如何一步步渲染到屏幕上的呢?触摸、动画等是如何实现的?我们可以利用这些知识做哪些优化呢?本文先从屏幕物理层原理出发,一步步介绍渲染流程,然后介绍iOS的UIKit框架设计,最后介绍如何利用这些知识做优化先看第一步,屏幕是如何非常细腻的展示图片的。

分享我的电子藏书:C++系列(共32本)_accelerated c++ pdf-程序员宅基地

文章浏览阅读1.1k次。这些书籍是网上经常讨论的书籍,在此做一个总结和归类,免去大家找资料的奔波之苦。因书籍是自己学习所用,上传的所有书我收藏时都浏览过,保证书籍质量。 基础:××××××××××《C++编程思想》Thinking in C++Bruce Eckel著,刘宗田等译卷1 2E 卷2 1E机械工业出版社中文版,PDF格式卷1 506页 卷2 534页内附_accelerated c++ pdf

Vue2-TodoList案例(初级 后面会进行完善)_vue2 todolist网页-程序员宅基地

文章浏览阅读240次。Vue2-TodoList案例_vue2 todolist网页

java.lang.NullPointerException出现的几种原因及解决方案_exception in thread "main" java.lang.nullpointerex-程序员宅基地

文章浏览阅读2.9k次。如果你的对象的引用等于 null , NullPointerException 则会抛出,使用静态 String.valueOf 方法,该方法不会抛出任。主要介绍了 java.lang.NullPointerException 出现的几种原因及解决方案 , 本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下。7 、返回 null ,方法的返回值不要定义成为一般的类型,而是用数组。4 、字符串与文字的比较,文字可以是一个字符串或 Enum 的元素,如下会出现异常。_exception in thread "main" java.lang.nullpointerexception: cannot invoke "ja