Python-pytorch
吃葡萄不吐葡萄皮~
这个作者很懒,什么都没留下…
展开
-
使用GPU训练模型
方法一在模型、损失函数和数据中使用cudamodel1 = model()if torch.cuda.is_available(): model1 = model1.cuda()bloss = nn.CrossEntropyLoss()if torch.cuda.is_available(): bloss = bloss.cuda()for data in train_dataloader: imgs, targets = data if to原创 2022-01-27 23:19:28 · 2209 阅读 · 0 评论 -
完整的模型训练讨论
模型训练套路——以cifar10为例1.加载数据集dataset_transform_compose = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform_compose, download=True)test_set = torchvision.datasets.C原创 2022-01-27 22:56:12 · 428 阅读 · 0 评论 -
Pytorch——模型的保存与加载
1. 方法一(保存模型及参数)vgg16_false = torchvision.models.vgg16(pretrained=False)##加载初始模型,模型参数还没有训练torch.save(vgg16_false, "method1.pth")torch.load("method1.pth") 2. 方法二(仅保存模型参数)torch.save(vgg16_false.state_dict(), "method2.pth")torch.load("method2.pth")注:当原创 2022-01-25 01:51:22 · 1286 阅读 · 0 评论 -
Pytorch——加载并修改已有模型
1. 加载已有模型——以VGG为例vgg16_true = torchvision.models.vgg16(pretrained=True)##加载参数已经训练好的模型vgg16_false = torchvision.models.vgg16(pretrained=False)##加载初始模型,模型参数还没有训练2 . 修改模型2.1 在模型的基础上增加网络结构vgg16_false.add_module("linear", nn.Linear(1000, 10))print(vgg16_原创 2022-01-25 01:36:00 · 692 阅读 · 0 评论 -
Pytorch——优化器
优化器的使用1.创建优化器2.循环内梯度清零3.计算梯度4.根据梯度优化loss例子:dataset_transform_compose = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform_compose, download=True)test_set = t原创 2022-01-25 01:01:51 · 179 阅读 · 0 评论 -
Pytorch——损失函数
1.常见损失函数1.1 L1loss例子:x = torch.Tensor([1, 2, 3])y = torch.Tensor([1, 0, 6])L1loss = nn.L1Loss()loss1 = L1loss(x, y)print(loss1)输出:tensor(1.6667)1.2 MSELoss例子:x = torch.Tensor([1, 2, 3])y = torch.Tensor([1, 0, 6])mseloss = nn.MSELoss()los原创 2022-01-25 00:40:41 · 376 阅读 · 0 评论 -
Pytorch——深度神经网络
1.卷积层(以Conv2d为例)注:输入的参数包括batchsize,通道数,图像宽度和高度,对于一般的图像通常只有宽度和高度两个参数,所以可以使用reshape函数改变尺寸。例1:以张量为例input = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9] ])print(input.shape)input = torch.原创 2022-01-24 23:59:31 · 1161 阅读 · 0 评论 -
Pytorch数据集使用
写在最前:多看pytorch官方文档1.官方文档中给出了数据的参数及使用方法2.实战dataset_transform_compose = transforms.Compose([transforms.ToTensor()])#下载数据集train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform_compose, download=True)test_set原创 2022-01-24 02:20:49 · 886 阅读 · 0 评论 -
Transform使用方法
1.整体框架例子:(以transforms.ToTensor()为例)from torchvision import transformsimage_path = "IMG_1098.JPG"image = Image.open(image_path)trans_to_tensor = transforms.ToTensor()image_tensor = trans_to_tensor(image)print(image_tensor)2.常见的transform(1)Normal原创 2022-01-24 00:03:52 · 2725 阅读 · 0 评论