tf.losses.get_regularization_loss() 和 tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)的区别

一、 引言

对于模型的训练过程中,为了防止过拟合,往往会加入一些措施,比如说dropout,l1正则化,l2正则化等。这里简单替一下,怎么加入正则化。

二、加入正则化

2.1 方式一(定义参数时候,加入正则化,这个是对于tf.nn):

          dwise_weight = tf.get_variable(name='depthwise_weights', dtype=tf.float32, trainable=True,
                                           shape=(3, 3, input_c, 1),
                                           initializer=tf.random_normal_initializer(stddev=0.01),regularizer=tf.contrib.layers.l2_regularizer(0.0005))

2.2 直接在定义层的加入正则化

对于Keras

    conv = tf.keras.layers.Conv2D(filters=filters_shape[-1], kernel_size = filters_shape[0], strides=strides, padding=padding,
                                  use_bias=not bn, kernel_regularizer=tf.keras.regularizers.l2(0.0005),
                                  kernel_initializer=tf.random_normal_initializer(stddev=0.01),
                                  bias_initializer=tf.constant_initializer(0.))(input_layer)

对于layers

  conv = tf.layers.conv2d(
            inputs = inputs, filters = filters_num,
            kernel_size = kernel_size, strides = [strides, strides], kernel_initializer = tf.glorot_uniform_initializer(),
            padding = ('SAME' if strides == 1 else 'VALID'), kernel_regularizer = tf.contrib.layers.l2_regularizer(scale = 5e-4), use_bias = use_bias, name = name)

三、tf.losses.get_regularization_loss() 和tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)的区别

网络框架打好之后,有两种统计l2正则化损失的方法。分别是tf.losses.get_regularization_loss() 和tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)。这两种有啥区别?

3.1 tf.losses.get_regularization_loss()

在网络搭建中,

self.l2_loss_1 = tf.losses.get_regularization_loss()

输出的是一个数,并且随着步数逐渐增加。

0.58757013

3.2 tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

在网络搭建中,

self.reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

输出的是一个矩阵(最后tf.get_collection对于下列矩阵是求和了的,输出的是一个数):

[0.0012587225, 0.0010193918, 0.0052134874, 0.00341414, 0.00074096455, 0.002031758, 0.006348608, 0.00038262826, 0.00085482444, 0.0029724198, 0.002807861, 0.0007209684, 0.00031811546, 0.0013439112, 0.0021733919, 0.00020669002, 8.275395e-05, 0.0013378835, 0.0022069137, 0.001741043, 0.0034375808, 0.00013721097, 0.001323716, 0.0014967844, 0.004495445, 0.00047215584, 0.001435072, 0.0016756648, 0.0013996841, 0.005932026, 0.00034188302, 0.014275889, 0.041253585, 0.014632461, 0.00071085553, 0.004095958, 0.006281647, 0.00018288092, 0.00031625855, 0.0023527981, 0.0026494847, 0.0021232541, 0.0025605266, 0.0005422489, 0.0037714622, 0.004139712, 0.00083760486, 0.00062693085, 0.004283203, 0.0049679745, 0.0025498145, 0.0028820795, 0.00038692102, 0.0038133918, 0.0026536714, 0.0012104163, 0.00047182888, 0.005267764, 0.004599052, 0.002836037, 0.006269276, 0.00034041578, 0.0042686653, 0.00214019, 0.0010493125, 0.000588716, 0.0054240623, 0.004182527, 0.0026253967, 0.0013253428, 0.00018922403, 0.01713413, 0.021168439, 0.00021331629, 0.0003404037, 0.0028425192, 0.0036303431, 0.0009226834, 0.00018170882, 0.0028716202, 0.0035161083, 0.0013216749, 0.0005894841, 0.00058683427, 0.0013555998, 0.0010963612, 0.0006417244, 0.00034985706, 0.0026754243, 0.0029172348, 0.0018723371, 0.0016511936, 0.00033535418, 0.002971437, 0.0021632009, 0.0015928581, 0.00063872937, 0.003048273, 0.0027831641, 0.0017480636, 0.0020496326, 0.00019814339, 0.0031466747, 0.0021316577, 0.0007197651, 0.0005315529, 0.0023107, 0.0015848394, 0.0011122554, 0.0009281139, 0.0004503858, 0.00238213, 0.001816663, 0.00097569835, 0.00031585852, 0.0022457493, 0.0015602952, 0.001193227, 0.0015331523, 0.0004179199, 0.0023959437, 0.0017357501, 0.0011387379, 0.00038396072, 0.0032778345, 0.0018264554, 0.0010251866, 0.0010517415, 0.000410388, 0.0030574333, 0.0021174678, 0.0012139543, 0.00048116146, 0.0025172373, 0.0012459956, 0.0007727607, 0.00044975439, 2.882598e-05, 0.0004120011, 0.0001914374, 0.00052413816, 0.000120028904, 0.003719843, 0.0017916383, 0.0013264294, 0.0017006957, 0.00047912705, 0.0013180161, 0.0015887399, 6.201234e-05, 0.0002548375, 1.9491777e-05, 5.6108373e-05, 0.00048251252, 1.6500117e-05, 0.00021688869, 0.00012332862, 0.00013551643, 0.00042087503, 2.729059e-05, 7.538646e-05, 9.9515884e-05, 0.00031406822, 1.4336873e-05, 4.5481807e-05, 3.7250775e-05, 6.7927074e-05, 0.00011680615, 1.8122297e-05, 1.2117926e-05, 1.5117716e-05, 0.00031482978, 4.59927e-06, 3.122417e-05, 1.9584195e-05, 3.5352397e-05, 0.0005032815, 3.6575593e-06, 6.774779e-06, 1.0360999e-05, 0.00030892712, 2.5183053e-06, 3.9015135e-06, 3.1233471e-06, 9.5274445e-06, 0.00028210497, 2.1516175e-06, 2.4816518e-05, 3.9151535e-05, 0.0006750349, 6.9654807e-06, 3.46891e-05, 2.3807737e-05, 3.670531e-05, 0.00034138284, 8.539587e-06, 1.5459173e-05, 2.2549477e-05, 0.00042386923, 4.6176965e-06, 6.9554744e-06, 4.210681e-06, 1.4128666e-05, 0.0006702608, 1.0705651e-06, 7.0838455e-06, 1.1673798e-05, 0.00017209895, 1.9372294e-06, 5.3059757e-06, 5.2798428e-06, 1.2961367e-05, 0.00060915365, 6.509151e-06, 1.3141949e-05, 1.3468712e-05, 0.00050954963, 5.740558e-06, 2.4997694e-06, 1.7281184e-06, 5.7629754e-06, 0.00034941913, 2.8325378e-06, 3.2588482e-06, 5.3775743e-06, 0.00019110511, 2.5025427e-06, 1.7237888e-05, 1.3587475e-05, 2.0598216e-05, 0.0005371568, 1.9321478e-06, 7.603406e-06, 1.1545717e-05, 0.00014172144, 7.762021e-06, 5.7557813e-06, 5.8592386e-06, 1.50709175e-05, 0.0006336278, 2.3983177e-06, 7.397871e-05, 8.0926766e-05, 0.0006022396, 5.7875645e-06, 6.246662e-06, 7.893348e-06, 1.2292606e-05, 0.00021138701, 5.2278085e-07, 1.1165079e-05, 1.7876693e-05, 0.0006405, 1.763156e-06, 7.871037e-06, 6.097522e-06, 1.4337125e-05, 0.0003488061, 4.106041e-06, 1.1237236e-05, 2.0877109e-05, 0.0005505738, 3.4871016e-06, 6.9442794e-06, 7.3363135e-06, 2.7831536e-05, 0.00047160118, 2.940275e-06, 3.8799085e-06, 7.175076e-06, 0.0007043817, 3.360974e-06, 2.6693504e-05, 2.0144544e-05, 3.2222542e-05, 0.0008319255, 5.6719446e-06, 6.0634757e-06, 8.442739e-06, 0.00025592584, 1.6397635e-06, 4.354275e-06, 2.7345804e-06, 7.6029755e-06, 0.00021582501, 2.8183663e-07, 2.5316986e-05, 2.9046081e-05, 0.0005600996, 6.2943077e-06, 9.300074e-06, 6.3922166e-06, 2.0472015e-05, 0.00029123586, 4.126493e-06, 0.0021376628, 4.0073966e-05, 0.00032663788, 0.00019001846, 0.00019183535, 0.00010044787, 9.714952e-05, 0.0003104542, 0.00033163588, 3.901839e-05, 0.00030362647, 0.00017132689, 0.00040824036, 0.000759582, 0.0007054352, 0.00054079085, 0.0023890003, 0.00042751216, 0.0005998212, 0.02637032, 0.0026553501, 0.014795203, 0.009039405, 0.0018787832, 0.0062182304, 0.004410134, 0.0029215082, 0.00021064303, 0.0039598923, 0.03058016, 0.0022820449, 0.020663949, 0.014974551, 0.0016926205, 0.010938942, 0.009170251, 0.024396155, 0.000110875415]

然后就和,我们可以知道:

import numpy as np 
b = np.array([0.0012587225, 0.0010193918, 0.0052134874, 0.00341414, 0.00074096455, 0.002031758, 0.006348608, 0.00038262826, 0.00085482444, 0.0029724198, 0.002807861, 0.0007209684, 0.00031811546, 0.0013439112, 0.0021733919, 0.00020669002, 8.275395e-05, 0.0013378835, 0.0022069137, 0.001741043, 0.0034375808, 0.00013721097, 0.001323716, 0.0014967844, 0.004495445, 0.00047215584, 0.001435072, 0.0016756648, 0.0013996841, 0.005932026, 0.00034188302, 0.014275889, 0.041253585, 0.014632461, 0.00071085553, 0.004095958, 0.006281647, 0.00018288092, 0.00031625855, 0.0023527981, 0.0026494847, 0.0021232541, 0.0025605266, 0.0005422489, 0.0037714622, 0.004139712, 0.00083760486, 0.00062693085, 0.004283203, 0.0049679745, 0.0025498145, 0.0028820795, 0.00038692102, 0.0038133918, 0.0026536714, 0.0012104163, 0.00047182888, 0.005267764, 0.004599052, 0.002836037, 0.006269276, 0.00034041578, 0.0042686653, 0.00214019, 0.0010493125, 0.000588716, 0.0054240623, 0.004182527, 0.0026253967, 0.0013253428, 0.00018922403, 0.01713413, 0.021168439, 0.00021331629, 0.0003404037, 0.0028425192, 0.0036303431, 0.0009226834, 0.00018170882, 0.0028716202, 0.0035161083, 0.0013216749, 0.0005894841, 0.00058683427, 0.0013555998, 0.0010963612, 0.0006417244, 0.00034985706, 0.0026754243, 0.0029172348, 0.0018723371, 0.0016511936, 0.00033535418, 0.002971437, 0.0021632009, 0.0015928581, 0.00063872937, 0.003048273, 0.0027831641, 0.0017480636, 0.0020496326, 0.00019814339, 0.0031466747, 0.0021316577, 0.0007197651, 0.0005315529, 0.0023107, 0.0015848394, 0.0011122554, 0.0009281139, 0.0004503858, 0.00238213, 0.001816663, 0.00097569835, 0.00031585852, 0.0022457493, 0.0015602952, 0.001193227, 0.0015331523, 0.0004179199, 0.0023959437, 0.0017357501, 0.0011387379, 0.00038396072, 0.0032778345, 0.0018264554, 0.0010251866, 0.0010517415, 0.000410388, 0.0030574333, 0.0021174678, 0.0012139543, 0.00048116146, 0.0025172373, 0.0012459956, 0.0007727607, 0.00044975439, 2.882598e-05, 0.0004120011, 0.0001914374, 0.00052413816, 0.000120028904, 0.003719843, 0.0017916383, 0.0013264294, 0.0017006957, 0.00047912705, 0.0013180161, 0.0015887399, 6.201234e-05, 0.0002548375, 1.9491777e-05, 5.6108373e-05, 0.00048251252, 1.6500117e-05, 0.00021688869, 0.00012332862, 0.00013551643, 0.00042087503, 2.729059e-05, 7.538646e-05, 9.9515884e-05, 0.00031406822, 1.4336873e-05, 4.5481807e-05, 3.7250775e-05, 6.7927074e-05, 0.00011680615, 1.8122297e-05, 1.2117926e-05, 1.5117716e-05, 0.00031482978, 4.59927e-06, 3.122417e-05, 1.9584195e-05, 3.5352397e-05, 0.0005032815, 3.6575593e-06, 6.774779e-06, 1.0360999e-05, 0.00030892712, 2.5183053e-06, 3.9015135e-06, 3.1233471e-06, 9.5274445e-06, 0.00028210497, 2.1516175e-06, 2.4816518e-05, 3.9151535e-05, 0.0006750349, 6.9654807e-06, 3.46891e-05, 2.3807737e-05, 3.670531e-05, 0.00034138284, 8.539587e-06, 1.5459173e-05, 2.2549477e-05, 0.00042386923, 4.6176965e-06, 6.9554744e-06, 4.210681e-06, 1.4128666e-05, 0.0006702608, 1.0705651e-06, 7.0838455e-06, 1.1673798e-05, 0.00017209895, 1.9372294e-06, 5.3059757e-06, 5.2798428e-06, 1.2961367e-05, 0.00060915365, 6.509151e-06, 1.3141949e-05, 1.3468712e-05, 0.00050954963, 5.740558e-06, 2.4997694e-06, 1.7281184e-06, 5.7629754e-06, 0.00034941913, 2.8325378e-06, 3.2588482e-06, 5.3775743e-06, 0.00019110511, 2.5025427e-06, 1.7237888e-05, 1.3587475e-05, 2.0598216e-05, 0.0005371568, 1.9321478e-06, 7.603406e-06, 1.1545717e-05, 0.00014172144, 7.762021e-06, 5.7557813e-06, 5.8592386e-06, 1.50709175e-05, 0.0006336278, 2.3983177e-06, 7.397871e-05, 8.0926766e-05, 0.0006022396, 5.7875645e-06, 6.246662e-06, 7.893348e-06, 1.2292606e-05, 0.00021138701, 5.2278085e-07, 1.1165079e-05, 1.7876693e-05, 0.0006405, 1.763156e-06, 7.871037e-06, 6.097522e-06, 1.4337125e-05, 0.0003488061, 4.106041e-06, 1.1237236e-05, 2.0877109e-05, 0.0005505738, 3.4871016e-06, 6.9442794e-06, 7.3363135e-06, 2.7831536e-05, 0.00047160118, 2.940275e-06, 3.8799085e-06, 7.175076e-06, 0.0007043817, 3.360974e-06, 2.6693504e-05, 2.0144544e-05, 3.2222542e-05, 0.0008319255, 5.6719446e-06, 6.0634757e-06, 8.442739e-06, 0.00025592584, 1.6397635e-06, 4.354275e-06, 2.7345804e-06, 7.6029755e-06, 0.00021582501, 2.8183663e-07, 2.5316986e-05, 2.9046081e-05, 0.0005600996, 6.2943077e-06, 9.300074e-06, 6.3922166e-06, 2.0472015e-05, 0.00029123586, 4.126493e-06, 0.0021376628, 4.0073966e-05, 0.00032663788, 0.00019001846, 0.00019183535, 0.00010044787, 9.714952e-05, 0.0003104542, 0.00033163588, 3.901839e-05, 0.00030362647, 0.00017132689, 0.00040824036, 0.000759582, 0.0007054352, 0.00054079085, 0.0023890003, 0.00042751216, 0.0005998212, 0.02637032, 0.0026553501, 0.014795203, 0.009039405, 0.0018787832, 0.0062182304, 0.004410134, 0.0029215082, 0.00021064303, 0.0039598923, 0.03058016, 0.0022820449, 0.020663949, 0.014974551, 0.0016926205, 0.010938942, 0.009170251, 0.024396155, 0.000110875415])
print(np.sum(b))

输出值为:

0.58754411552538

所以说,这两种最后得到的损失值都是都是一样的。

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
代码time_start = time.time() results = list() iterations = 2001 lr = 1e-2 model = func_critic_model(input_shape=(None, train_img.shape[1]), act_func='relu') loss_func = tf.keras.losses.MeanSquaredError() alg = "gd" # alg = "gd" for kk in range(iterations): with tf.GradientTape() as tape: predict_label = model(train_img) loss_val = loss_func(predict_label, train_lbl) grads = tape.gradient(loss_val, model.trainable_variables) overall_grad = tf.concat([tf.reshape(grad, -1) for grad in grads], 0) overall_model = tf.concat([tf.reshape(weight, -1) for weight in model.weights], 0) overall_grad = overall_grad + 0.001 * overall_model ## adding a regularization term results.append(loss_val.numpy()) if alg == 'gd': overall_model -= lr * overall_grad ### gradient descent elif alg == 'gdn': ## gradient descent with nestrov's momentum overall_vv_new = overall_model - lr * overall_grad overall_model = (1 + gamma) * oerall_vv_new - gamma * overall_vv overall_vv = overall_new pass model_start = 0 for idx, weight in enumerate(model.weights): model_end = model_start + tf.size(weight) weight.assign(tf.reshape()) for grad, ww in zip(grads, model.weights): ww.assign(ww - lr * grad) if kk % 100 == 0: print(f"Iter: {kk}, loss: {loss_val:.3f}, Duration: {time.time() - time_start:.3f} sec...") input_shape = train_img.shape[1] - 1 model = tf.keras.Sequential([ tf.keras.layers.Input(shape=(input_shape,)), tf.keras.layers.Dense(30, activation="relu"), tf.keras.layers.Dense(20, activation="relu"), tf.keras.layers.Dense(1) ]) n_epochs = 20 batch_size = 100 learning_rate = 0.01 momentum = 0.9 sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum) model.compile(loss="mean_squared_error", optimizer=sgd_optimizer) history = model.fit(train_img, train_lbl, epochs=n_epochs, batch_size=batch_size, validation_data=(test_img, test_lbl)) nag_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum, nesterov=True) model.compile(loss="mean_squared_error", optimizer=nag_optimizer) history = model.fit(train_img, train_lbl, epochs=n_epochs, batch_size=batch_size, validation_data=(test_img, test_lbl))运行后报错TypeError: Missing required positional argument,如何改正
05-22
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值