def load_videos(train_seq_path ,Batch_size):
train_transforms = transforms.Compose(
[transforms.ToTensor(), transforms.CenterCrop((256, 256))] # 这里改成CenterCrop则每次遍历train_dataloader会保持数据一致
)
train_dataset = VideoFolder(
train_seq_path,
rnd_interval=False,
rnd_temp_order=False,
split="train",
transform=train_transforms,
max_frames = 2, #若max_frames = 4,则把每个batch的图片限定在了4
)
train_dataset = torch.utils.data.Subset(train_dataset, indices=range(4))
train_dataloader = DataLoader(
train_dataset,
batch_size=Batch_size,
num_workers=8,
shuffle=False,
# drop_last=True, #修改,丢弃最后一组
pin_memory=(device == "cuda"),
)
total_data = len(train_dataloader.dataset)
return train_dataloader, total_data
08-15