Pytorch-5 : 视觉深度学习prepare data - ETL流程,Dataset 和 DataLoader类

ETL

  • Extract : 获得图片数据
  • Transfrom :将图片转换为tensor形式
  • Load : 将数据转为对象,以便于使用

为了实现以上流程,PyTroch提供这两个类:

  • torch.utils.data.Dataset : 抽象类(abstract class),代表dataset
  • torch.utils.data.DataLoader : 包装dataset,从而方便诸如 batch, shuffle 等操作

在继承抽象类时,必须定义某些函数。

class OHLC(Dataset):
	def __init__(self, csv_file):
		self.data = pd.read_csv(csv_file)
	def __getitem__(self, index):##MUST
		r = self.data.iloc[index]
		label = torch.tensor(r.is_up_day, dtype=torch.long)
		sample = self.normalize(torch.tensor([r.open, r.high, r.low, r.close]))
		return sample, label
	def __len__(self):##MUST
		return len(self.data)

torchvision

使用Pytorch做视觉深度学习的话,这个包必不可少。torchvision的内容:

  • Dataset:包含如MNIST的数据集
  • Models:包含VGG-16等
  • Transforms
  • Utils

接下来用FashionMNIST做一个ETL的sample:

import torch
import torchvision
import torchvision.transforms as transforms

#ET
train_set = torchvision.datasets.FashionMNIST(
	root = './data/FashionMNIST'
	,train=True							#标记为训练集
	,download=True						#注意,不要下载已有的数据
	,transform=transforms.Compose([		#转为tensor
		transform.ToTensor()
	])
)
#L
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)

看一下数据集内容

先从数据集整体来看:

len(train_set)						#->	60000

train_set.train_labels				#-> tensor([..lables...])

train_set.train_lables.bincount()	#-> tensor([6000, 6000...., 6000])

train_set.train_lables.bincount 返回每个 train_lables 在 train_set 中出现的次数,可以看到,本数据集是均衡的(balanced)。

再从数据集中单个数据来看:

sample = next(iter(train_set))

len(sample)		#-> 2 ,一个是图片,一个是label

type(sample)	#-> tuple

image, label = sample

img.shape		#-> torch.Size([1, 28, 28])

label.shape		#-> torch.Size([])

从每个batch来看:

batch = next(iter(train_loader))##注意此处是trainloader而不是trainset

len(batch)		#-> 2 

type(batch)		#-> list

images, labels = batch

imgs.shape		#-> torch.Size([10, 1, 28, 28])

labels.shape	#-> torch.Size([10])

#展示batch中的图片内容:
grid = torchvision.utils.make_graid(images, nrow=10)#make_grad用于批量展示图片

plt.figure(figsize=(15,15))#设置每张图片显示的大小
plt.imshow(np.transpose(grid,(1,2,0)))#将grid transpose 成 imshow()接收的形式

print('labels:',labels)


Output

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值