torch.ones_like函数和torch.zeros_like函数的基本功能是根据给定张量,生成与其形状相同的全1张量或全0张量,示例如下:
input = torch.rand(2, 3)
print(input)
# 生成与input形状相同、元素全为1的张量
a = torch.ones_like(input)
print(a)
# 生成与input形状相同、元素全为0的张量
b = torch.zeros_like(input)
print(b)
效果如下:
tensor([[0.0881, 0.9002, 0.7084],
[0.3313, 0.2736, 0.0894]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
我们进一步看一下这两个函数在源码中是怎样定义的。
torch.ones_like函数:
@overload
def ones_like(self: Tensor, *, dtype: _dtype=None, layout: layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
torch.zeros_like函数:
@overload
def zeros_like(self: Tensor, *, dtype: _dtype=None, layout: layout=strided, device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...
可以看到,在这两个函数中,我们还可以指定数据类型、设备、是否计算梯度等信息,可以结合具体场景灵活使用。