Dataset的用法
Dataset提供一种方式去获取数据及其label
如何获取每一个数据及其label
告诉我们总共有多少的数据
Dataloader 为后面的网络提供不同的数据形式
PIL 中有 Image 通过 Image.open(file_name) 访问图片的信息
os 中有 os.listdir(file_bunch_name) 得到图片路径列表
import os
from PIL import Image
from torch.utils.data import Dataset
class MuData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir = "C:\\Users\\16499\\Pictures"
label_dir = "Screenshots"
img_dataset = MuData(root_dir,label_dir)
for i in range(1,img_dataset.__len__()):
temp_path,temp_label= img_dataset[i]
from torch.utils.data import DataLoader
#dataset 的用法
# import torchvision
from torch.utils.tensorboard import SummaryWriter
# dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
# train_set = torchvision.datasets.CIFAR10(root="./dataset",train = True,transform=dataset_transform,download = True)
# test_set = torchvision.datasets.CIFAR10(root="./dataset",train = False,transform=dataset_transform,download = True)
# # print(test_set[0])
# # print(test_set.classes)
# #
# # img, target = test_set[0]
# # print(img)
# # print(test_set.classes[target])
# # img.show()
writer = SummaryWriter("logs")
# for i in range(10):
# img,target = test_set[i]
# writer.add_image("test_set",img,i)
"""
dataLoader 的用法
batch_size = 2 每次抓两张牌
shuffle = true顺序不一样 = true 顺序一样
num_workers = 0 主进程进行加载 >0的时候windos 系统会出问题
drop_last 看除不完的时侯 看看余数的牌留不留 =true 舍去余数 = false 留下余数
"""
import torchvision
test_data = torchvision.datasets.CIFAR10(root="./dataset",train = False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
img,target = test_data[0]
"""
__getitem__():
return img,target
data
"""
for epoch in range(2):
step = 0
for data in test_loader:
imgs,target = data
# print(img.shape)
# print(target)
writer.add_images("Epoch:{}".format(epoch),imgs,step)
step = step + 1
writer.close()
Transform的常用用法
from fileinput import close
from idlelib.pyparse import trans
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img = Image.open("train_set_1/train/ants_image/0013035.jpg")
"""
transforms 中__call__的用法 就是输出
ToTensor 将 PIL 类型或者 numpy.ndarray 类型转化为tensor类型
使用transforms是把图片改成tensor类型
"""
writer = SummaryWriter("logs")
trans_to_tensor = transforms.ToTensor()
img_tensor = trans_to_tensor(img)
#writer.add_image("ToTensor",img_tensor) #img_tensor (torch.Tensor, numpy.ndarray, or string): Image data
#Normalize
trans_norm = transforms.Normalize([6,3,5],[3,2,1]) #output[channel] = (input[channel] - mean[channel]) / std[channel]
img_norm = trans_norm(img_tensor)
writer.add_image("Normalize",img_norm,1)
#Resize 需要的是PIL类型
print(img.size)
trans_resize = transforms.Resize((512,512))
# img PIL -> resize -> img_resize PIL
img_resize = trans_resize(img)
# img PIL -> Totensor -> img_resize tensor
img_resize = trans_to_tensor(img_resize)
writer.add_image("Resize",img_resize,1)
#Compose Compose - resize - 2
trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize_2,trans_to_tensor])
#trans_compose = transforms.Compose([trans_to_tensor,trans_resize_2]) 会爆错
"""Compose()用法
需要的是一个列表
需要时transforms类型
前面类型的输入要和输出相匹配
所以得到Compose([transforms_1,transforms_2,……])"""
img_resize_2 = trans_compose(img)
writer.add_image("Rize",img_resize_2,3)
#随机裁剪 RandomCrop
trans_random = transforms.RandomCrop(512)
trans_compose_2 = transforms.Compose([trans_random,trans_to_tensor])
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("test",img_crop,i)
"""
总结:关注输入和输出,多看官方文档
关注需要什么参数
不知道返回值的时候print 或者 print(type())
"""