pythorch处理数据集常用操作句句注释

#用于解出测试集解析的代码
import glob
import numpy as ny
import cv2
import os #os库用来搞文件存放相关的

train_list = glob.glob("/../data_*") # *号是通配符获取一系列文件放到列表里面
print(train_list) #打印看看
save_path ="/home/kuan/.."#定义一个数据存放的路径

for l in train_list
    print(l) #打印文件
    l_dict = unpickle(l) #这里是使用数据集官网的调用的解析
    print(l_dict)
    print(l_dict.keys()) #打印key看看数据都有哪些标签
   
for im_idx, im_data in enumerate(l_dict[b'data']): #这里是遍历数据这个维度
    print(im_idx) #索引值
    print(im_data) #图片
    im_label = l_dict[b'labels'][im_idx] #获得label_name 和数据 
    im_name = l_dict[b'filenames'][im_idx]

    print(im_label,im_name,im_data)

    im_label_name =label_name[im_label]
    im_data = np.reshape( im_data,[3,32,32]) #通过numpy将数据转化为 3*32*32
    im_data = np.transpose([ im_data,(1,2,0)) #将数据转化为 把原先的3大小的尺寸放到第三个位置         
                                   #32大小的尺寸放到第1和2个位置 
    cv2.imshow("im_data",im_data)   #opencv库数据可视化使得im_data获得了名字
    cv2.imshow("im_data",cv2.resize(im_data,(200,200))
    cv2.waitkey(0) #这个函数可以暂停一下防止直接刷下去

    if not os.path.exists("{}/{}",format(save_path,im_label_name)):
         os.mkdir("{}/{}",format(save_path,im_label_name))     #如果不存在这个路径就创建

    cv2.imwrite("{}/{}/{}",format(save_path,
                            im_label_name,
                            im_name.decode("utf-8")))
#decode函数解码成字符串型

完成自定义数据加载 网络训练前预处理

from torchvision import transforms #进行数据增强
from torch.utils.data import  DataLoader,Dataset #导入数据加载相关的类
import  os
from PIL import  Image #PIL类似于OpenCV
import numpy as np
import glob

label_name=["plane","dog"]#定义列表名字
label_dict={} #定义一个字典

for idx,name in enmuerate(label_name): #将类别转化成数字
    label_dict[name] =idx

##class Mydataset(Dataset): #至少完成三个类的定义
##  def __init__(self):
## def __getitem__(self, item):
## def __len__(self):
def default_loader(path): #需要一个路径
    return Image.open(path).convert("RGB") #将他转化为RGB
#通过conpose拼接多个数据增强的方法 flip随机水平和垂直 翻转
train_transform =transforms.Compose([
    transforms.RandomResizedCrop((28,28)),#resize
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),#角度在-90到90度之间进行翻转
    transforms.RandomGrayscale(0.1),#0.1概率将图片转为灰度图
    transforms.ColorJitter(0.3,0.3,0.3,0.3),#
    transforms.ToTensor() #转化为网络输入数据
]
                              )
class MyDataset(Dataset):
    def __init__(self, im_list, transfrom=None ,
                 loader=default_loader ): #加个参数获得当前文件夹下的列表 传入transform即进行数据增强的函数 加入loader
        super(MyDataset,self).__init__() #初始化这个类
        imgs=[]

        for im_item in im_list: #"/home/CIFAR10/TRAIN/airplane/aeroplane_1.png" 原路径是这个
            im_label_name =im_item.split("/")[-2] #这里倒数第二个就是获取类编号
            imgs.append([im_item, im_label_name[im_label_name]]) #增加 传入路径和列表名对应的id

        # 将三个量变成类内的变量
        self.imgs = imgs #元素
        self.transform =transfrom #方法
        self.loader =loader  #方法


    def __getitem__(self, index): #定义数据的读取和增强 返回图片的数据和label index是训练时反复传的索引值
        im_path ,im_label =self.imgs[index] #读取相应的数据
        im_data =self.loader(im_path) #图片数据采用self.loader读取图片

        #下面是数据
        if self.transform is not  None


        return im_data, im_label

    #返回样本的长度
    def __len__(self):
        return len(self.imgs)

#获得训练集和测试集的列表
im_train_list =glob.glob("/home/CIFAR10/TRAIN/*/*.png")
im_test_list =glob.glob("/home/CIFAR10/TRAIN/*/*.png")
#数据加载
train_loader = MyDataset(im_train_list,transfrom = train_transform) #参数1是文件列表,第二个是transform
train_loader = MyDataset(im_test_list,transfrom = transforms.ToTensor() )

train_data_loader = DataLoader( dataset=train_dataset,#传入data
                                batch_size=6, #分成n=6大小的块
                                shuffle=True, #打乱顺序
                                num_workers=4 #选用四个进行加载
                                )

test_data_loader = DataLoader( dataset=test_dataset,#传入data
                                batch_size=6,
                                shuffle=False,
                                num_workers=4 #选用四个进行加载
                                )
print("num_of_train",len(train_dataset))
print("num_of_test",len(test_dataset))

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值