pyTorch框架
主要介绍pyTorch框架的使用
半臻(火白)
技术栈:Python、Java、大数据、人工智能
展开
-
PyTorch基础5——自定义损失函数
自定义损失函数自定义损失函数与自定义网络类似。需要继承nn.Module类,然后重写forward方法即可# 自定义损失函数,交叉熵损失函数class MyEntropyLoss(nn.Module): def forward(self,output,target): batch_size_ = output.size()[0] # 获得batch_size num_class = output[0].size()[0] #获得类别数量 lab原创 2021-07-13 16:57:36 · 1531 阅读 · 1 评论 -
PyTorch基础4——加载模型权重
加载模型权重包括1 加载完全的模型权重2 加载某一层的模型权重3 根据tensor形状加载模型权重from torch import nnimport torch# 定义一个网络class Model(nn.Module): def __init__(self,class_num,input_channel=3): super(Model, self).__init__() self.conv1 = nn.Conv2d(in_channels=in原创 2021-07-13 11:12:13 · 3888 阅读 · 2 评论 -
pyTorch基础3——自定义数据集
自定义数据集的步骤定义一个类,并继承 torch.utils.data.Dataset在__init__(构造方法中) 写需要读取的所有数据和标签,如果是图片可以写所有的图片路径在__len__ 方法中定义数据集的总长度在__getitem__ 中写每次循环时调用的方法,index表示当前循环的下标将定义好的类,放入torch.utils.data.DataLoader之中,设置batchsize等信息用迭代器取出每一个数据# 数据处理import torchfrom torch.原创 2021-07-13 11:07:30 · 354 阅读 · 0 评论 -
PyTorch基础2——快速开始
快速开始接下来会用三个例子,展示用PyTorch训练一个网络的基本代码import torchfrom torch import nnimport random# step1:加载数据# 随机生成100张图片和标签# 每张图片是32*32的,并且有三个通道# 随机生成5个标签(类别)img_list = []label_list = []for i in range(100): img = torch.rand(3,32,32) label = random.ran原创 2021-07-12 20:38:52 · 289 阅读 · 0 评论 -
PyTorch基础1——张量
参考:《20天吃掉那只pytorch》4、张量 张量是pytroch中最重要的数据类型,神经网络中操作的数据都是张量。输入的图片是一个张量,中间的隐藏层也是张量,最后输出的结果也是张量。 所以懂得张量的基本操作就成了pytroch的基本功。 张量是一个多维数组,维度可以从0到n 如果维度为0那么就是一个常数,如果维度为1那么就是一个向量,如果维度为2那么就是一个矩阵,如果维度为3就是一个立方体,如果维度为4 …4.0 张量的介绍 Pytorch的基本数据结构是张量。 张量即多维数组。原创 2021-07-11 20:20:32 · 809 阅读 · 2 评论