创建tensor后,马上.to(device)
labels = torch.zeros(
logits.shape[0],
logits.shape[1]
).to(device).scatter_(
dim=1,
index=labels,
src=torch.ones(
labels.shape[0],
labels.shape[1]
).to(device)
)
创建tensor后,马上.to(device)
labels = torch.zeros(
logits.shape[0],
logits.shape[1]
).to(device).scatter_(
dim=1,
index=labels,
src=torch.ones(
labels.shape[0],
labels.shape[1]
).to(device)
)