Pytorch使用--学习记录

目录

1,model.train(),model.eval()

2,tensor.unfold()

 3,torch.nn.Unfold

 

4, pytorch张量约减操作

5,nn._init_()


1,model.train(),model.eval()

model.train():启用 BatchNormalization 和 Dropout

model.eval():不启用 BatchNormalization 和 Dropout

训练完train_datasets之后,model要来测试样本了。在模型输入测试数据之前model(test_datasets),需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的影响

2,tensor.unfold()

unfold(dim, size, step) → Tensor

dim维填充上所有大小为size的分片。两个分片之间的步长为step。 如果_sizedim_是dim维度的原始大小,则在返回tensor中的维度dim大小是_(sizedim-size)/step+1_ 维度大小的附加维度将附加在返回的tensor中。

>>> x = torch.arange(1, 8)
>>> x

 1
 2
 3
 4
 5
 6
 7
[torch.FloatTensor of size 7]

>>> x.unfold(0, 2, 1)

 1  2
 2  3
 3  4
 4  5
 5  6
 6  7
[torch.FloatTensor of size 6x2]

>>> x.unfold(0, 2, 2)

 1  2
 3  4
 5  6
[torch.FloatTensor of size 3x2]

 3,torch.nn.Unfold

 

torch.nn.Unfold(kernel_size: Union[T, Tuple[T, ...]], dilation: Union[T, Tuple[T, ...]] = 1, padding: Union[T, Tuple[T, ...]] = 0, stride: Union[T, Tuple[T, ...]] = 1

从输入张量的一个Batch数据中提取滑动的局部块,

参数:卷积核的尺寸,空洞大小,填充大小和步长。

(N, C, *)N是batch 维度,C是通道维度,*是任意的空间维度。该运算在输入的空间维度内每一个kernel_size大小的slide,形成第三个维度中的列。

nfold的输出为N, C \times \prod(\text{kernel\_size}), L,其中\prod(\text{kernel\_size})为kernel_size长和宽的乘积, L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。

L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor

d是所有空间维度,空间尺寸指输入的空间维度(*

 每个区块的大小为

下图中公式有错,一上面计算区块数量的公式为准

nn.Unfold对输入channel的每一个 的滑动窗口区块做了展平操作。

torch.Size([1, 2, 4, 4])
tensor([[[[ 1.4818, -0.1026, -1.7688,  0.5384],
          [-0.4693, -0.0775, -0.7504,  0.2283],
          [-0.1414,  1.0006, -0.0942,  2.2981],
          [-0.9429,  1.1908,  0.9374, -1.3168]],

         [[-1.8184, -0.3926,  0.1875,  1.3847],
          [-0.4124,  0.9766, -1.3303, -0.0970],
          [ 1.7679,  0.6961, -1.6445,  0.7482],
          [ 0.1729, -0.3196, -0.1528,  0.2180]]]])
torch.Size([1, 8, 4])
tensor([[[ 1.4818, -1.7688, -0.1414, -0.0942],
         [-0.1026,  0.5384,  1.0006,  2.2981],
         [-0.4693, -0.7504, -0.9429,  0.9374],
         [-0.0775,  0.2283,  1.1908, -1.3168],
         [-1.8184,  0.1875,  1.7679, -1.6445],
         [-0.3926,  1.3847,  0.6961,  0.7482],
         [-0.4124, -1.3303,  0.1729, -0.1528],
         [ 0.9766, -0.0970, -0.3196,  0.2180]]])

def unfold_x():
    #N*M,C,T,V,Input
    #window_size=3,dilation=1,stride=1
    x=torch.randn(2,3,5,4)
    print(f'----------------x \n{x}')
    unfold=torch.nn.Unfold(kernel_size=(3, 1),dilation=(1, 1),stride=(1, 1),padding=(1, 0))
    un_x=unfold(x)
    print('un_x shape {}'.format(un_x.size()))
    print(f'............un_x\n{un_x}')
    
    return

x=torch.randn(2,3,5,4)  print(f'----------------x \n{x}')

tensor([[[[-0.1007, -0.1986,  0.2615,  0.2375],
          [ 0.3395, -0.2650, -2.3015, -0.5818],
          [-2.4892,  1.3659,  0.9418,  0.5290],
          [-0.1242, -1.1327, -0.7105, -0.3952],
          [-0.3351, -0.3885, -1.0516,  0.0144]],

         [[ 1.1538, -0.2460, -0.6409,  2.3420],
          [-1.6041, -0.0226, -1.1131, -1.2851],
          [ 1.5435,  2.1038,  0.1150,  0.7285],
          [-0.8543,  0.5684, -0.0907, -1.5588],
          [-0.1338,  1.2914,  0.5947, -0.1871]],

         [[-0.5479,  0.0572, -1.3323,  0.2371],
          [-0.3639,  0.8004, -2.4990, -2.6908],
          [-0.3635,  0.5411,  0.6723, -1.1272],
          [ 1.7912,  1.1216,  0.2887,  0.8244],
          [ 0.2222,  1.1524,  1.2438,  0.4919]]],


        [[[ 0.5019,  1.0633,  0.3409, -0.4121],
          [ 1.1162,  0.0055,  1.2277, -1.4919],
          [ 0.0533, -1.6769, -0.9581,  1.7418],
          [ 1.9506, -0.7145, -0.3485,  0.0497],
          [ 1.7571, -1.0860,  0.1596,  0.4369]],

         [[-0.9666, -0.7096,  0.3977,  0.9115],
          [-0.0983,  0.3316,  0.1486,  0.2869],
          [ 0.7518, -0.7357,  0.2328, -1.5851],
          [ 0.2918,  0.4178,  0.0045, -1.1917],
          [-1.2200, -1.2876,  1.9524, -2.4134]],

         [[ 2.6374,  1.4099,  0.8991, -0.7087],
          [ 0.2047, -0.6513, -0.8530,  0.7599],
          [ 0.2445,  0.5106, -2.3711,  0.5012],
          [ 1.2275, -0.0866,  0.6022, -0.0259],
          [ 0.0051, -0.0894, -0.1819, -0.7296]]]])
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000, -0.1007, -0.1986,  0.2615,
           0.2375,  0.3395, -0.2650, -2.3015, -0.5818, -2.4892,  1.3659,
           0.9418,  0.5290, -0.1242, -1.1327, -0.7105, -0.3952],
         [-0.1007, -0.1986,  0.2615,  0.2375,  0.3395, -0.2650, -2.3015,
          -0.5818, -2.4892,  1.3659,  0.9418,  0.5290, -0.1242, -1.1327,
          -0.7105, -0.3952, -0.3351, -0.3885, -1.0516,  0.0144],
         [ 0.3395, -0.2650, -2.3015, -0.5818, -2.4892,  1.3659,  0.9418,
           0.5290, -0.1242, -1.1327, -0.7105, -0.3952, -0.3351, -0.3885,
          -1.0516,  0.0144,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  1.1538, -0.2460, -0.6409,
           2.3420, -1.6041, -0.0226, -1.1131, -1.2851,  1.5435,  2.1038,
           0.1150,  0.7285, -0.8543,  0.5684, -0.0907, -1.5588],
         [ 1.1538, -0.2460, -0.6409,  2.3420, -1.6041, -0.0226, -1.1131,
          -1.2851,  1.5435,  2.1038,  0.1150,  0.7285, -0.8543,  0.5684,
          -0.0907, -1.5588, -0.1338,  1.2914,  0.5947, -0.1871],
         [-1.6041, -0.0226, -1.1131, -1.2851,  1.5435,  2.1038,  0.1150,
           0.7285, -0.8543,  0.5684, -0.0907, -1.5588, -0.1338,  1.2914,
           0.5947, -0.1871,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, -0.5479,  0.0572, -1.3323,
           0.2371, -0.3639,  0.8004, -2.4990, -2.6908, -0.3635,  0.5411,
           0.6723, -1.1272,  1.7912,  1.1216,  0.2887,  0.8244],
         [-0.5479,  0.0572, -1.3323,  0.2371, -0.3639,  0.8004, -2.4990,
          -2.6908, -0.3635,  0.5411,  0.6723, -1.1272,  1.7912,  1.1216,
           0.2887,  0.8244,  0.2222,  1.1524,  1.2438,  0.4919],
         [-0.3639,  0.8004, -2.4990, -2.6908, -0.3635,  0.5411,  0.6723,
          -1.1272,  1.7912,  1.1216,  0.2887,  0.8244,  0.2222,  1.1524,
           1.2438,  0.4919,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.5019,  1.0633,  0.3409,
          -0.4121,  1.1162,  0.0055,  1.2277, -1.4919,  0.0533, -1.6769,
          -0.9581,  1.7418,  1.9506, -0.7145, -0.3485,  0.0497],
         [ 0.5019,  1.0633,  0.3409, -0.4121,  1.1162,  0.0055,  1.2277,
          -1.4919,  0.0533, -1.6769, -0.9581,  1.7418,  1.9506, -0.7145,
          -0.3485,  0.0497,  1.7571, -1.0860,  0.1596,  0.4369],
         [ 1.1162,  0.0055,  1.2277, -1.4919,  0.0533, -1.6769, -0.9581,
           1.7418,  1.9506, -0.7145, -0.3485,  0.0497,  1.7571, -1.0860,
           0.1596,  0.4369,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, -0.9666, -0.7096,  0.3977,
           0.9115, -0.0983,  0.3316,  0.1486,  0.2869,  0.7518, -0.7357,
           0.2328, -1.5851,  0.2918,  0.4178,  0.0045, -1.1917],
         [-0.9666, -0.7096,  0.3977,  0.9115, -0.0983,  0.3316,  0.1486,
           0.2869,  0.7518, -0.7357,  0.2328, -1.5851,  0.2918,  0.4178,
           0.0045, -1.1917, -1.2200, -1.2876,  1.9524, -2.4134],
         [-0.0983,  0.3316,  0.1486,  0.2869,  0.7518, -0.7357,  0.2328,
          -1.5851,  0.2918,  0.4178,  0.0045, -1.1917, -1.2200, -1.2876,
           1.9524, -2.4134,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  2.6374,  1.4099,  0.8991,
          -0.7087,  0.2047, -0.6513, -0.8530,  0.7599,  0.2445,  0.5106,
          -2.3711,  0.5012,  1.2275, -0.0866,  0.6022, -0.0259],
         [ 2.6374,  1.4099,  0.8991, -0.7087,  0.2047, -0.6513, -0.8530,
           0.7599,  0.2445,  0.5106, -2.3711,  0.5012,  1.2275, -0.0866,
           0.6022, -0.0259,  0.0051, -0.0894, -0.1819, -0.7296],
         [ 0.2047, -0.6513, -0.8530,  0.7599,  0.2445,  0.5106, -2.3711,
           0.5012,  1.2275, -0.0866,  0.6022, -0.0259,  0.0051, -0.0894,
          -0.1819, -0.7296,  0.0000,  0.0000,  0.0000,  0.0000]]])

 

4, pytorch张量约减操作

agg = torch.einsum('vu,nctu->nctv', a_n, b_n)

 

a_n=torch.from_numpy(A)
b_n=torch.from_numpy(b)
print(a_n)
print(b_n)
agg = torch.einsum('vu,nctu->nctv', a_n, b_n)
#每个张量的最后一个维度u对应行,相乘相加后形成v的每个元素
#a_n[v,1]点乘b_n[n,c,t,1]得到 agg[n,c,t,v11]
print(agg)
print(agg.size())

 

tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24.]], dtype=torch.float64)
tensor([[[[ 0.,  1.,  2.,  3.,  4.]],

         [[ 5.,  6.,  7.,  8.,  9.]],

         [[10., 11., 12., 13., 14.]]],


        [[[15., 16., 17., 18., 19.]],

         [[20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29.]]]], dtype=torch.float64)
tensor([[[[  30.,   80.,  130.,  180.,  230.]],

         [[  80.,  255.,  430.,  605.,  780.]],

         [[ 130.,  430.,  730., 1030., 1330.]]],


        [[[ 180.,  605., 1030., 1455., 1880.]],

         [[ 230.,  780., 1330., 1880., 2430.]],

         [[ 280.,  955., 1630., 2305., 2980.]]]], dtype=torch.float64)
torch.Size([2, 3, 1, 5])

5,nn._init_()

    torch.nn.init.uniform_(tensor, a=0, b=1) 服从~U ( a , b ) 均匀分布

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值