pytorch数据处理

定义自己的数据类

import torch.utils.data as data
import os

import PIL.Image as Image

import torch
class myDataset(torch.utils.data.Dataset):
    def __init__(self, root,augment = None):
        self.dataSource = root
        file_is_pic=['.jpg','.png','.JPG'] #这是你要用的数据,有哪些后缀名
        # 这个list存放所有图像的地址
        self.image_files = np.array([x.path for x in os.scandir(root) if
                                     x.name.endswith(file_is_pic[0]) or x.name.endswith(file_is_pic[1]) or x.name.endswith(file_is_pic[2])])
        self.augment = augment  # 是否需要图像增强


    def __getitem__(self, index):# 这里我定义的是同时读取训练和测试数据到一个例子里
        # 读取图像数据并返回
        data_of_pic = self.image_files[index]
        data_of_label = self.image_files[index].replace('rgb_origin', 'rgb_origin_mask')
        if self.augment:
            image = open_image(data_of_pic)
            image = self.augment(image)  # 这里对图像进行了增强
            image_label= open_image(data_of_label)
            image_label = self.augment(image_label)  # 这里是label数据一起读进来
            return image, image_label # 自己 的增强方式要返回tensor格式
        else:
            # 如果不进行增强,直接读取图像数据并返回
            # 这里的open_image是读取图像函数,可以用PIL、opencv等库进行读取
            return open_image(data_of_pic), open_image(data_of_label)

    def __len__(self):
        # 返回图像的数量
        return len(self.image_files)

加载使用

img_path = 'rgb_origin'
finger_datasets=  myDataset(img_path)
# 读取数据,分批次# 数据集
train_loader = torch.utils.data.DataLoader(dataset=finger_datasets,  
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=0)  # 单个线程进行数据读取

for  epoch in range(100):
    for data,label  in train_loader:
		# 同时具有data和对应的label
        outputs = UNet(data)  # 输入数据,进入网络得到输出
        # permute such that number of desired segments would be on 4th dimension
        outputs = outputs.permute(0, 2, 3, 1)  # permute 维度变化操作,问题不大!
        # 类似torch.transpose函数,而permute只适用于tensor.permute
        width_out = 128
        height_out = 128
        m = outputs.shape[0]
        # Resizing the outputs and label to caculate pixel wise softmax loss
        outputs = outputs.resize(m * width_out * height_out, 2)  # 数据格式, 镜像2维??
        labels = labels.resize(m * width_out * height_out)   
      
        criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失函数
        optimizer = torch.optim.SGD(UNet.parameters(), lr=0.01, momentum=0.99)  # sgd梯度下降,对UNets的参数优化
        optimizer.zero_grad()  # 梯度清零
        loss = criterion(outputs, labels) # 输出和标签的数据对比
        print(loss)
        loss.backward() # 反向传播
        optimizer.step() # 进行优化
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值