本文整理汇总了Python中torchvision.transforms.Compose方法的典型用法代码示例。如果您正苦于以下问题:Python transforms.Compose方法的具体用法?Python transforms.Compose怎么用?Python transforms.Compose使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在模块torchvision.transforms的用法示例。
在下文中一共展示了transforms.Compose方法的26个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: load_data
点赞 8
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Compose [as 别名]
def load_data(root_path, dir, batch_size, phase):
transform_dict = {
'src': transforms.Compose(
[transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
'tar': transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])}
data = datasets.ImageFolder(root=root_path + dir, transform=transform_dict[phase])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4)
return data_loader
开发者ID:jindongwang,项目名称:transferlearning,代码行数:20,
示例2: get_data_loader
点赞 6
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Compose [as 别名]
def get_data_loader(opt):
if opt.dset_name == 'moving_mnist':
transform = transforms.Compose([vtransforms.ToTensor()])
dset = MovingMNIST(opt.dset_path, opt.is_train, opt.n_frames_input,
opt.n_frames_output, opt.num_objects, transform)
elif opt.dset_name == 'bouncing_balls':
transform = transforms.Compose([vtransforms.Scale(opt.image_size),
vtransforms.ToTensor()])
dset = BouncingBalls(opt.dset_path, opt.is_train, opt.n_frames_input,
opt.n_frames_output, opt.image_size[0], transform)
else:
raise NotImplementedError
dloader = data.DataLoader(dset, batch_size=opt.batch_size, shuffle=opt.is_train,
num_workers=opt.n_workers, pin_memory=True)
return dloader
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:20,
示例3: transform_for_train
点赞 6
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Compose [as 别名]
def transform_for_train(fixed_scale = 512, rotate_prob = 15):
"""
Options:
1.RandomCrop
2.CenterCrop
3.RandomHorizontalFlip
4.Normalize
5.ToTensor
6.FixedResize
7.RandomRotate
"""
transform_list = []
#transform_list.append(FixedResize(size = (fixed_scale, fixed_scale)))
transform_list.append(RandomSized(fixed_scale))
transform_list.append(RandomRotate(rotate_prob))
transform_list.append(RandomHorizontalFlip())
#transform_list.append(Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
transform_list.append(Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
transform_list.append(ToTensor())
return transforms.Compose(transform_list)
开发者ID:songdejia,项目名称:DeepLab_v3_plus,代码行数:24,
示例4: __init__
点赞 6
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Compose [as 别名]
def __init__(self, config):
self.config = config
if config.data_mode == "imgs":
transform = v_transforms.Compose(
[v_transforms.ToTensor(),
v_transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
dataset = v_datasets.ImageFolder(self.config.data_folder, transform=transform)
self.dataset_len = len(dataset)
self.num_iterations = (self.dataset_len + config.batch_size - 1) // config.batch_size
self.loader = DataLoader(dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.data_loader_workers,
pin_memory=config.pin_memory)
elif config.data_mode == "numpy":
raise NotImplementedError("This mode is not implemented YET")
else:
raise Exception("Please specify in the json a specified mode in data_mode")
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:25,
示例5: __init__
点赞 6
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Compose [as 别名]
def __init__(self, args, train=True):
self.root_dir = args.data
if train:
self.data_set_list = train_set_list
elif args.use_test_for_val:
self.data_set_list = test_set_list
else:
self.data_set_list = val_set_list
self.data_set_list = ['%06d.png' % (x) for x in self.data_set_list]
self.args = args
self.read_features = args.read_features
self.features_dir = args.features_dir
self.transform = transforms.Compose([
transforms.Scale((args.image_size, args.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
self.transform_segmentation = transforms.Compose([
transforms.Scale((args.segmentation_size, args.segmentation_size)),
transforms.ToTensor(),
])
开发者ID:ehsanik,项目名称:dogTorch,代码行数:27,
示例6: load_data
点赞 6
# 需要导入模块: from torchvision import transforms [as 别名]
# 或者: from torchvision.transforms import Compose [as 别名]
def load_data(data_folder, batch_size, phase='train', train_val_split=Tru