Deep Learning with Pytorch 中文简明笔记 第四章 Real-world data representation using tensors

Deep Learning with Pytorch 中文简明笔记 第四章 Real-world data representation using tensors

Pytorch作为深度学习框架的后起之秀,凭借其简单的API和简洁的文档,收到了越来越多人的关注和喜爱。本文主要总结了 Deep Learning with Pytorch 一书第四章[Real-world data representation using tensors]的主要内容,并加以简单明了的解释,作为自己的学习记录,也供大家学习和参考。

主要内容

  • 使用Pytorch的Tensor来表示现实世界的数据
  • 使用混合的数据类型
  • 从文件中读取数据
  • 将数据转换为Tensor
  • 改变Tensor的形状以适应模型

1. 平面图像

在消费级相机中,一般使用8-bit整数所表示的范围(0-255)来刻画像素,但是在科学研究、医疗和工业应用中,12-bit和16-bit的编码也并不罕见。

一般而言,彩色图片分为RGB三个颜色通道,灰度图则只有一个通道。不失一般性,均可以表示为H×W×C,即height×width×channel。使用imageio模块可以时间读取一个图片为numpy array的目的。

# In[2]: 
import imageio
img_arr = imageio.imread('../data/p1ch4/image-dog/bobby.jpg') 
img_arr.shape

# Out[2]: 
(720, 1280, 3)

Pytorch适用于处理C×H×W的图片数据,故使用permute()进行变换,注意此时并没有复制Tensor,使用的是同一份数据。

img = torch.from_numpy(img_arr) 
out = img.permute(2, 0, 1)

对于mini-batch的情况,则将数据叠加为N×C×H×W

# In[4]: 
batch_size = 3 
batch = torch.zeros(batch_size, 3, 256, 256, dtype=torch.uint8)

图片的正则化,需要计算出数据集的均值和方差,然后减去均值除以标准差,如果有多个通道,则需要分别对每个通道单独运算处理。

# In[6]: 
batch = batch.float() 
batch /= 255.0

# In[7]: 
n_channels = batch.shape[1] 
for c in range(n_channels): 
	mean = torch.mean(batch[:, c]) 
	std = torch.std(batch[:, c])
	batch[:, c] = (batch[:, c] - mean) / std

2. 3D图像

对于一些医疗图像(CT扫描)或者深度图像,在N×C×H×W的基础上增加一个维度深度depth,变为N×C×D×H×W。但是医疗图像中一般不会出现彩色图片,一般为灰度图,所以C=1,也并不影响其他的操作。

# In[2]: 
import imageio
dir_path = "../data/p1ch4/volumetric-dicom/2-LUNG 3.0 B70f-04083" 
vol_arr = imageio.volread(dir_path, 'DICOM') 
vol_arr.shape

# Out[2]: 
Reading DICOM (examining files): 1/99 files (1.0%99/99 files (100.0%) 
	Found 1 correct series. 
Reading DICOM (loading data): 31/99 (31.392/99 (92.999/99 (100.0%)
(99, 512, 512)

之后使用unsqueeze在第0维(最前面)添加一个维度,即深度d

# In[3]: 
vol = torch.from_numpy(vol_arr).float() 
vol = torch.unsqueeze(vol, 0)
vol.shape

# Out[3]: 
torch.Size([1, 99, 512, 512])

3. 扁平数据

这里的扁平化数据一般指的是可以用一张二维表来表示的数据。如书中的例子,一些不同葡萄酒的样品的不同指标,和此葡萄酒对应的品级。可以很轻易的在二维表中表示出来。
在这里插入图片描述

通常这些数据以csv(Comma-separated values)文件存储,或者是xls的表格文件存储。书中以csv文件举例,使用numpy读取。

# In[2]: 
import csv wine_path = "../data/p1ch4/tabular-wine/winequality-white.csv" 
wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=";", skiprows=1)
wineq_numpy

# Out[2]: 
array([[ 7. , 0.27, 0.36, ...
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>