目录
1,model.train(),model.eval()
model.train():启用 BatchNormalization 和 Dropout
model.eval():不启用 BatchNormalization 和 Dropout
训练完train_datasets之后,model要来测试样本了。在模型输入测试数据之前model(test_datasets),需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的影响
2,tensor.unfold()
unfold(dim, size, step) → Tensor
在
dim
维填充上所有大小为size
的分片。两个分片之间的步长为step
。 如果_sizedim_是dim维度的原始大小,则在返回tensor中的维度dim大小是_(sizedim-size)/step+1_ 维度大小的附加维度将附加在返回的tensor中。
>>> x = torch.arange(1, 8)
>>> x
1
2
3
4
5
6
7
[torch.FloatTensor of size 7]
>>> x.unfold(0, 2, 1)
1 2
2 3
3 4
4 5
5 6
6 7
[torch.FloatTensor of size 6x2]
>>> x.unfold(0, 2, 2)
1 2
3 4
5 6
[torch.FloatTensor of size 3x2]
3,torch.nn.Unfold
torch.nn.
Unfold
(kernel_size: Union[T, Tuple[T, ...]], dilation: Union[T, Tuple[T, ...]] = 1, padding: Union[T, Tuple[T, ...]] = 0, stride: Union[T, Tuple[T, ...]] = 1)
从输入张量的一个Batch数据中提取滑动的局部块,
参数:卷积核的尺寸,空洞大小,填充大小和步长。
,是batch 维度,是通道维度,是任意的空间维度。该运算在输入的空间维度内每一个kernel_size大小的slide,形成第三个维度中的列。
nfold的输出为,其中为kernel_size长和宽的乘积, L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。
d是所有空间维度,空间尺寸指输入的空间维度()
每个区块的大小为
下图中公式有错,一上面计算区块数量的公式为准
nn.Unfold对输入channel的每一个 的滑动窗口区块做了展平操作。
torch.Size([1, 2, 4, 4])
tensor([[[[ 1.4818, -0.1026, -1.7688, 0.5384],
[-0.4693, -0.0775, -0.7504, 0.2283],
[-0.1414, 1.0006, -0.0942, 2.2981],
[-0.9429, 1.1908, 0.9374, -1.3168]],
[[-1.8184, -0.3926, 0.1875, 1.3847],
[-0.4124, 0.9766, -1.3303, -0.0970],
[ 1.7679, 0.6961, -1.6445, 0.7482],
[ 0.1729, -0.3196, -0.1528, 0.2180]]]])
torch.Size([1, 8, 4])
tensor([[[ 1.4818, -1.7688, -0.1414, -0.0942],
[-0.1026, 0.5384, 1.0006, 2.2981],
[-0.4693, -0.7504, -0.9429, 0.9374],
[-0.0775, 0.2283, 1.1908, -1.3168],
[-1.8184, 0.1875, 1.7679, -1.6445],
[-0.3926, 1.3847, 0.6961, 0.7482],
[-0.4124, -1.3303, 0.1729, -0.1528],
[ 0.9766, -0.0970, -0.3196, 0.2180]]])
def unfold_x():
#N*M,C,T,V,Input
#window_size=3,dilation=1,stride=1
x=torch.randn(2,3,5,4)
print(f'----------------x \n{x}')
unfold=torch.nn.Unfold(kernel_size=(3, 1),dilation=(1, 1),stride=(1, 1),padding=(1, 0))
un_x=unfold(x)
print('un_x shape {}'.format(un_x.size()))
print(f'............un_x\n{un_x}')
return
x=torch.randn(2,3,5,4) print(f'----------------x \n{x}')
tensor([[[[-0.1007, -0.1986, 0.2615, 0.2375],
[ 0.3395, -0.2650, -2.3015, -0.5818],
[-2.4892, 1.3659, 0.9418, 0.5290],
[-0.1242, -1.1327, -0.7105, -0.3952],
[-0.3351, -0.3885, -1.0516, 0.0144]],
[[ 1.1538, -0.2460, -0.6409, 2.3420],
[-1.6041, -0.0226, -1.1131, -1.2851],
[ 1.5435, 2.1038, 0.1150, 0.7285],
[-0.8543, 0.5684, -0.0907, -1.5588],
[-0.1338, 1.2914, 0.5947, -0.1871]],
[[-0.5479, 0.0572, -1.3323, 0.2371],
[-0.3639, 0.8004, -2.4990, -2.6908],
[-0.3635, 0.5411, 0.6723, -1.1272],
[ 1.7912, 1.1216, 0.2887, 0.8244],
[ 0.2222, 1.1524, 1.2438, 0.4919]]],
[[[ 0.5019, 1.0633, 0.3409, -0.4121],
[ 1.1162, 0.0055, 1.2277, -1.4919],
[ 0.0533, -1.6769, -0.9581, 1.7418],
[ 1.9506, -0.7145, -0.3485, 0.0497],
[ 1.7571, -1.0860, 0.1596, 0.4369]],
[[-0.9666, -0.7096, 0.3977, 0.9115],
[-0.0983, 0.3316, 0.1486, 0.2869],
[ 0.7518, -0.7357, 0.2328, -1.5851],
[ 0.2918, 0.4178, 0.0045, -1.1917],
[-1.2200, -1.2876, 1.9524, -2.4134]],
[[ 2.6374, 1.4099, 0.8991, -0.7087],
[ 0.2047, -0.6513, -0.8530, 0.7599],
[ 0.2445, 0.5106, -2.3711, 0.5012],
[ 1.2275, -0.0866, 0.6022, -0.0259],
[ 0.0051, -0.0894, -0.1819, -0.7296]]]])
tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000, -0.1007, -0.1986, 0.2615,
0.2375, 0.3395, -0.2650, -2.3015, -0.5818, -2.4892, 1.3659,
0.9418, 0.5290, -0.1242, -1.1327, -0.7105, -0.3952],
[-0.1007, -0.1986, 0.2615, 0.2375, 0.3395, -0.2650, -2.3015,
-0.5818, -2.4892, 1.3659, 0.9418, 0.5290, -0.1242, -1.1327,
-0.7105, -0.3952, -0.3351, -0.3885, -1.0516, 0.0144],
[ 0.3395, -0.2650, -2.3015, -0.5818, -2.4892, 1.3659, 0.9418,
0.5290, -0.1242, -1.1327, -0.7105, -0.3952, -0.3351, -0.3885,
-1.0516, 0.0144, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 1.1538, -0.2460, -0.6409,
2.3420, -1.6041, -0.0226, -1.1131, -1.2851, 1.5435, 2.1038,
0.1150, 0.7285, -0.8543, 0.5684, -0.0907, -1.5588],
[ 1.1538, -0.2460, -0.6409, 2.3420, -1.6041, -0.0226, -1.1131,
-1.2851, 1.5435, 2.1038, 0.1150, 0.7285, -0.8543, 0.5684,
-0.0907, -1.5588, -0.1338, 1.2914, 0.5947, -0.1871],
[-1.6041, -0.0226, -1.1131, -1.2851, 1.5435, 2.1038, 0.1150,
0.7285, -0.8543, 0.5684, -0.0907, -1.5588, -0.1338, 1.2914,
0.5947, -0.1871, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.5479, 0.0572, -1.3323,
0.2371, -0.3639, 0.8004, -2.4990, -2.6908, -0.3635, 0.5411,
0.6723, -1.1272, 1.7912, 1.1216, 0.2887, 0.8244],
[-0.5479, 0.0572, -1.3323, 0.2371, -0.3639, 0.8004, -2.4990,
-2.6908, -0.3635, 0.5411, 0.6723, -1.1272, 1.7912, 1.1216,
0.2887, 0.8244, 0.2222, 1.1524, 1.2438, 0.4919],
[-0.3639, 0.8004, -2.4990, -2.6908, -0.3635, 0.5411, 0.6723,
-1.1272, 1.7912, 1.1216, 0.2887, 0.8244, 0.2222, 1.1524,
1.2438, 0.4919, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5019, 1.0633, 0.3409,
-0.4121, 1.1162, 0.0055, 1.2277, -1.4919, 0.0533, -1.6769,
-0.9581, 1.7418, 1.9506, -0.7145, -0.3485, 0.0497],
[ 0.5019, 1.0633, 0.3409, -0.4121, 1.1162, 0.0055, 1.2277,
-1.4919, 0.0533, -1.6769, -0.9581, 1.7418, 1.9506, -0.7145,
-0.3485, 0.0497, 1.7571, -1.0860, 0.1596, 0.4369],
[ 1.1162, 0.0055, 1.2277, -1.4919, 0.0533, -1.6769, -0.9581,
1.7418, 1.9506, -0.7145, -0.3485, 0.0497, 1.7571, -1.0860,
0.1596, 0.4369, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, -0.9666, -0.7096, 0.3977,
0.9115, -0.0983, 0.3316, 0.1486, 0.2869, 0.7518, -0.7357,
0.2328, -1.5851, 0.2918, 0.4178, 0.0045, -1.1917],
[-0.9666, -0.7096, 0.3977, 0.9115, -0.0983, 0.3316, 0.1486,
0.2869, 0.7518, -0.7357, 0.2328, -1.5851, 0.2918, 0.4178,
0.0045, -1.1917, -1.2200, -1.2876, 1.9524, -2.4134],
[-0.0983, 0.3316, 0.1486, 0.2869, 0.7518, -0.7357, 0.2328,
-1.5851, 0.2918, 0.4178, 0.0045, -1.1917, -1.2200, -1.2876,
1.9524, -2.4134, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 2.6374, 1.4099, 0.8991,
-0.7087, 0.2047, -0.6513, -0.8530, 0.7599, 0.2445, 0.5106,
-2.3711, 0.5012, 1.2275, -0.0866, 0.6022, -0.0259],
[ 2.6374, 1.4099, 0.8991, -0.7087, 0.2047, -0.6513, -0.8530,
0.7599, 0.2445, 0.5106, -2.3711, 0.5012, 1.2275, -0.0866,
0.6022, -0.0259, 0.0051, -0.0894, -0.1819, -0.7296],
[ 0.2047, -0.6513, -0.8530, 0.7599, 0.2445, 0.5106, -2.3711,
0.5012, 1.2275, -0.0866, 0.6022, -0.0259, 0.0051, -0.0894,
-0.1819, -0.7296, 0.0000, 0.0000, 0.0000, 0.0000]]])
4, pytorch张量约减操作
agg = torch.einsum('vu,nctu->nctv', a_n, b_n)
a_n=torch.from_numpy(A)
b_n=torch.from_numpy(b)
print(a_n)
print(b_n)
agg = torch.einsum('vu,nctu->nctv', a_n, b_n)
#每个张量的最后一个维度u对应行,相乘相加后形成v的每个元素
#a_n[v,1]点乘b_n[n,c,t,1]得到 agg[n,c,t,v11]
print(agg)
print(agg.size())
tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]], dtype=torch.float64)
tensor([[[[ 0., 1., 2., 3., 4.]],
[[ 5., 6., 7., 8., 9.]],
[[10., 11., 12., 13., 14.]]],
[[[15., 16., 17., 18., 19.]],
[[20., 21., 22., 23., 24.]],
[[25., 26., 27., 28., 29.]]]], dtype=torch.float64)
tensor([[[[ 30., 80., 130., 180., 230.]],
[[ 80., 255., 430., 605., 780.]],
[[ 130., 430., 730., 1030., 1330.]]],
[[[ 180., 605., 1030., 1455., 1880.]],
[[ 230., 780., 1330., 1880., 2430.]],
[[ 280., 955., 1630., 2305., 2980.]]]], dtype=torch.float64)
torch.Size([2, 3, 1, 5])
5,nn._init_()
torch.nn.init.uniform_(tensor, a=0, b=1)
服从~U ( a , b ) 均匀分布