10.16Pytorch学习日志

本文介绍了PyTorch中torch.utils.data.Dataset和DataLoader的使用,它们简化了自定义数据集的读取和处理。Dataset是抽象类,通过定义__len__和__getitem__方法实现定制数据。DataLoader则用于按batch读取数据,支持shuffle和多线程。此外,还提到了torch.gather函数的作用,以及torch.log的用法。最后,提到了在PyTorch中实现softmax的技巧,并祝作者的朋友欣欣10.16生日快乐。
摘要由CSDN通过智能技术生成

torch.utils.data.Dataset
torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。
通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
总之,通过torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷

gather 函数
gather(input(tensor),dim(int),index(LongTensor))
第一个参数:输入的张量
第二个参数:指定操作的维度,0为列, 1为行
第三个参数:需要获取维度的索引位置
收集输入的特定维度指定位置的数值
ex:x = [[1,2,3]
[4,5,6]]
y = torch.gather(x,1,[0,2]) (按行获取x中, 第一个数, 第三个数)
y = [ [1],
[6] ]

torch.log():
torch.log是以自然数e为底的对数函数
ex:
torch.log(torch.Tensor([1]) = 0
log e ^ 1 = 0 (e的0次方=1)

torch.log(torch.Tensor([2.714…]) = 0.999
log e ^2.714 = 0.99 (e的0.99次方=2.714…)

x.shape[0] :
x的第一维度那个数
ex:
x=(2,3,4)
x.shape[0]=2, [2]=3, [2]=4

raise NotImplementedError:
一般是继承的函数出错,语法,缩进错误


softmaxPytorch搭法

import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

sys.path.append("..")  # 为了导⼊上层⽬录的d2lzh_pytorch
import d2lzh_pytorch as d2l

# mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
#                                                 train=True, download=False, transform=transforms.ToTensor())
# mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
#                                                train=False, download=True, transform=transforms.ToTensor())

# print(type(mnist_train))
# print(len(mnist_train), len(mnist_test))


# feature, label = mnist_train[0]
# print(feature.shape, label)  # Channel x Height X Width

# 显示前十个数据
# X, y = [], []
# for i in range(10):
#     X.append(mnist_train[i][0])
#     y.append(mnist_train[i][1])
# d2l.show_fashion_mnist(X, d2l.get_fashion_mnist_labels(y))

batch_size = 256
# if sys.platform.startswith('win'):
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值