
pytorch
风会记得一朵花的香:)
这个作者很懒,什么都没留下…
展开
-
神经网络——利用GPU训练
有两种方式可以实现目录方式1(.cuda)方式2(.to(device))方式1(.cuda)将网络模型、数据(输入、标注)和损失函数引入cuda()网络模型: if torch.cuda.is_available(): test = test.cuda() # 使用GPU损失函数:if torch.cuda.is_available(): loss_fn = loss_fn.cuda()数据:(训练数据和测试数据)imgs, targets = dataif to原创 2022-01-27 22:32:40 · 2050 阅读 · 0 评论 -
神经网络——完整的模型训练套路
采用CIFAR10数据集步骤:1. 准备数据集2. 运用dataloader下载数据3. 搭建神经网络4. 创建损失函数5. 创建优化器6. 设置训练轮数实例:# 神经网络——完整的模型训练套路import torchimport torchvisionfrom torch import nnfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterfrom原创 2022-01-27 21:20:16 · 1166 阅读 · 0 评论 -
神经网络——模型的保存,模型的加载
目录模型的保存(torch.save)方式1(模型结构+模型参数)方式2(模型参数)模型的加载(torch.load)对应保存方式1对应保存方式2方式1存储,加载时需注意事项模型的保存(torch.save)方式1(模型结构+模型参数)参数:保存位置# 创建模型vgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1——模型结构+模型参数torch.save(vgg16, "vgg16_method1.pth")方式2(模型参数)原创 2022-01-27 12:30:55 · 2649 阅读 · 0 评论 -
现有网络模型的使用及修改
在现有网络上进行修改假设已有vgg16网络。网络结构:添加一个线性层:vgg16_True.add_module('add_linear', nn.Linear(1000, 10))网络结构:修改vgg16_False.classifier[6] = nn.Linear(4096, 10)表示修改vgg16_False中classifier中的第7层(从0开始)原始:修改后:...原创 2022-01-27 11:31:38 · 730 阅读 · 0 评论 -
神经网络——损失函数、反向传播与优化器
lossloss越小越好计算实际输出和目标之间的差距为我们更新输出提供一定的依据(反向传播)调用torch中已有损失函数:result_loss = loss(output, target)backward反向传播:计算每一个参数的梯度result_loss.backward()优化器...原创 2022-01-27 09:01:04 · 528 阅读 · 0 评论 -
神经网络——非线性激活
这一部分比较简单ReLU公式:图像对应参数和输入输出的要求:sigmoid公式:图像:输入输出要求:原创 2022-01-26 10:55:56 · 189 阅读 · 0 评论 -
神经网络——最大池化的使用
MaxPool2d图1. 参数stride:默认为池化核大小ceil_mode:为True则保留原创 2022-01-26 09:49:16 · 447 阅读 · 0 评论 -
神经网络——卷积层(conv2d)
in_channels:输入通道个数。例如彩色图片有3个通道out_channels:输出通道个数。out_channals=2时,会生成2个卷积核与输入图像进行卷积。kernel_size:卷积核大小。padding_mode:padding填充模式。groups:一般设置成1bias:偏置。一般设置为Trueimport torchimport torchvisionfrom torch.nn import Conv2dfrom torch.utils.data import D..原创 2022-01-25 12:28:09 · 1893 阅读 · 0 评论 -
pytorch中的dataset和datalaoder
1. 使用Dataset数据集import torchvision使用方法:# 1. 如何使用dataset标准数据集train_set = torchvision.datasets.CIFAR10(root="./dataset2", train=True, transform=dataset_transform, download=True)test_set = torchvision.datasets.CIFAR10(root="./dataset2", train=False, tran原创 2022-01-24 17:45:21 · 703 阅读 · 0 评论