DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
上面的DataLoader加载dataset的代码相当于如下
dataset=UnetDataset(data_path) # 继承自Dataset
dataset[0] # 迭代,接下来调用dataset[1], [2], [3], ...
其中每一次对象名加中括号的方式dataset[0]
是调用一次dataset对象的__getitem__()方法
Transform
from torchvision import transforms
from PIL import Image
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize()])
transforms(image),transforms(label_image)
transforms专门对image图片进行处理,而上述代码即对image和label_image两个图片对象进行ToTensor()和Normalize()处理。其中使用对象名加小括号的方法transforms(image)
是调用一次transforms对象的call()函数
,上述方法同transforms.call(image)
。
transforms.call(image)
transforms=transforms.Compose([transforms.ToTensor()(image)])
transforms=transforms.Compose([transforms.ToTensor().call(image)])
argparse 用法
argparse主要有两个作用,一个是生成字典,如parser.add_argument("--data-path", default="./", help="DRIVE root")
就是生成key = "--data-path"
, value = "./"
的键值对。
还有一个作用就是作为terminal中的参数传入,如上所示为python test.py --data-path="./"
import argparse
parser = argparse.ArgumentParser(description="pytorch unet training")
parser.add_argument("--data-path", default="./", help="DRIVE root")
# 将parser转换为字典返回
args = parser.parse_args()
return args
数据预处理
你可以在重现Dataset的__getitem__()
方法,或者transforms的列表中添加参数来达到数据域处理的效果。