批标准化

文章目录

批标准化

批标准化通俗来说就是对每一层神经网络进行标准化 (normalize) 处理, 对输入数据进行标准化能让机器学习有效率地学习.构建带有 BN 的神经网络的. BN 其实可以看做是一个 layer (BN layer).
就像平时加层一样加 BN layer 就好了. 注意, 还对输入数据进行了一个 BN 处理, 因为如果你把输入数据看出是
从前面一层来的输出数据, 我们同样也能对她进行 BN.

#批标准化
import torch
from torch import nn
from torch.nn import init
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 超参数
N_SAMPLES = 2000
BATCH_SIZE = 64
EPOCH = 12
LR = 0.03
N_HIDDEN = 8
ACTIVATION = F.tanh     # 你可以换 relu 试试
B_INIT = -0.2   # 模拟不好的 参数初始化

# training data
x = np.linspace(-7, 10, N_SAMPLES)[:, np.newaxis]
noise = np.random.normal(0, 2, x.shape)
y = np.square(x) - 5 + noise

# test data
test_x = np.linspace(-7, 10, 200)[:, np.newaxis]
noise = np.random.normal(0, 2, test_x.shape)
test_y = np.square(test_x) - 5 + noise

train_x, train_y = torch.from_numpy(x).float(), torch.from_numpy(y).float()
test_x = torch.from_numpy(test_x).float()
test_y = torch.from_numpy(test_y).float()

train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)

# show data
plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
plt.legend(loc='upper left')

#创建两个神经网络一个有Batch Normalization一个无
class Net(nn.Module):

    def __init__(self, batch_normalization=False):
        super(Net, self).__init__()
        self.do_bn = batch_normalization
        self.fcs = []
        self.bns = []
        self.bn_input = nn.BatchNorm1d(1, momentum=0.5)   # 对input也进行BN
        # (1, momentum=0.5)1代表有多少个输入值0.5用来平滑化 batch mean and stddev 的
        for i in range(N_HIDDEN):               # 建hidden layers,BN layers

            input_size = 1 if i == 0 else 10#(输入层时等于1,隐藏层时=10)
            fc = nn.Linear(input_size, 10)#输出10个神经元
            setattr(self, 'fc%i' % i, fc)       # 只打上面那句class Net(nn.Module)里不识别
            # setattr()相当于self.fcs = []。。的功能
            self._set_init(fc)                  # parameters initialization
            self.fcs.append(fc)
            if self.do_bn:
                bn = nn.BatchNorm1d(10, momentum=0.5)
                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module
                self.bns.append(bn)

        self.predict = nn.Linear(10, 1)         # output layer
        self._set_init(self.predict)            # parameters initialization

    def _set_init(self, layer):#参数初始化
        init.normal_(layer.weight, mean=0., std=.1)
        init.constant_(layer.bias, B_INIT)
    def forward(self, x):
        pre_activation = [x]#没有经历激活函数
        if self.do_bn: x = self.bn_input(x)     # 对input数据进行batch normalization
        layer_input = [x]
        for i in range(N_HIDDEN):
            x = self.fcs[i](x)
            pre_activation.append(x)
            if self.do_bn: x = self.bns[i](x)   # batch normalization
            x = ACTIVATION(x)
            layer_input.append(x)

        out = self.predict(x)
        return out, layer_input, pre_activation#不仅输出最终的预测



nets = [Net(batch_normalization=False), Net(batch_normalization=True)]



# print(*nets)    # print net architecture



opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]



loss_func = torch.nn.MSELoss()





def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):

    for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):

        [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]

        if i == 0:

            p_range = (-7, 10);the_range = (-7, 10)

        else:

            p_range = (-4, 4);the_range = (-1, 1)

        ax_pa.set_title('L' + str(i))

        ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)

        ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')

        for a in [ax_pa, ax, ax_pa_bn, ax_bn]: a.set_yticks(());a.set_xticks(())

        ax_pa_bn.set_xticks(p_range);ax_bn.set_xticks(the_range)

        axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')

    plt.pause(0.01)





if __name__ == "__main__":

    f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))

    plt.ion()  # something about plotting

    plt.show()



    # training

    losses = [[], []]  # recode loss for two networks



    for epoch in range(EPOCH):

        print('Epoch: ', epoch)

        layer_inputs, pre_acts = [], []

        for net, l in zip(nets, losses):

            net.eval()              # set eval mode to fix moving_mean and moving_var

            pred, layer_input, pre_act = net(test_x)

            l.append(loss_func(pred, test_y).data.item())

            layer_inputs.append(layer_input)

            pre_acts.append(pre_act)

            net.train()             # free moving_mean and moving_var

        plot_histogram(*layer_inputs, *pre_acts)     # plot histogram



        for step, (b_x, b_y) in enumerate(train_loader):
            for net, opt in zip(nets, opts):     # train for each network
                pred, _, _ = net(b_x)
                loss = loss_func(pred, b_y)
                opt.zero_grad()
                loss.backward()
                opt.step()    # it will also learns the parameters in Batch Normalization



    plt.ioff()
 # plot training loss

    plt.figure(2)

    plt.plot(losses[0], c='#FF9359', lw=3, label='Original')

    plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')

    plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')



    # evaluation

    # set net to eval mode to freeze the parameters in batch normalization layers

    [net.eval() for net in nets]    # set eval mode to fix moving_mean and moving_var

    preds = [net(test_x)[0] for net in nets]

    plt.figure(3)

    plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')

    plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')

    plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')

    plt.legend(loc='best')

    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值