torch.ones_like
torch.ones_like(input, dtype=None, layout=None, device=None,
requires_grad=False) → Tensor
方法解释:
返回一个填充了标量值1的张量,其大小与之相同 input。
torch.ones_like(input)相当于 torch.ones(input.size(), dtype=input.dtype,
layout=input.layout, device=input.device)
import torch
input = torch.empty(2, 3)
torch.ones_like(input)
tensor([[1., 1., 1.],
[1., 1., 1.]])
torch.distributions.Normal()
class Normal(ExponentialFamily):
r"""
Creates a normal (also called Gaussian) distribution parameterized by
:attr:`loc` and :attr:`scale`.
Example::
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample() # normally distributed with loc=0 and scale=1
tensor([ 0.1046])
Args:
loc (float or Tensor): mean of the distribution (often referred to as mu)
scale (float or Tensor): standard deviation of the distribution
(often referred to as sigma)
"""
from torch.distributions import Normal
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
m.sample()
tensor([-0.0836])
pytorch的torch.distributions中可以定义正态分布
如下:
import torch
from torch.distributions import Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)
torch.distributions.sample()
sample()就是直接在定义的正太分布(均值为mean,标准差std是1)上采样:
c=normal.sample()
print("c:",c)
torch.distributions.rsample()
rsample()不是在定义的正太分布上采样,而是先对标准正太分布N(0,1)进行采样,然后输出:mean+std×采样值
a=normal.rsample()
a: tensor([ 0.0530, 2.8396])
torch.distributions.log_prob(value)
log_prob(value)是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是
,对其取对数可得
这里我们通过对数概率还原其对应的真实概率:
print("c log_prob:",normal.log_prob(c).exp())
c log_prob: tensor([ 0.1634, 0.2005])
torch.distributions相关测试补充说明
import torch
import math
from torch.distributions import Normal
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
action=m.sample()
print('action=',action)
# 这里的action相当于在概率密度函数中的x,不是指f(x);注意区分分布函数与概率密度函数单位区别
print('log_prob=',m.log_prob(action))
print("exp_log_prob=",torch.exp(m.log_prob(action)))
fx=(1/math.sqrt(2*math.pi))*torch.exp((-action**2)/2)
print(f)
print(torch.log(f))
# 这里的三行是对m.log_prob(action)的一个说明
# 表示将采样到的action,即x,带入正态分布的概率密度函数中计算的到f(x)值,
# 然后将值再计算log(f(x))