为sketch数据集生成txt文本,并自定义Dataset

线稿上色的数据集:
dataset link:https://pan.baidu.com/s/1Abm7V6J2uNOy5U6nvsRSlg
key:eepv

txt文件生成
在这里插入图片描述

import os
import glob

def Create_Txt(data_name, data_path, data_class,txt_path,
               ratio = 0.01):
    # absolute path
    data_path = os.path.join(data_path,data_name)
    txt_path = os.path.join(txt_path,data_name)

    # find the required file
    imgs_path = glob.glob(data_path+"/"+data_class[0]+"\*.png")
    num_data = int(len(imgs_path) * ratio)

    # create the txt file
    txt_class = ["train.txt","val.txt","test.txt"]
    txt_class_ratio = [0.7, 0.05, 0.25]

    if not os.path.exists(txt_path):
        os.makedirs(txt_path)

    start = 0
    for i in range(len(txt_class)):
        i_txt_path = os.path.join(txt_path,txt_class[i])
        txt = open(i_txt_path, mode='w')
        if i != len(txt_class)-1:
            end = start + int(num_data * txt_class_ratio[i])
        else:
            end = num_data
        for j in range(start, end):
            name = os.path.basename(imgs_path[j])
            data = []
            for k in range(len(data_class)):
                temp = data_path + "/" + data_class[k] + "/"+name
                if k != len(data_class)-1:
                    temp = temp + " "
                data.append(temp)
            data.append("\n")
            txt.write(''.join(data))
        start = end

if __name__ == '__main__':
    current_path = os.getcwd()
    data_name = "sketch"
    data_path = current_path + "/data"
    data_class = ["img","label"]
    txt_path =  current_path + "/list"
    Create_Txt(data_name, data_path, data_class,txt_path)

使用txt文本读入数据可以减少内存的需要,有时候自定义加载数据集是非常必要的。
在这里插入图片描述

自定义Dataset

from torch.utils.data import Dataset
import os
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
import torchvision


def cuda(*args):
    return (item.cuda() for item in args)

class Sketch(Dataset):
    def __init__(self,list_path, mode="train"):
        self.mode = mode
        if mode == "train":
            list_path = os.path.join(list_path,"train.txt")
        elif mode == "test":
            list_path = os.path.join(list_path, "test.txt")
        else:
            list_path = os.path.join(list_path, "val.txt")
        # .txt/.lst数据获取:打开文件,以空格分割每一行(注意:不要有空行)
        self.img_list = [line.strip().split() for line in open(list_path)]
        # 添加信息:sample{image_path,label_path, name}
        self.files = self.read_files()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        return self.load_item(index)

    def read_files(self):
        files = []
        if self.mode == "test":
            for item in self.img_list:
                image_path = item
                name = os.path.splitext(os.path.basename(image_path[0]))[0]
                files.append({
                    "img": image_path[0],
                    "name": name,
                })
        else:
            for item in self.img_list:
                image_path, label_path = item
                name = os.path.splitext(os.path.basename(label_path))[0]
                files.append({
                    "img": image_path,
                    "label": label_path,
                    "name": name,
                })
        return files

    # 根据索引,获得对象items:[images, label]
    def load_item(self, index):
        item = self.files[index]

        image = self.read_image(item["img"],cv2.IMREAD_COLOR)
        label = self.read_image(item["label"],cv2.IMREAD_GRAYSCALE)
        name = item["name"]

        if self.mode == "test":
            return F.to_tensor(label),name

        return F.to_tensor(image), F.to_tensor(label), name

    def read_image(self, img_path, read_mode):
        image = cv2.imread(img_path, read_mode).astype(np.float32)
        if read_mode == cv2.IMREAD_COLOR:
            # BGR -> RGB
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            # Add 3rd dimension to grayscale
            image = image[:, :, np.newaxis]
        return image


if __name__ == '__main__':
    current_path = os.getcwd()
    txt_path_train =  current_path + "/list/sketch"
    # initial
    Dataset = Sketch(txt_path_train, mode="train")
    # the way of dataloader
    dataloader = torch.utils.data.DataLoader(
        Dataset,batch_size= 1,shuffle = False)

    for index, items in enumerate(dataloader):
        images, labels, name = items
        images, labels = cuda(*[images, labels])
        torchvision.transforms.ToPILImage()(images[0].cpu()).show()
        torchvision.transforms.ToPILImage()(labels[0].cpu()).show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码小白的成长

计算机网络PPT下载

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值