pytorch
文章平均质量分 75
pytorch
万俟淋曦
CSDN专家博主,阿里云专家博主,中国人工智能学会会员。分享机器人领域技术,包括SLAM,ROS,CV,DL等,助力机器人领域研究者成长进步,为我国机器人研发与制造领域添砖加瓦。
展开
-
pytorch 学习之 学习率调整策略
学习率调整一、 所有学习率调整函数的基类class _LRSccheduler(object): def __init__(self, optimizer, last_epoch=-1):主要属性:optimizer:关联的优化器last_epoch:记录epoch数base_lrs:记录初始学习率主要方法:step():更新下一个epoch的学习率get_lr():虚函数,计算下一个epoch的学习率二、调整策略有序调整:Step、MultiStep、Expon原创 2021-01-28 21:16:08 · 375 阅读 · 0 评论 -
pytorch 学习之 优化器
优化器管理并更新模型中的可学习参数的值,使得模型输出更接近真实标签。初始化函数class Optimizer(object): def __init__(self, params, defaults): # 优化器超参数,是一个dict{超参数:值} self.defaults = defaults # 参数的缓存,如momentum的缓存 self.state = defaultdict(dict) # 管理的参原创 2021-01-28 21:12:04 · 262 阅读 · 0 评论 -
pytorch 学习之 损失函数
损失函数损失函数(loss):Loss=f(y^,y)Loss=f(\hat{y}, y)Loss=f(y^,y) ,衡量模型输出与真实标签的差异,针对一个样本。代价函数(cost):Coss=1N∑iNf(y^i,yi)Coss=\frac{1}{N}\sum_i^Nf(\hat{y}_i, y_i)Coss=N1∑iNf(y^i,yi) ,计算整个样本集loss的平均值。目标函数(objective):Obj=Cost+RegularizationObj=Cost+Regulariza原创 2021-01-28 21:07:35 · 553 阅读 · 0 评论 -
常用数据增强方法(基于pytorch)
技术不重要,而是思想。原则:让训练集与测试集更接近关于名称: 数据增强、数据扩增、数据增广 都是他。方法分类:空间位置:如平移色彩:如灰度图、色彩抖动形状:如放射变换上下文场景:如遮挡、填充具体方法:数据中心化数据标准化缩放裁剪旋转翻转填充噪声添加灰度变换线性变换仿射变换亮度、饱和度及对比度变换在深度学习模型的训练过程中,数据扩增是必不可少的环节。现有深度学习的参数非常多,一般的模型可训练的参数量基本上都是万到百万级别,而训练集样本的数量很难有这么多,数据扩增可以扩原创 2020-11-01 11:50:25 · 11582 阅读 · 1 评论 -
pytorch 学习之 训练流程
读取与加载数据从硬盘读取数据train_data = XXXDataset(data_dir=train_dir, transform=train_transform)valid_data = XXXDataset(data_dir=valid_dir, transform=valid_transform)加载数据到网络train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)vali.原创 2020-09-13 22:17:51 · 847 阅读 · 0 评论 -
pytorch 学习之 Tensor 基础
tensor作为pytorch的基本操作对象,是首先要了解的。一、tensor的8个属性:# 数据相关t.data # tensor的数据t.dtype # tensor的数据类型t.shape # tensor的形状t.device # tensor所在的设备# 梯度相关t.grad # data的梯度t.grad_fn # 创建tensor的functiont.requires_grad # 是否需要求导t.is_leaf # 是否是叶子结点二、tensor的创原创 2020-09-04 22:58:09 · 555 阅读 · 0 评论 -
Pytorch 报错及解决记录
1. ValueError: num_samples should be a positive integer value, but got num_samples=0情况描述:一般出现在DataLoader(dataset=train_data, …)行。可能原因:传入的dataset没有数据,这时的 train_data.__len__() == 0,可能是函数找不到数据文件,数据路径不对。解决方法:检查自定义的 XXDataset() 类,关于获取数据的部分,查看路径等是否正确。2. Runt原创 2020-08-05 16:27:50 · 3306 阅读 · 0 评论