线稿上色的数据集:
dataset link:https://pan.baidu.com/s/1Abm7V6J2uNOy5U6nvsRSlg
key:eepv
txt文件生成
import os
import glob
def Create_Txt(data_name, data_path, data_class,txt_path,
ratio = 0.01):
# absolute path
data_path = os.path.join(data_path,data_name)
txt_path = os.path.join(txt_path,data_name)
# find the required file
imgs_path = glob.glob(data_path+"/"+data_class[0]+"\*.png")
num_data = int(len(imgs_path) * ratio)
# create the txt file
txt_class = ["train.txt","val.txt","test.txt"]
txt_class_ratio = [0.7, 0.05, 0.25]
if not os.path.exists(txt_path):
os.makedirs(txt_path)
start = 0
for i in range(len(txt_class)):
i_txt_path = os.path.join(txt_path,txt_class[i])
txt = open(i_txt_path, mode='w')
if i != len(txt_class)-1:
end = start + int(num_data * txt_class_ratio[i])
else:
end = num_data
for j in range(start, end):
name = os.path.basename(imgs_path[j])
data = []
for k in range(len(data_class)):
temp = data_path + "/" + data_class[k] + "/"+name
if k != len(data_class)-1:
temp = temp + " "
data.append(temp)
data.append("\n")
txt.write(''.join(data))
start = end
if __name__ == '__main__':
current_path = os.getcwd()
data_name = "sketch"
data_path = current_path + "/data"
data_class = ["img","label"]
txt_path = current_path + "/list"
Create_Txt(data_name, data_path, data_class,txt_path)
使用txt文本读入数据可以减少内存的需要,有时候自定义加载数据集是非常必要的。
自定义Dataset
from torch.utils.data import Dataset
import os
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
import torchvision
def cuda(*args):
return (item.cuda() for item in args)
class Sketch(Dataset):
def __init__(self,list_path, mode="train"):
self.mode = mode
if mode == "train":
list_path = os.path.join(list_path,"train.txt")
elif mode == "test":
list_path = os.path.join(list_path, "test.txt")
else:
list_path = os.path.join(list_path, "val.txt")
# .txt/.lst数据获取:打开文件,以空格分割每一行(注意:不要有空行)
self.img_list = [line.strip().split() for line in open(list_path)]
# 添加信息:sample{image_path,label_path, name}
self.files = self.read_files()
def __len__(self):
return len(self.files)
def __getitem__(self, index):
return self.load_item(index)
def read_files(self):
files = []
if self.mode == "test":
for item in self.img_list:
image_path = item
name = os.path.splitext(os.path.basename(image_path[0]))[0]
files.append({
"img": image_path[0],
"name": name,
})
else:
for item in self.img_list:
image_path, label_path = item
name = os.path.splitext(os.path.basename(label_path))[0]
files.append({
"img": image_path,
"label": label_path,
"name": name,
})
return files
# 根据索引,获得对象items:[images, label]
def load_item(self, index):
item = self.files[index]
image = self.read_image(item["img"],cv2.IMREAD_COLOR)
label = self.read_image(item["label"],cv2.IMREAD_GRAYSCALE)
name = item["name"]
if self.mode == "test":
return F.to_tensor(label),name
return F.to_tensor(image), F.to_tensor(label), name
def read_image(self, img_path, read_mode):
image = cv2.imread(img_path, read_mode).astype(np.float32)
if read_mode == cv2.IMREAD_COLOR:
# BGR -> RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
# Add 3rd dimension to grayscale
image = image[:, :, np.newaxis]
return image
if __name__ == '__main__':
current_path = os.getcwd()
txt_path_train = current_path + "/list/sketch"
# initial
Dataset = Sketch(txt_path_train, mode="train")
# the way of dataloader
dataloader = torch.utils.data.DataLoader(
Dataset,batch_size= 1,shuffle = False)
for index, items in enumerate(dataloader):
images, labels, name = items
images, labels = cuda(*[images, labels])
torchvision.transforms.ToPILImage()(images[0].cpu()).show()
torchvision.transforms.ToPILImage()(labels[0].cpu()).show()