tensorflow regularizer(正则化)防止过拟合

Regularizer是防止网络过拟合的一种有效方法。这篇文章主要探讨如何在自己的网络模型中加入正则化,防止过拟合。

首先我们看一下正则化的基本使用方法,这篇博客给出了一个使用的例子:

http://www.cnblogs.com/linyuanzhou/p/6923607.html

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tf_regularization.py
#Author: Wang 
#Mail: wang****@hotmail.com
#Created Time:2017-08-23 11:53:34
############################

import tensorflow as tf 
from tensorflow.contrib import layers

myreg1 = layers.l1_regularizer(0.01)     #创建一个正则化方法, 0.01为系数,相当于给每个参数前乘以0.01,当然这里也可以是l2方法或者sum混合方法

with tf.variable_scope('var', initializer = tf.random_normal_initializer(), regularizer = myreg1):    #高能!:参数里面指明了regularizer
    weight = tf.get_variable('weight', shape=[8], initializer = tf.ones_initializer())

with tf.variable_scope('var2', initializer = tf.random_normal_initializer(), regularizer = myreg1):
    weight2 = tf.get_variable('weight', shape=[8], initializer = tf.ones_initializer())

regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))        #get_collection 获得list, reduce_sum进行对list求和

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
with sess.as_default():
    result = regularization_loss.eval() 
print result

最后的输出结果为0.16。

那么当我们需要在自己的网络中加入正则化时该怎么做? 继续上代码。

首先创建一个net.py文件,这个是我们自己的网络模型:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: net.py
#Author: Wang
#Mail: wang**@hotmail.com
#Created Time:2017-08-23 12:10:48
############################

import tensorflow as tf
import numpy as np
from tensorflow.contrib import layers

class mynet:
    
    def __init__(self):
        self.myreg1 = layers.l1_regularizer(0.01)
        self.inference()

    def inference(self):
        with tf.variable_scope('var', initializer = tf.random_normal_initializer(), regularizer = self.myreg1):
	    weight = tf.get_variable('weight', shape = [8], initializer = tf.ones_initializer())


然后是我们训练网络的主程序,这里面需要定义数据和loss,学习方法等:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: test.py
#Author: Wang
#Mail: wang***@hotmail.com
#Created Time:2017-08-23 12:10:28
############################

import tensorflow as tf
from net import mynet

sess = tf.Session()

mnet = mynet()

init = tf.global_variables_initializer()
sess.run(init)

regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

with sess.as_default():
    result = regularization_loss.eval()

print result

最后的输出结果为0.08。






  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值