torch.flatten()函数
A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
out1 = torch.flatten(A)
out2 = torch.flatten(A,0)
out3 = torch.flatten(A,1)
print("默认 = ",out1)
print("参数为0 = ",out2)
print("参数为1 = ",out3)
运行结果
默认 = tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,19, 20, 21, 22, 23, 24])
参数为0 = tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,19, 20, 21, 22, 23, 24])
参数为1 = tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
总结:torch.flatten()函数默认参数为0.
nn.nn.Linear()函数
将flatten的数据,映射