1、当使用如下代码,想保存一个tensor的一部分时
large = torch.arange(1, 1000)
small = large[0:5]
torch.save(small, 'small.pt')
loaded_small = torch.load('small.pt')
loaded_small.storage().size()
# 999
最后保存的结果却不是,large[0:5]
,而是整个large
,
想要解决这个问题,需要加一个使用tensor
的clone
函数。
small.clone()
完整保存代码
large = torch.arange(1, 1000)
small = large[0:5]
torch.save(small.clone(), 'small.pt') # saves a clone of small
loaded_small = torch.load('small.pt')
loaded_small.storage().size()
# 5