一、transform
1. transform定义
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop((32, 32), padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5 ,0.5), (0.5, 0.5, 0.5)),
])
2. 使用方法
txt_path = 'G:/train/path.txt'
class MyDataset(Dataset):
def __init__(self, txt_path,transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rsplit('\t')
imgs.append((line[0],line[1]))
self.imgs = imgs
self.transform = transform
arget_transform = target_transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_path,label = self.imgs[index]
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img,label
train_set = MyDataset(txt_path,transform=transform)