需求:读取csv文件的最后一列,真是底子不实在,曲线救国啊!
先读出来的是字符串的列表,不知道为何把表头也读进来了,
然后把他踢出去(出栈),
然后再转化为整型的列表。
学会把每一步输出,学会debug!
还有就是,字符串组成的列表也可以对列表中的每个字符串进行切片操作
import csv
def read_label(dir):
with open(dir, "r") as f:
#"D:\\dataset\\GTSRB\\GT-final_test.csv"
reader = csv.reader(f)
list = [i[0].split(";")[7] for i in reader]
# print(list)
list.pop(0)
# print(list)
new_numbers = [];
for n in list:
new_numbers.append(int(n));
numbers = new_numbers;
# print(numbers)
return numbers
德国交通标志识别训练和测试集读取
import numpy as np
import torch
import os
import csv
import random
from PIL import Image
from torch.utils.data import Dataset
random.seed(1)
tsr_label={"0": 0, "1": 1, "2": 2, "3": 3,"4": 4, "5": 5, "6": 6, "7": 7,"8": 8, "9": 9, "10": 10,"11": 11,"12": 12, "13": 13, "14": 14, "15": 15,"16": 16, "17": 17, "18": 18,
"19": 19,"20": 20, "21": 21, "22": 22, "23": 23,"24": 24, "25": 25, "26": 26, "27": 27, "28": 28, "29": 29, "30": 30, "31": 31,"32": 32, "33": 33, "34": 34, "35": 35,
"36": 36, "37": 37, "38": 38, "39": 39,"40": 40, "41": 41, "42": 42
}
# class TSRDataset(Dataset):
class CatDogDataset(Dataset):
def __init__(self,data_dir,mode="train",transform=None):
self.label_name = {"0": 0, "1": 1, "2": 2, "3": 3,"4": 4, "5": 5, "6": 6, "7": 7,"8": 8, "9": 9, "10": 10,"11": 11,"12": 12, "13": 13, "14": 14, "15": 15,"16": 16, "17": 17, "18": 18,
"19": 19,"20": 20, "21": 21, "22": 22, "23": 23,"24": 24, "25": 25, "26": 26, "27": 27, "28": 28, "29": 29, "30": 30, "31": 31,"32": 32, "33": 33, "34": 34, "35": 35,
"36": 36, "37": 37, "38": 38, "39": 39,"40": 40, "41": 41, "42": 42
}
self.mode = mode
self.data_info = self.get_img_info(self,data_dir)# 是个list,jpg格式的数据;分类情况
# # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
# 根据index索引返回图片、标签 todo
path_img, label = self.data_info[index]
# 在init里定义,用下面定义的方法,用来读取图片路径和标签
img = Image.open(path_img).convert('RGB') # 0~255
# img是PIL数据类型
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(self,data_dir):
data_info = list()
if self.mode == "train":
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = tsr_label[sub_dir]
data_info.append((path_img, int(label)))
# path_img图片在计算机中的路径
# 图片的类别
else:
data_info = self.read_test_dataset(data_dir)
# data_dir = r"D:\dataset\myGTSRBTest"
# img_names = os.listdir(data_dir)
# img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) # 列表
# path_img_set = [os.path.join(data_dir, n) for n in img_names]
# img_labels = self.read_label(r"D:\dataset\myTestcsv\GT-final_test.csv")
# data_info = [(n, l) for n, l in zip(path_img_set, img_labels)]
return data_info
def read_test_dataset(self,dir):
import os
import csv
# data_dir = r"D:\dataset\myGTSRBTest"
img_names = os.listdir(dir)
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# print(img_names)
csv_file = open(r"GT-final_test.csv") # 打开csv文件
csv_reader_lines = csv.reader(csv_file)
dict = {}
path_dict = {}
# data = []# 不用这个
for one_line in csv_reader_lines:
# data.append(one_line[len(one_line)-1])
dict[int(one_line[0])] = one_line[len(one_line) - 1]
for root, dirs, _ in os.walk(dir):
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, img_name)
path_dict[i] = path_img
result_list = []
for _ in range(len(dict)):
temp_tuple = []
temp_tuple.append(path_dict[_])
temp_tuple.append(int(dict[_]))
result_list.append(tuple(temp_tuple))
return result_list
训练集验证集的划分
def get_img_info(self,data_dir):
import random
rng_seed = 620
# data_dir = r"D:\dataset\myGTSRB"
# mode = "train"
split_idx = 0.9
random.seed(rng_seed)
data_info = self.read_data_info(data_dir)
random.shuffle(data_info)
indx = int(len(data_info) * split_idx)
if self.mode == "train":
train = data_info[:indx]
# print(len(train))
else:
train = data_info[indx:]
return train
def read_data_info(self,data_dir):
import os
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
# label = 0
# path_img = None
for sub_dir in dirs:
img_name = os.listdir(os.path.join(root, sub_dir))
# 用于返回指定的文件夹包含的文件或文件夹的名字的列表
# img_names = list(filter(lambda x: x.endswith('.jpg'), img_name))
# 遍历图片
for i in range(len(img_name)):
img_name1 = img_name[i]
path_img = os.path.join(root, sub_dir, img_name1)
label = tsr_label[sub_dir]
data_info.append((path_img, int(label)))
# print(data_info)
# print(data_info)
return data_info
注意:不要介意那些类的命名,变量的命名,我是个小白。