1python学习中两法宝函数
dir(torch)
dir(torch.cuda)
dir(torch.cuda.is_available)
输出双下划线的变量,表示私有属性,你无法修改它
help(torch.cuda.is_available)
返回这个函数怎么用
is_available() -> bool
Returns a bool indicating if CUDA is currently available.
2pytorch中加载数据
两个类 Dataset,DataLoader
Dataset
自定义数据集需要用到它
from torch.utils.data import Dataset
command+点击Dataset进入
可以看到Dataset是个abstract class,所有的数据集必须 subclass(继承)这个类。所有的子类应该重写(overwrite)__getitem__这个方法。
3transform
transforms相当于一个工具箱,里面有很多工具。
小tips:
点结构可以看整个文件的所有类
Totensor
用PIL读取为PIL格式,用opencv读取是ndarray。
from PIL import Image
from torchvision import transforms
img_path = "/Users/apple/PycharmProjects/learn_pytorch_tudui/hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(img_path)
trans = transforms.ToTensor()# 相当于执行ToTensor类的__call__方法
tensor_img = trans(img)
print(tensor_img)
Normalize