torch数据类型的item方法是得到只有一个元素张量里面的元素值。
如下:
>>> x = torch.tensor(4)
>>> x.item()
4
如果对包含多个元素的torch.tensor用item()方法,则会报错如下:
>>> x = torch.tensor([1,2,3,4])
>>> x.item()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: only one element tensors can be converted to Python scalars
只要是只有一个元素,不论维度如何,都可以用item()方法,如下:
>>> x = torch.tensor([4])
>>> x.item()
4
>>> x = torch.tensor([[4]])
>>> x.item()
4