Pytorch基础知识
Pytorch的一般流程:
准备数据—定义网络 —训练—可视化—测试
一、读取数据指令 torch.utils.data
torch.utils.data.Dataset是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__() 和__getitem__()
len():返回的是数据集的大小。
getitem():实现索引数据集中的某一个数据。
class dataset(data.Dataset):
# 参数预定义 重载
def __init__(self, anno_pd, transforms=None,debug=False,test=False):
self.paths = anno_pd['ImageName'].tolist() # 图像路径
self.labels = anno_pd['label'].tolist() # 图像数字标签
self.transforms = transforms # 数字增强
self.debug=debug # 程序调试
self.test=test # 判定是否训练或测试
def __len__(self): # 返回图片个数
return len(self.paths)
def __getitem__(self, item):# 获取每个图片
img_path =self.paths[item] # 图像路径
img =cv2.imread(img_path) # 读取
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) # 格式转换
if self.transforms is not None: # 是否进行数据增强
img = self.transforms(img)
label = self.labels[item] # 图像对应标签
return torch.from_numpy(img).float(), int(label) # tensor和对应标签
二、搭建网络模块的指令
torch.nn
卷积层、全连接层 等都来这个函数
三 、训练指令
优化方法的选择 torch.optim
学习率的调整 torch.optim
网络参数初始化的选择 torch.nn.init