pytorch建立自己的数据集Dataset

源码下载,分别在label.py、test.py
将图片集(下图class)存到在当前工程目录下,并按类别分开(下图的1、2、3):
在这里插入图片描述
为图片生成label:

import os

path = "class"  # 图片集路径
classes = [i for i in os.listdir(path)]
files = os.listdir(path)
train = open("train.txt", 'w')
val = open("val.txt", 'w')
for i in classes:
    s = 0
    for imgname in os.listdir(os.path.join(path, i)):

        if s % 7 != 0:  # 7:1划分训练集测试集
            name = os.path.join(path, i) + '\\' + imgname + ' ' + str(classes.index(i)) + '\n'  # 我是win10,是\\,ubuntu注意!
            train.write(name)
        else:
            name = os.path.join(path, i) + '\\' + imgname + ' ' + str(classes.index(i)) + '\n'
            val.write(name)
        s += 1

val.close()
train.close()

结果:
在这里插入图片描述
其中,txt内容:路径+类别,如下:

class\1\metal100.jpg 0
class\1\metal101.jpg 0
class\1\metal102.jpg 0
class\1\metal103.jpg 0
class\1\metal104.jpg 0

使用:

from PIL import Image
import torch
import torchvision.transforms as transforms


class MyDataset(torch.utils.data.Dataset):  # 创类:MyDataset,继承torch.utils.data.Dataset
    def __init__(self, datatxt, transform=None):
        super(MyDataset, self).__init__()
        fh = open(datatxt, 'r')  # 打开txt,读取内容
        imgs = []
        for line in fh:  # 按行循环txt文本中的内容
            line = line.rstrip()  # 删除本行string字符串末尾的指定字符
            words = line.split()  # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            imgs.append((words[0], int(words[1])))  # 把txt里的内容读入imgs列表保存,words[0]是图片信息,words[1]是label

        self.imgs = imgs
        self.transform = transform

    def __getitem__(self, index):  # 按照索引读取每个元素的具体内容
        fn, label = self.imgs[index]  # fn是图片path
        img = Image.open(fn).convert('RGB')  # from PIL import Image

        if self.transform is not None:  # 是否进行transform
            img = self.transform(img)
        return img, label  # return回哪些内容,在训练时循环读取每个batch,就能获得哪些内容

    def __len__(self):  # 它返回的是数据集的长度,必须有
        return len(self.imgs)


'''标准化、图片变换'''
mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]
train_transforms = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=stdv)])

train_data = MyDataset(datatxt='train.txt', transform=train_transforms)

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

""" 训练时:"""
for data, label in train_loader:
    pass

外加一点识别时,读取图片的代码

from PIL import Image
from torchvision import transforms

img = Image.open('1.jpg')  # [H,W,C] [0,255] RGB
# img.show()
# tf=transforms.ToTensor()
# pic=tf(img)   # 单个操作

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])])		# 组合操作
img = transform(img)  # [C,H,W] [0,1] RGB
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值