Pytorch学习笔记(一)
一、Pytorch加载数据
1、Dataset
- 作用:提供一种方式获取数据和label并编号,可以根据编号提取相应的数据,同时获取数据对应的相应的label
class MyData(Dataset):
def __init__(self, root_dir, label_dir): #为后面的函数提供全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path) #获取训练集中每一个图片
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) #获取图片的相对路径
img = Image.open(img_item_path) #读取图片
label = self.label_dir
def __len__(self):
return len(self.img_path)
2、Dataloader
- 作用:对获取的数据进行打包(大小为batch_size),为后面的网络提供不同的数据形式
二、TensorBoard的使用
1、SummaryWriter类的使用
add_scalar
writer = SummaryWriter("logs") #创建实例
for i in range(100): ##实现y=x的图像
writer.add_scalar("y=x", i, i) #里面的参数分别是tag,scalar_value(y轴),global_step(x轴)
writer.close()
运行程序后,此时logs文件中产生一个事件。打开方式:
- 在命令行中输入tensorboard --logdir=logs,打开下面的链接即可看到加载的函数图像 # logdir=事件文件所在文件夹名,可以在后面加–port=xxxx来指定端口名
如果图像出现错误,可以删除事件文件夹中的所有事件重新运行程序,打开新的事件
add_image
参数:tag,img_tensor(数据类型:torch.tensor, numpy_array;数据大小为(3, H, W),如果是(H, W, 3)的话需要改参数dataformats为HWC),global_step, walltime,dataformats
writer = SummaryWriter("logs")
image_path = "data/train/ants_image/0013035.jpg" #输入图片的相对路径
img_PIL = Image.open(image_path) #使用PIL.Image打开的图片类型默认为PIL类型,需要转化为numpy类型或者tensor类型
img_array = np.array(img_PIL)
writer.add_image("test", img_array, 1, dataformats='HWC') #global_step可以相同的tag下以global_step的步数大小展示数据集里的不同图片
writer.close()
运行程序后,此时logs文件中产生一个事件。打开方式:
在命令行中输入tensorboard --logdir=logs,打开下面的链接即可看到加载的图片 # logdir=事件文件所在文件夹名,可以在后面加–port=xxxx来指定端口名
注:如果图像出现错误,可以删除事件文件夹中的所有事件重新运行程序,打开新的事件
三、Transforms的使用
注:__call__()是一个内置函数,调用时可以直接使用对象(参数)的方法调用
1、ToTensor()
将PIL或者是numpy数据类型转换成tensor数据类型
img_path = "data/train/ants_image/0013035.jpg"
img = Image.open(img_path)
tensor_trans = transforms.ToTensor() #创建ToTensor类的对象
tensor_img = tensor_trans(img) #调用ToTensor里的内置函数将img转化为tensor数据类型并返回
Tensor数据类型包装了有关反向神经网络的一些相关参数
读取得到numpy.ndarray的数据使用opencv
cv_img = cv2.imread(img_path)
读取得到PIL的数据使用PIL.Image
PIL_img = Image.open(img_path)
2、Normalize()
注:__call__()是一个内置函数,调用时可以直接使用对象(参数)的方法调用
输入:mean(均值),std(标准差)
归一化:input[channel] = (input[channel] - mean[channel]) / std[channel]
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
3、Resize()
作用:改变图片大小,输入为PIL数据类型
输入:size数据类型为序列或是int
trans_resize = transforms.Resize((512, 512)) #将图片大小变成512*512
img_resize = trans_resize(img) #img_resize是PIL类型,如果需要add_image(),需要先将PIL的数据类型通过ToTensor转换成tensor数据类型
trans_resize2 = transforms.Resize(512) #如果只输入一个数,便用图像的最小的边去匹配这个大小,然后进行等比缩放
img_resize2 = trans_resize2(img)
4、Compost()
作用:将几个transforms的操作组合起来使用
Compost()的参数应该是transforms类型元素的列表。即:
Compose([transforms参数1, transforms参数2,……])
trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize_2, trans_totensor]) #将resize与totensor组合起来
img_resize_2 = trans_compost(img) ##图片变成tensor类型的等比缩放到最短边为512的图片
5、RandomCrop()
随机裁剪图片
输入为PIL类型
trans_random = transforms.RandomCrop(512)
trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
#随机裁剪十个
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("RandomCrop", img_crop, i)
总结
查看官方文档时首先查看输入和输出类型,然后查看方法中没有默认值的参数的作用和数据类型。