源代码
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size]
)
解决方法
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size],generator=torch.Generator(device='cuda')
)
)
dataloader这部分也需要修改:
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=collate_fn
)
test_loader = DataLoader(
test_dataset, batch_size=test_size, shuffle=True,
collate_fn=collate_fn
)
修改为:
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=collate_fn,generator=torch.Generator(device='cuda')
)
test_loader = DataLoader(
test_dataset, batch_size=test_size, shuffle=True,
collate_fn=collate_fn,generator=torch.Generator(device='cuda')
)
方便复制
generator=torch.Generator(device='cuda')