torch.utils.data.Dataset
torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。
通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
总之,通过torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷
gather 函数:
gather(input(tensor),dim(int),index(LongTensor))
第一个参数:输入的张量
第二个参数:指定操作的维度,0为列, 1为行
第三个参数:需要获取维度的索引位置
收集输入的特定维度指定位置的数值
ex:x = [[1,2,3]
[4,5,6]]
y = torch.gather(x,1,[0,2]) (按行获取x中, 第一个数, 第三个数)
y = [ [1],
[6] ]
torch.log():
torch.log是以自然数e为底的对数函数
ex:
torch.log(torch.Tensor([1]) = 0
log e ^ 1 = 0 (e的0次方=1)
torch.log(torch.Tensor([2.714…]) = 0.999
log e ^2.714 = 0.99 (e的0.99次方=2.714…)
x.shape[0] :
x的第一维度那个数
ex:
x=(2,3,4)
x.shape[0]=2, [2]=3, [2]=4
raise NotImplementedError:
一般是继承的函数出错,语法,缩进错误
softmaxPytorch搭法
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
sys.path.append("..") # 为了导⼊上层⽬录的d2lzh_pytorch
import d2lzh_pytorch as d2l
# mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
# train=True, download=False, transform=transforms.ToTensor())
# mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
# train=False, download=True, transform=transforms.ToTensor())
# print(type(mnist_train))
# print(len(mnist_train), len(mnist_test))
# feature, label = mnist_train[0]
# print(feature.shape, label) # Channel x Height X Width
# 显示前十个数据
# X, y = [], []
# for i in range(10):
# X.append(mnist_train[i][0])
# y.append(mnist_train[i][1])
# d2l.show_fashion_mnist(X, d2l.get_fashion_mnist_labels(y))
batch_size = 256
# if sys.platform.startswith('win'):