Pytorch
文章平均质量分 90
混个毕业罢了
悟兰因w
天雨虽宽,不润无根之草。
展开
-
常用组件详解(十一):torch.nn.Linear()
【定义三个样本】批次为3,每个样本特征数量为5。【定义全连接层】由于输入特征为5,故。维向量,是模型要学习的偏置(可通过。则是输出神经元的个数。,想要输出神经元数为10,故。是模型要学习的权重,原创 2024-10-09 15:54:07 · 730 阅读 · 0 评论 -
常用组件详解(十):保存与加载模型、检查点机制的使用
模型检查点(checkpoint)是指模型训练过程中保存的模型状态,包括模型参数(权重与偏置)、优化器状态等其他相关的训练信息。通过保存检查点,可以实现在训练过程中定期保存模型的当前状态,以便在需要时恢复训练或用于模型评估和推理。用于返回模型的状态字典,其中保存了模型的可学习参数。其中,只有可学习参数的层(卷积层、全连接层等)和注册缓冲区(batchnorm’s running_mean)才会作为模型参数保存(优化器也有状态字典,也可进行保存)。),即只保存学习到的模型参数,并通过。来保存模型的状态字典(原创 2024-10-06 10:26:11 · 909 阅读 · 0 评论 -
常用组件详解(九):学习率更新策略
适合的学习率能够更好地训练模型,为此衍生出多种学习率调整策略。一般来说,在训练初期希望学习率大一些,使得网络收敛迅速,在训练后期希望学习率小一些,使得网络更好的收敛到最优解。在包下提供了大量的学习率调整模块,如StepLR。对学习率的更新由学习率更新对象执行step()原创 2024-10-04 14:33:45 · 861 阅读 · 0 评论 -
常见组件详解(八):torch.nn.functional.interpolate()
【代码】常见组件详解(八):torch.nn.functional.interpolate()原创 2024-09-29 20:58:32 · 149 阅读 · 0 评论 -
常见组件详解(六):torch.nn.AvgPool2d()、torch.nn.AdaptiveAvgPool2d()
用于在由多个平面组成的输入信号上应用二维平均池化操作,输入尺寸为NCinHWNCinHW,输出尺寸为NCinHoutWoutNCinHoutWout。输出特征图形状为:其中,HinWinHinWin分别表输入图像的高度与宽度,默认向下取整。参数名功能池化核大小,整数或元组。stride步幅,默认等于kernel_size,整数或元组。padding填充,整数或元组。原创 2024-09-28 10:39:10 · 881 阅读 · 0 评论 -
常用组件详解(五):torch.nn.BatchNorm2d()
在卷积神经网络的卷积层之后通常会添加进行数据的归一化处理,将数据规范到均值为0,方差为一的分布上,使得数据在进行Relu时不会因为数据过大而导致网络性能的不稳定。原创 2024-09-25 19:05:15 · 1047 阅读 · 0 评论 -
Pytorch实战(二):VGG神经网络
本案例中使用FashionMNIST数据集,所以输入通道数为1nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),nn.ReLU(),return x。原创 2024-07-04 09:40:43 · 1213 阅读 · 1 评论 -
Pytorch实战(一):LeNet神经网络
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档。原创 2024-06-26 09:57:38 · 1156 阅读 · 2 评论 -
常用组件详解(四):torch.nn.MaxPool2d
参数名功能池化核的大小,可以是单值或元组stride池化操作的步长,可以是单值或元组,默认与kernel_size相等padding对原图像的填充大小,可以是单值或元组dilation池化核单元的距离,可用于实现空洞池化为 True 时,返回的是包含池化结果,以及元素对应下标的张量元组。ceil_mode为True时,使用ceil模式(2.31取2,向下取整,舍去小数,保留边缘剩余数据);原创 2024-09-24 14:49:05 · 880 阅读 · 0 评论 -
常用组件详解(二):torchsummary
库是一个好用的模型可视化工具,用于帮助开发者把握每个网络层级的细节,包括其中的连接和维度。表示输入数据的大小。约等于0.19MB。原创 2024-06-30 00:00:42 · 1280 阅读 · 0 评论 -
常用组件详解(二):torch.nn.Flatten、torch.flatten()、torch.Tensor.view()
是Pytorch提供的类,常用于将输入数据进行展平,而函数与之功能相同。原创 2024-06-26 16:54:04 · 1008 阅读 · 0 评论 -
常用组件详解(一):nn.Conv2d、nn.functional.conv2d()
点击跳转。类对象以NCinHWNCinHW作为输入图像数据,以NCoutHoutWoutNCoutHoutWout作为输出图像数据(仅批量大小NNN保持不变)。NNN:批量大小。CCC:通道数。HHH:图像高度。WWW:图像宽度。原创 2024-06-26 14:10:25 · 1271 阅读 · 0 评论 -
Pytorch基础
在Pytorch中可以创建自定义的数据集来加载数据,创建时需继承下的Dataset类,并实现__init____len__和方法。import os# 传入标签目录、数据目录、transform、target_transform# 返回样本个数# 返回指定索引处样本数据及标签#获取样本数据的路径#读取数据文件#读取对应标签#查看是否进行在Pytorch中可以创建自定义的数据集来加载数据,创建时需继承Dataset类,并实现__init____len__和方法。__init__:在实例化。原创 2024-06-21 17:55:27 · 778 阅读 · 0 评论 -
深度学习基础
深度学习入门原创 2024-06-16 20:51:21 · 1206 阅读 · 2 评论