代码:
import torch
class_num = 10
batch_size = 4
label = torch.Tensor(2,3,4,5).random_() % class_num
print(label.size())
label=label.permute(1,0,2,3)
print(label.size())
输出:
torch.Size([2, 3, 4, 5])
torch.Size([3, 2, 4, 5])
代码:
import torch
class_num = 10
batch_size = 4
label = torch.Tensor(2,3,4,5).random_() % class_num
print(label.size())
label=label.permute(1,0,2,3)
print(label.size())
输出:
torch.Size([2, 3, 4, 5])
torch.Size([3, 2, 4, 5])