第一步学习使用pytorch框架
会使用框架写出改进算法的模块很重要,熟悉框架才能将网络结构转为代码,进而可以通过实验来验证想法的正确性!!!
一、数据导入部分的基本思维导图
二、代码源码
from torch.utils.data import Dataset
import torch
import os
import cv2
import numpy as np
#Torch的很多三方库的开头字母都是大写的,需要注意,自己定义类库时可采用大驼峰方式命名,这里采用的是class son(father):的定义方式,儿子可以使用父亲的函数
class MyDataset(Dataset):
def __init__(self,root,is_train=True):
#导入数据,如果是图片,保存图片路径,可以减少对图片的持续读入,避免内存爆炸
self.dataset=[]
#通过改变MyDataset的第二个参数选择进行训练或测试文件夹,a=123 if True else 234
dir='train' if is_train else 'test'
#将根路径与训练文件夹拼接或与测试文件夹拼接
sub_dir=os.path.join(root,dir)
img_lists=os.listdir(sub_dir)
#i=1"数量"
for img_name in img_lists:
#print(img_name)
#print('{}'.format(i))"数量"
#i=i+1"数量"
img_dir=os.path.join(sub_dir,img_name)
#print(img_dir)
self.dataset.append(img_dir)
def __len__(self):
#返回数据集的长度
return self.len(self.dataset)
def __getitem__(self, index):
data=self.dataset[index]
#imread读的图片的属性,c-通道数,h-高,w-宽
img=cv2.imread(data)/255
#print(img.shape)
#c-0,H-1,W-2,CHW,numpy的transpose是转换图片属性顺序
#new_img=np.transpose(img,(2,0,1))
#print(new_img.shape)
#torch的tensor的permuate转换图片的属性顺序
new_img=torch.tensor(img).permute(2,0,1)
data_list=data.split('.')
label=int(data_list[1])
position_pre=data_list[2:6]
position=[int(i)/300 for i in position_pre]
sorts=int(data_list[6])
return np.float32(new_img),np.float32(label),np.float32(position),np.float32(sorts)
if __name__ == '__main__':
data=MyDataset(r'D:\NewUser2\ApplicationDataLocation\DatasetLocation\yellow_data')
"""
a=data.__getitem__(3)
for i in a:
print(i)
"""
总结
代码学习是在b站上跟着视频敲的: