1. torch.linspace(start,end,steps=100,dtype) <线性间距向量>
作用是返回一个一维的tensor(张量),包含在区间start和end上均匀间隔的step个点,输出张量的长度由steps决定,其中dtype是返回的数据类型。
import torch
print(torch.linspace(-1,1,5))
输出结果为:tensor([-1.0000, -0.5000, 0.0000, 0.5000, 1.0000])
2. unsqueeze()函数
作用是在指定位置增加维度。
import torch
a=torch.arange(0,6) #a是一维向量
b=a.reshape(2,3) #b是二维向量
c=b.unsqueeze(1) #c是三维向量,在b的第二维上增加一个维度
print(a)
print(b)
print(c)
print(c.size())
输出结果为:
tensor([0, 1, 2, 3, 4, 5])
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[[0, 1, 2]],
[[3, 4, 5]]])
torch.Size([2, 1, 3])
a的维度为1x6
b的维度为2x3
c的维度为2x1x3
若想在倒数第二个维度增加一个维度,则c=b.unsqueeze(-1)
3. squeeze()函数
可去掉维度为1的维度。
import torch
a=torch.arange(0,6) #a是一维向量
b=a.reshape(2,3)
c=b.unsqueeze(1)
print(c)
print(c.size())
d=c.squeeze(1)
print(d)
print(d.size())
输出结果为:
tensor([[[0, 1, 2]],
[[3, 4, 5]]])
torch.Size([2, 1, 3])
tensor([[0, 1, 2],
[3, 4, 5]])
torch.Size([2, 3])
4. torch.rand(*sizes, out=None) → Tensor
均匀分布
返回一个张量,包含了从区间[0, 1)的均匀分布中抽取的一组随机数。张量的形状由参数sizes定义。
- 参数
- sizes (int…) - 整数序列,定义了输出张量的形状
- out (Tensor, optinal) - 结果张量
import torch
e = torch.rand(2, 3) # 同rand((2, 3))
e1 = torch.rand(3)
print(e)
print(e.size())
print(e1)
print(e1.size())
输出结果为:
tensor([[0.5202, 0.3955, 0.7335],
[0.2064, 0.7310, 0.0490]])
torch.Size([2, 3])
tensor([0.9491, 0.7499, 0.0585])
torch.Size([3])
5. torch.randn(*sizes, out=None) → Tensor
标准正态分布
返回一个张量,包含了从标准正态分布(均值为0,方差为1,即高斯白噪声)中抽取的一组随机数。张量的形状由参数sizes定义。
-
参数:
-
sizes (int…) - 整数序列,定义了输出张量的形状
-
out (Tensor, optinal) - 结果张量
import torch
f = torch.randn(2, 3)
f1 = torch.randn(3)
print(f)
print(f.size())
print(f1)
print(f1.size())
输出结果为:
tensor([[-0.8658, 1.9324, -1.8821],
[ 0.2715, -0.2389, 2.2982]])
torch.Size([2, 3])
tensor([-0.4860, 0.5520, -0.4666])
torch.Size([3])
6. torch.normal(means, std, out=None) → → Tensor
离散正态分布
返回一个张量,包含了从指定均值means和标准差std的离散正态分布中抽取的一组随机数。
标准差std是一个张量,包含每个输出元素相关的正态分布标准差。
- 参数
- means (float, optional) - 均值
- std (Tensor) - 标准差
- out (Tensor) - 输出张量
import torch
torch.normal(mean=0.5, std=torch.arange(1., 6.))
输出结果为:
tensor([ 0.2444, 0.5445, -0.5246, -1.5175, -0.7741])