pytorch
小驰驰呕吼**
热爱计算机的小可爱一枚
展开
-
完整的模型验证套路
import torchimport torchvisionfrom PIL import Imagefrom torch import nnimage_path = "./imgs/dog.jpeg"image = Image.open(image_path)# 图片保留三通道image = image.convert('RGB')transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32原创 2022-03-01 16:51:15 · 156 阅读 · 0 评论 -
使用GPU训练
GPU使用1import torchimport torchvision.datasetsfrom torch.utils.tensorboard import SummaryWriterimport time# 准备数据集from torch import nnfrom torch.utils.data import DataLoaderimport nn_seqtrain_data = torchvision.datasets.CIFAR10(root="../dataset"原创 2022-03-01 15:40:55 · 264 阅读 · 0 评论 -
完整的模型训练套路
import torchimport torchvision.datasetsfrom torch.utils.tensorboard import SummaryWriterfrom MyModule import *# 准备数据集from torch import nnfrom torch.utils.data import DataLoaderimport nn_seqtrain_data = torchvision.datasets.CIFAR10(root="./datas原创 2022-03-01 11:53:15 · 213 阅读 · 0 评论 -
网络模型的保存与读取
保存import torchimport torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1 保存模型结构 + 模型参数torch.save(vgg16, "vgg16_method1.pth")# 保存方式2 保存模型参数(官方推荐)torch.save(vgg16.state_dict(), "vgg16_method2.pth")加载import torchimport torchvi原创 2022-02-28 22:17:27 · 144 阅读 · 0 评论 -
现有网络模型的使用和修改
import torchvisionfrom torch import nnvgg_false = torchvision.models.vgg16(pretrained=False)# 参数使用在别的数据集上训练好的初始化vgg_true = torchvision.models.vgg16(pretrained=True)# 在现有的网络结构进行一些修改,使它适用于自己的网络结构train_data = torchvision.datasets.CIFAR10("./dataset",原创 2022-02-28 19:12:25 · 140 阅读 · 0 评论 -
优化器的使用
import torchfrom torch import nnimport torchvisionfrom torch.nn import Conv2d, MaxPool2d, ReLU, Linear, Flattenfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("./dataset原创 2022-02-28 18:39:21 · 78 阅读 · 0 评论 -
损失函数和反向传播
import torchfrom torch.nn import L1Loss, MSELoss, CrossEntropyLossinputs = torch.tensor([1, 2, 3], dtype=torch.float32)targets = torch.tensor([1, 2, 5], dtype=torch.float32)inputs = torch.reshape(inputs, (1, 1, 1, 3))targets = torch.reshape(targets,原创 2022-02-28 16:42:28 · 95 阅读 · 0 评论 -
使用Sequential
不使用Sequentialimport torchfrom torch import nnimport torchvisionfrom torch.nn import Conv2d, MaxPool2d, ReLU, Linear, Flattenfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterclass MyModule(nn.Module): de原创 2022-02-28 10:55:47 · 175 阅读 · 0 评论 -
神经网络线性层
import torchfrom torch import nnimport torchvisionfrom torch.nn import Conv2d, MaxPool2d, ReLU, Linearfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("./dataset", train=原创 2022-02-28 10:23:08 · 572 阅读 · 0 评论 -
神经网络非线性激活层
import torchfrom torch import nnimport torchvisionfrom torch.nn import Conv2d, MaxPool2d, ReLUfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterinput = torch.tensor([[1, -0.5], [-1, 3]])in原创 2022-02-27 14:20:03 · 83 阅读 · 0 评论 -
最大池化层的使用
import torchfrom torch import nnimport torchvisionfrom torch.nn import Conv2d, MaxPool2dfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../dataset", train=False, transf原创 2022-02-24 16:13:30 · 109 阅读 · 0 评论 -
神经网络卷机层
import torchfrom torch import nnimport torchvisionfrom torch.nn import Conv2dfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvisi原创 2022-02-24 15:36:27 · 578 阅读 · 0 评论 -
神经网络基本骨架-nn.Module的使用
import torchfrom torch import nnclass MyModule(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, input): output = input + 1 return outputmoudle = MyModule()x = torch.tensor(1.0)o原创 2022-02-21 22:22:38 · 407 阅读 · 0 评论 -
DataLoader的使用
import torchvisionfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWriter# 准备测试集test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())# batch_size为批量大小(即每次取四张图片原创 2022-02-21 21:16:50 · 212 阅读 · 0 评论 -
torchvision数据集的使用
import torchvisionfrom torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_tr原创 2022-02-21 20:39:47 · 101 阅读 · 0 评论 -
transform的使用2
from PIL import Imagefrom torchvision import transformsfrom torch.utils.tensorboard import SummaryWriterimg = Image.open("/Users/computer/Documents/Code/pytorchLearning/imgs/five.png")trans_toTensor = transforms.ToTensor()tran_img = trans_toTensor(im原创 2022-02-20 15:28:32 · 1877 阅读 · 0 评论 -
transform的使用1
from PIL import Imagefrom torchvision import transformsfrom torch.utils.tensorboard import SummaryWriterimport cv2writer = SummaryWriter("logs")# 1. transforms如何使用img_path = "dataset/train/ants/0013035.jpg"img = Image.open(img_path)# 创建对象tensor原创 2022-02-14 18:19:59 · 667 阅读 · 0 评论 -
tensorboard使用2
from torch.utils.tensorboard import SummaryWriterfrom PIL import Imageimport numpy as npwriter = SummaryWriter("logs")img_path = "dataset/train/ants/0013035.jpg"img_PIL = Image.open(img_path)img_array = np.array(img_PIL)print(type(img_array))# 输出原创 2022-02-14 16:25:50 · 1995 阅读 · 0 评论 -
TensorBoard使用1
代码from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")for i in range(100): writer.add_scalar("y=x", i , i)writer.close()运行后生成一个文件运行图片中命令点击网址即可查看参考课程视频地址:https://www.bilibili.com/video/BV1hE411t7RN?p=8&spm_id原创 2022-02-12 14:11:10 · 157 阅读 · 0 评论 -
读取数据集
读取数据数据目录结构代码from torch.utils.data import Datasetimport osfrom PIL import Image# 使用pytorch对数据集进行操作需要继承Dataset类,重写相关方法class MyData(Dataset): # 初始化,root_dir为跟路径,data_dir为跟路径下数据集的名字,该名字为该数据集下所有图片的标签 def __init__(self, root_dir, data_dir):原创 2022-02-12 12:36:34 · 974 阅读 · 0 评论