这个学期考的不错,本来想放松一下,但是不知道如何放松,索性就写了这么一个类,来更好的完成我暑假做的这个图片分割与识别的项目。
大家如果有什么问题可以随时问我,还想要什么功能也非常欢迎评论,如果我的代码哪里有错误请一定要指出(先谢谢大家了),希望我们一起进步,一起学习。
"""
准备整个适用于神经网络的能够处理csv文件的类
这个类可以把写新生成的图片以像素的形式存储在csv文件中还有标签
"""
import csv
import numpy as np
class ProcessCSV:
def __init__(self,filename, mode, pixels=784):
self.pixels = int(pixels)
self.filename = str(filename)
self.mode = str(mode)
self.file = open(self.filename, self.mode, newline='')
self.reader = None
self.writer = None
if self.mode == 'w' or self.mode == 'a':
self.writer = csv.writer(self.file)
elif mode == 'r':
self.writer = csv.reader(self.file)
def open_file(self, filename, mode):
self.file = open(self.filename, self.mode, newline='')
print(type(self.file))
def close_file(self):
self.file.close()
def write_header(self):
"""
打开一个空的csv对象自动生成表头格式是:
'label','pixel0','pexel1','pixel2'...
"""
header = ['lable']
i = 0
while i < self.pixels:
pixel = 'pixel{}'.format(i)
header.append(pixel)
i += 1
self.writer.writerow(header)
def append_data(self, data):
"""data 应是一个列表(矩阵也可以),包含一条数据"""
if isinstance(data, np.ndarray):
data = data.tolist()
self.writer.writerow(data)
def append_datas(self, datas):
"""
datas应是一个包含列表的列表(numpy类型的矩阵也可以),
列表中的每个列表代表了一条数据
"""
if isinstance(datas, np.ndarray):
datas = datas.tolist()
self.writer.writerows(datas)
def integrate_data(self, pixel_values, labels):
"""
考虑到识别数字大部分的项目,标签(及正确解的集合)和元素都是分开的
这个方法把标签和像数值进行整合得到符合此csv格式的数据列表
parameters
------
pixel_values: list of lists or np.ndarray
labels: list or np.ndarray, [1,2,3,9,0]
Return
------
datas: 如果是一条数据就是一个列表否则就是一个包含列表的列表
"""
# 先把参数全部转换为list
if isinstance(pixel_values, np.ndarray):
new_pixel_values = pixel_values.tolist()
elif isinstance(pixel_values[0], np.ndarray):
new_pixel_values = []
for value in pixel_values:
new_pixel_values.append(value.tolist())
else:
new_pixel_values = pixel_values
if isinstance(labels, np.ndarray):
labels = labels.tolist()
if len(labels) == 1:
datas = labels + pixel_values
else:
datas = []
i = 0
while i < len(labels):
datas.append([labels[i]] + new_pixel_values[i])
i += 1
return datas
def get_labels_and_images(self):
"""
这个方法读取一个csv文件,返回图片和标签
Return
-----------
labels: 一个含有整数0到9的一维数组
images: 一个含有 照片数x像素数 的矩阵
"""
reader = csv.reader(self.file)
next(reader) # header
label_ls = []
images_ls = []
for row in reader:
label_ls.append(row[0])
images_ls.append(row[1:])
labels = np.array(label_ls)
images = np.array(images_ls)
return labels, images
最后得到的csv文件应该如下所示(用excel打开)
或是这样用vscode打开