通道注意力:pytorch小案例

一、通道注意力机制简介

下面的图形象的说明了通道注意力机制

 二、通道注意力机制

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

1. 单独使用通道注意力机制的小案例

import torch
import torch.nn as nn
import torch.utils.data as Data


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


def get_total_train_data(H, W, C, class_count):
    """得到全部的训练数据,这里需要替换成自己的数据"""
    import numpy as np
    x_train = torch.Tensor(
        np.random.random((1000, H, W, C)))  # 维度是 [ 数据量, 高H, 宽W, 长C]
    y_train = torch.Tensor(
        np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务
    return x_train, y_train


if __name__ == '__main__':
    # ================训练参数=================
    epochs = 100
    batch_size = 30
    output_class = 14
    H = 40
    W = 50
    C = 30
    # ================准备数据=================
    x_train, y_train = get_total_train_data(H, W, C, class_count=output_class)
    train_loader = Data.DataLoader(
        dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度
        batch_size=batch_size,  # 每块的大小
        shuffle=True,  # 要不要打乱数据 (打乱比较好)
        num_workers=6,  # 多进程(multiprocess)来读数据
        drop_last=True,
    )
    # ================初始化模型=================
    model = ChannelAttention(in_planes=H)
    # ================开始训练=================
    for i in range(epochs):
        for seq, labels in train_loader:
            attention_out = model(seq)
            seq_attention_out = attention_out.squeeze()
            for i in range(seq_attention_out.size()[0]):
                print(seq_attention_out[i])

 输出结果:

tensor([0.5588, 0.5731, 0.5546, 0.4316, 0.5486, 0.5705, 0.4307, 0.5087, 0.4734,
        0.5527, 0.5328, 0.5670, 0.5212, 0.5357, 0.5589, 0.5588, 0.5630, 0.4738,
        0.5375, 0.4512, 0.5027, 0.4762, 0.4330, 0.4694, 0.5424, 0.4416, 0.5593,
        0.5629, 0.4485, 0.4500, 0.5506, 0.4345, 0.5415, 0.5627, 0.5681, 0.5420,
        0.5425, 0.5357, 0.4589, 0.5654], grad_fn=<SelectBackward0>)
tensor([0.5606, 0.5753, 0.5563, 0.4295, 0.5501, 0.5726, 0.4286, 0.5089, 0.4726,
        0.5544, 0.5339, 0.5691, 0.5218, 0.5368, 0.5607, 0.5606, 0.5649, 0.4730,
        0.5387, 0.4497, 0.5028, 0.4755, 0.4310, 0.4684, 0.5437, 0.4398, 0.5611,
        0.5648, 0.4470, 0.4485, 0.5521, 0.4325, 0.5427, 0.5646, 0.5701, 0.5433,
        0.5438, 0.5368, 0.4577, 0.5674], grad_fn=<SelectBackward0>)
tensor([0.5608, 0.5757, 0.5565, 0.4292, 0.5503, 0.5729, 0.4283, 0.5090, 0.4725,
        0.5546, 0.5340, 0.5694, 0.5219, 0.5369, 0.5610, 0.5608, 0.5652, 0.4729,
        0.5389, 0.4495, 0.5028, 0.4754, 0.4307, 0.4683, 0.5438, 0.4395, 0.5614,
        0.5650, 0.4467, 0.4483, 0.5523, 0.4322, 0.5429, 0.5648, 0.5704, 0.5435,
        0.5440, 0.5370, 0.4575, 0.5677], grad_fn=<SelectBackward0>)
tensor([0.5603, 0.5749, 0.5560, 0.4299, 0.5498, 0.5722, 0.4290, 0.5089, 0.4728,
        0.5541, 0.5337, 0.5687, 0.5217, 0.5366, 0.5604, 0.5603, 0.5646, 0.4731,
        0.5385, 0.4500, 0.5028, 0.4756, 0.4314, 0.4686, 0.5434, 0.4401, 0.5608,
        0.5644, 0.4473, 0.4488, 0.5518, 0.4328, 0.5425, 0.5642, 0.5698, 0.5431,
        0.5436, 0.5366, 0.4579, 0.5671], grad_fn=<SelectBackward0>)

这个就是每个batch中,每层的权重,其中输入模型的size是[30, 40, 50, 30],输出的attention_out的size是[30, 40, 1, 1]

2. 使用通道注意力机制的小案例

import torch
import torch.nn as nn
import torch.utils.data as Data


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class UseAttentionModel(nn.Module): # 这里可以随便定义自己的模型
    def __init__(self, H):
        super(UseAttentionModel, self).__init__()
        self.channel_attention = ChannelAttention(H)

    def forward(self, x):  # 反向传播
        attention_value = self.channel_attention(x)
        out = x.mul(attention_value) # 得到借助注意力机制后的输出
        return out


def get_total_train_data(H, W, C, class_count):
    """得到全部的训练数据,这里需要替换成自己的数据"""
    import numpy as np
    x_train = torch.Tensor(
        np.random.random((1000, H, W, C)))  # 维度是 [ 数据量, 高H, 宽W, 长C]
    y_train = torch.Tensor(
        np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务
    return x_train, y_train


if __name__ == '__main__':
    # ================训练参数=================
    epochs = 100
    batch_size = 30
    output_class = 14
    H = 40
    W = 50
    C = 30
    # ================准备数据=================
    x_train, y_train = get_total_train_data(H, W, C, class_count=output_class)
    train_loader = Data.DataLoader(
        dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度
        batch_size=batch_size,  # 每块的大小
        shuffle=True,  # 要不要打乱数据 (打乱比较好)
        num_workers=6,  # 多进程(multiprocess)来读数据
        drop_last=True,
    )
    # ================初始化模型=================
    model = UseAttentionModel(H)
    cross_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
    model.train()
    # ================开始训练=================
    for i in range(epochs):
        for seq, labels in train_loader:
            attention_out = model(seq)
            print(attention_out.size())
            print(attention_out)

输出结果

.,
          [5.8909e-01, 6.3514e-01, 5.6621e-01,  ..., 2.5436e-01,
           6.1034e-01, 2.2627e-01],
          [1.8228e-01, 3.9943e-01, 3.3625e-01,  ..., 3.7625e-01,
           5.4496e-01, 5.6080e-01],
          [4.0188e-01, 5.3317e-01, 6.6107e-01,  ..., 3.5898e-02,
           6.1695e-01, 3.1746e-01]],

         [[5.3959e-02, 3.7324e-02, 3.7568e-02,  ..., 2.1961e-01,
           3.0825e-01, 6.2169e-02],
          [3.0097e-01, 2.5474e-01, 2.5117e-01,  ..., 2.5216e-01,
           1.0771e-01, 1.7865e-01],
          [9.8417e-02, 2.5117e-01, 1.9108e-01,  ..., 1.5224e-01,
           2.0325e-01, 3.2980e-02],
          ...,
          [3.1832e-01, 1.9137e-01, 7.8385e-02,  ..., 1.4967e-01,
           3.1490e-01, 3.1258e-01],
          [2.3128e-01, 2.9610e-01, 2.9908e-01,  ..., 1.7663e-01,
           2.3144e-01, 1.4767e-01],
          [6.9204e-02, 1.7444e-01, 8.5382e-02,  ..., 2.5462e-01,
           1.3741e-01, 1.0444e-01]],

         [[3.5779e-01, 3.5207e-01, 4.0147e-01,  ..., 4.5147e-01,
           1.1802e-01, 4.9862e-02],
          [4.8933e-01, 3.2494e-02, 4.1151e-01,  ..., 2.2519e-01,
           3.1664e-02, 1.3435e-01],
          [2.0881e-01, 3.2688e-01, 2.7747e-01,  ..., 3.3722e-01,
           4.6551e-01, 3.0461e-01],
          ...,
          [3.3489e-01, 2.6264e-02, 4.0494e-01,  ..., 3.0513e-01,
           2.5094e-01, 4.0879e-01],
          [3.3872e-01, 4.4867e-01, 5.0175e-02,  ..., 4.4742e-01,
           4.0580e-01, 9.6269e-02],
          [4.9568e-01, 1.4296e-01, 4.8520e-01,  ..., 1.7537e-01,
           3.7802e-01, 1.6438e-01]]]], grad_fn=<MulBackward0>)
torch.Size([30, 40, 50, 30])
tensor([[[[5.8376e-03, 8.6779e-02, 1.7438e-01,  ..., 1.0239e-01,
           1.3072e-01, 1.0530e-01],
          [3.1557e-01, 4.3194e-01, 2.9906e-01,  ..., 3.2867e-01,
           2.2711e-02, 3.2728e-01],
          [4.1940e-01, 2.7013e-01, 4.1261e-01,  ..., 4.6662e-03,
           2.5506e-02, 6.3954e-02],
          ...,
          [2.4880e-01, 1.7298e-01, 4.2021e-01,  ..., 4.9969e-01,
           2.6830e-01, 4.4434e-01],
          [4.4478e-01, 1.2308e-01, 1.6029e-02,  ..., 1.0709e-01,
           3.5965e-01, 2.7056e-01],
          [6.9277e-02, 4.6273e-01, 4.5013e-01,  ..., 3.3907e-01,
           6.9696e-02, 2.8434e-01]],

         [[3.4664e-01, 1.0111e-01, 4.5930e-01,  ..., 4.5612e-01,
           1.7015e-01, 2.6091e-01],
          [1.5467e-01, 2.6112e-02, 3.3943e-01,  ..., 1.2258e-01,
           3.7761e-01, 3.6805e-01],
          [3.7967e-01, 2.2492e-02, 2.0753e-01,  ..., 1.6064e-01,
           1.6094e-02, 1.3792e-01],
          ...,
          [2.9659e-01, 2.9710e-01, 4.4645e-01,  ..., 6.2105e-01,
           9.2036e-02, 5.0457e-01],
          [4.4295e-02, 3.4434e-01, 7.0044e-02,  ..., 4.3344e-01,
           5.8340e-01, 3.1414e-01],
          [5.2043e-01, 5.5376e-01, 6.3628e-01,  ..., 6.7018e-01,
           3.4686e-01, 3.5812e-01]],

         [[4.0782e-01, 1.8345e-01, 1.7163e-01,  ..., 3.5217e-01,
           2.3917e-01, 5.9046e-02],
          [2.7176e-01, 3.8589e-01, 5.2806e-02,  ..., 3.5818e-01,
           3.6082e-01, 1.9372e-01],
          [3.0722e-01, 4.0880e-01, 2.4668e-01,  ..., 3.2843e-01,
           4.0772e-01, 4.2086e-01],
          ...,
          [5.0441e-03, 3.2488e-02, 1.9876e-01,  ..., 5.8880e-02,
           2.9170e-01, 2.0976e-01],
          [2.8858e-01, 3.0288e-01, 1.5205e-01,  ..., 2.3044e-01,
           4.0283e-01, 4.1184e-01],
          [1.9545e-01, 2.3254e-02, 1.4408e-02,  ..., 1.1230e-01,
           3.1979e-01, 2.6520e-02]],

         ...,

         [[2.1590e-01, 4.4903e-01, 2.5352e-01,  ..., 1.5612e-01,
           2.4046e-01, 6.6485e-01],
          [9.9776e-02, 5.8276e-01, 2.5749e-01,  ..., 1.5287e-01,
           1.6295e-01, 1.7505e-01],
          [3.6630e-01, 1.9857e-01, 3.3491e-02,  ..., 6.3225e-01,
           2.2354e-02, 3.1733e-01],
          ...,
          [4.7942e-01, 1.6120e-01, 2.6720e-01,  ..., 4.5366e-01,
           5.8863e-01, 3.3767e-01],
          [5.8948e-01, 1.7926e-01, 3.4345e-01,  ..., 4.1692e-01,
           1.3364e-01, 3.4892e-01],
          [4.2176e-01, 2.4713e-01, 3.1612e-01,  ..., 1.5828e-01,
           4.1804e-01, 4.0963e-01]],

         [[1.3018e-01, 1.1872e-01, 1.5698e-01,  ..., 1.0385e-01,
           2.0025e-01, 2.4690e-01],
          [1.5643e-01, 1.8706e-01, 2.7718e-01,  ..., 1.0949e-01,
           1.2985e-01, 5.3787e-02],
          [1.8278e-01, 2.7324e-01, 1.8071e-01,  ..., 7.8948e-02,
           2.4110e-01, 1.3336e-01],
          ...,
          [1.5448e-01, 2.1301e-02, 3.1794e-02,  ..., 2.1719e-01,
           2.7544e-01, 8.8967e-02],
          [3.1825e-01, 2.1210e-01, 2.0945e-01,  ..., 2.5318e-01,
           3.0889e-01, 8.4615e-02],
          [2.0909e-01, 6.2515e-02, 2.7447e-01,  ..., 1.1217e-01,
           2.3674e-01, 1.2513e-01]],

         [[2.7271e-01, 2.1059e-01, 3.9190e-01,  ..., 1.9693e-01,
           2.6403e-01, 3.5621e-01],
          [3.4510e-01, 1.5951e-01, 1.9619e-01,  ..., 3.2635e-02,
           2.7575e-01, 3.7990e-01],
          [4.7587e-02, 8.0930e-02, 2.0546e-01,  ..., 1.0218e-01,
           3.2037e-03, 1.2652e-01],
          ...,
          [5.4354e-02, 1.8714e-01, 4.3344e-01,  ..., 2.0946e-01,
           4.5508e-01, 3.0272e-01],
          [1.1436e-01, 1.6864e-01, 2.1764e-01,  ..., 8.8889e-03,
           2.5602e-01, 3.4108e-01],
          [2.1648e-01, 1.0270e-01, 8.6273e-02,  ..., 4.6211e-01,
           3.0791e-01, 3.2195e-01]]],


        [[[3.1999e-01, 1.1309e-01, 5.2144e-01,  ..., 3.8548e-01,
           4.1719e-02, 3.8593e-01],
          [1.8820e-01, 4.0333e-02, 5.1674e-01,  ..., 8.8709e-02,
           4.0717e-01, 2.8133e-01],
          [2.4757e-01, 1.9266e-01, 3.7587e-01,  ..., 4.7162e-01,
           1.6995e-01, 3.6480e-01],
          ...,
          [2.5976e-01, 1.0032e-01, 4.9688e-01,  ..., 5.3358e-01,
           6.5465e-02, 2.5120e-01],
          [9.1720e-02, 1.5227e-01, 2.2289e-01,  ..., 1.8987e-01,
           3.0132e-01, 2.9701e-01],
          [4.6983e-02, 5.1845e-01, 8.9825e-02,  ..., 1.5552e-01,
           3.1282e-01, 3.5649e-01]],

         [[7.8974e-02, 4.6238e-01, 3.4131e-01,  ..., 4.1787e-01,
           1.9413e-01, 1.8333e-02],
          [9.6689e-02, 4.6322e-01, 3.9089e-01,  ..., 5.1841e-01,
           4.9190e-01, 1.1615e-01],
          [6.2001e-01, 2.4442e-01, 1.8889e-01,  ..., 1.2457e-01,
           2.6182e-02, 4.4985e-01],
          ...,
          [1.1602e-01, 5.8338e-01, 8.9675e-02,  ..., 2.1815e-01,
           4.3759e-01, 3.7383e-01],
          [9.4003e-02, 6.2554e-01, 4.3553e-01,  ..., 3.2928e-01,
           6.2832e-01, 6.0116e-01],
          [5.1921e-01, 1.9946e-01, 1.1482e-01,  ..., 5.3201e-01,
           5.4599e-01, 1.2702e-01]],

         [[1.9917e-01, 3.6794e-01, 3.3218e-01,  ..., 2.7417e-01,
           1.0544e-01, 8.0625e-03],
          [3.5539e-01, 9.8039e-02, 7.1296e-02,  ..., 3.1767e-01,
           1.5422e-02, 3.9295e-01],
          [1.3119e-01, 2.8511e-02, 7.4432e-02,  ..., 4.2632e-01,
           8.9860e-02, 1.2532e-01],
          ...,
          [3.3911e-01, 4.1993e-02, 3.0953e-01,  ..., 2.6378e-01,
           2.0396e-02, 2.4905e-01],
          [2.8862e-02, 1.3324e-01, 3.9113e-01,  ..., 1.3064e-01,
           4.1070e-01, 2.7076e-02],
          [2.7162e-01, 1.8591e-01, 4.2494e-01,  ..., 2.9818e-01,
           3.3376e-01, 4.2397e-01]],

         ...,

         [[2.4424e-01, 4.8494e-01, 6.0155e-04,  ..., 2.8806e-01,
           2.7329e-01, 4.0650e-01],
          [1.1670e-01, 6.4361e-01, 3.6455e-02,  ..., 3.3145e-01,
           3.6667e-01, 6.4167e-01],
          [3.0911e-01, 4.5071e-01, 1.8409e-01,  ..., 6.4862e-01,
           3.7249e-01, 4.0705e-01],
          ...,
          [1.9194e-01, 4.0528e-01, 4.7313e-01,  ..., 2.9100e-01,
           2.8848e-01, 5.1839e-01],
          [6.5525e-01, 5.2963e-01, 4.9597e-01,  ..., 6.2259e-01,
           7.3112e-02, 6.4011e-01],
          [5.4410e-01, 6.0115e-01, 6.2110e-01,  ..., 1.8753e-02,
           3.7479e-01, 3.5090e-01]],

         [[2.3961e-01, 3.2816e-02, 4.9618e-02,  ..., 1.2918e-02,
           2.1986e-01, 1.5563e-01],
          [2.4725e-01, 2.1041e-01, 1.6784e-01,  ..., 9.6483e-03,
           1.5086e-01, 1.5731e-01],
          [3.3069e-02, 1.0988e-01, 1.1812e-01,  ..., 2.3662e-02,
           2.6713e-01, 1.4946e-02],
          ...,
          [6.5230e-02, 1.7231e-01, 2.2948e-01,  ..., 8.9843e-02,
           3.1142e-01, 1.0406e-01],
          [2.2434e-01, 2.1281e-01, 1.1233e-01,  ..., 1.8396e-03,
           2.9770e-01, 2.6315e-01],
          [2.4079e-01, 1.4445e-01, 5.5308e-02,  ..., 2.0543e-01,
           6.8725e-02, 1.5042e-01]],

         [[2.4322e-01, 2.1722e-01, 1.7423e-01,  ..., 1.9310e-01,
           3.1676e-01, 2.3795e-02],
          [1.3183e-01, 3.3737e-01, 5.0061e-01,  ..., 3.9652e-01,
           7.0069e-02, 6.0332e-02],
          [9.1154e-02, 3.1488e-01, 1.0126e-01,  ..., 4.0861e-01,
           2.7820e-01, 2.1414e-02],
          ...,
          [7.9193e-02, 3.5923e-01, 8.5618e-02,  ..., 2.0397e-01,
           4.2899e-01, 1.8373e-01],
          [3.7426e-01, 2.3798e-01, 1.4702e-01,  ..., 3.2147e-01,
           3.7944e-01, 1.0994e-01],
          [1.5743e-01, 2.7951e-01, 4.2321e-01,  ..., 1.0681e-01,
           2.8060e-01, 3.9294e-01]]],


        [[[7.2220e-02, 2.1653e-01, 2.8066e-01,  ..., 1.8818e-01,
           1.2271e-01, 1.1514e-01],
          [4.0778e-01, 1.6100e-01, 2.5127e-02,  ..., 3.2859e-01,
           9.6127e-02, 5.1100e-01],
          [4.8849e-01, 4.9311e-01, 3.3616e-01,  ..., 3.3986e-01,
           2.1591e-01, 2.6148e-02],
          ...,
          [3.0963e-01, 1.0259e-01, 2.2475e-01,  ..., 4.5960e-01,
           3.2795e-01, 4.2569e-01],
          [3.9674e-01, 4.4430e-01, 2.2910e-01,  ..., 3.8700e-01,
           1.3288e-01, 3.1768e-01],
          [4.1361e-01, 5.0998e-01, 4.2441e-01,  ..., 4.6824e-01,
           3.3051e-01, 4.1861e-01]],

         [[6.3847e-01, 1.4755e-01, 2.3390e-01,  ..., 1.5977e-01,
           2.8158e-01, 4.5020e-01],
          [8.8735e-02, 6.5277e-01, 2.2966e-01,  ..., 9.5959e-02,
           4.3343e-01, 1.5598e-01],
          [2.8933e-01, 4.1399e-01, 6.6934e-01,  ..., 2.0807e-01,
           3.9732e-01, 2.3035e-01],
          ...,
          [1.3732e-01, 5.8038e-01, 1.9536e-01,  ..., 2.7867e-01,
           6.4940e-01, 3.1130e-01],
          [5.0236e-01, 6.0392e-01, 3.2395e-01,  ..., 2.1008e-01,
           1.0502e-01, 2.6188e-01],
          [2.6573e-01, 2.5646e-01, 2.4362e-01,  ..., 6.1493e-01,
           7.0246e-02, 4.7146e-01]],

         [[2.5816e-01, 9.4974e-02, 2.7214e-01,  ..., 4.0728e-01,
           1.4081e-01, 1.6711e-01],
          [3.1809e-01, 4.7522e-03, 3.6016e-01,  ..., 4.2377e-01,
           3.4266e-01, 7.7796e-02],
          [3.5260e-01, 2.1014e-01, 3.0935e-01,  ..., 4.2626e-01,
           1.9928e-01, 1.8115e-01],
          ...,
          [3.3202e-01, 5.0270e-02, 3.3748e-01,  ..., 2.2421e-02,
           3.1043e-01, 3.4556e-01],
          [3.0078e-01, 2.6495e-01, 2.6959e-01,  ..., 1.0680e-01,
           2.4285e-01, 1.8679e-01],
          [2.2222e-01, 2.8203e-01, 2.8294e-01,  ..., 8.7565e-02,
           2.8609e-01, 1.8188e-01]],

         ...,

         [[4.6075e-01, 8.3959e-02, 6.7045e-01,  ..., 2.4934e-01,
           3.5234e-02, 3.5842e-01],
          [6.6860e-01, 5.9290e-01, 5.1695e-01,  ..., 2.5263e-01,
           9.7966e-02, 4.5769e-01],
          [2.5360e-01, 4.6531e-01, 5.2316e-01,  ..., 1.7106e-01,
           3.3797e-01, 6.2658e-01],
          ...,
          [3.6486e-01, 6.4389e-01, 1.5288e-01,  ..., 1.6446e-01,
           4.1378e-01, 3.5700e-01],
          [5.9787e-01, 5.2417e-01, 5.1623e-01,  ..., 5.3684e-01,
           4.3913e-01, 4.1189e-01],
          [3.4635e-01, 2.9880e-02, 3.2238e-01,  ..., 3.0910e-01,
           5.8653e-01, 3.7638e-01]],

         [[2.9601e-01, 5.5435e-02, 1.3959e-01,  ..., 1.2477e-01,
           9.0192e-02, 1.1927e-01],
          [8.1599e-02, 2.7071e-01, 2.7892e-01,  ..., 1.9887e-01,
           7.2192e-02, 1.0702e-01],
          [6.5866e-02, 2.3005e-01, 2.2384e-01,  ..., 2.1413e-01,
           4.9587e-02, 9.7449e-02],
          ...,
          [2.1877e-01, 4.9594e-02, 3.1530e-01,  ..., 2.6952e-01,
           1.2913e-01, 2.1703e-01],
          [1.5761e-01, 4.2336e-02, 7.9946e-02,  ..., 1.7075e-01,
           6.8478e-02, 2.3904e-01],
          [3.0621e-01, 3.1975e-01, 7.4178e-02,  ..., 8.0217e-02,
           1.6002e-01, 9.4388e-02]],

         [[3.5447e-01, 2.8247e-01, 4.2736e-01,  ..., 2.2232e-01,
           2.6967e-01, 4.0318e-01],
          [3.3135e-01, 4.1400e-01, 1.0163e-01,  ..., 4.3457e-01,
           2.1152e-02, 2.0256e-01],
          [3.7009e-01, 3.6941e-02, 4.0027e-01,  ..., 4.5317e-01,
           3.8799e-01, 5.8958e-02],
          ...,
          [4.5340e-01, 2.7432e-01, 2.2257e-01,  ..., 1.5889e-01,
           2.6433e-01, 8.9026e-03],
          [4.6516e-01, 2.7365e-01, 4.4097e-01,  ..., 8.4256e-02,
           3.5146e-01, 2.8197e-01],
          [1.5617e-01, 2.5397e-01, 3.3196e-01,  ..., 1.6097e-01,
           3.0016e-01, 1.5150e-01]]],

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值