Ryan Dahl的tensorflow-resnet中Config类实现代码详细注释

      Ryan Dahl用Tensorflow实现了ResNet,并在Github上开源了其代码。这份代码写得比较复杂,用了比较底层的方式实现。相比Tensorflow官方提供的基于slim库实现的代码要复杂很多。Ryan Dahl自己也在该代码的Readme文件夹中说这份代码很老了,建议用slim库中的代码。但作为初学者学习Tensorflow,硬啃这份代码,必然对编程能力的提高有帮助。

      我最近花了点时间把这份代码中Config.py文件中的代码弄明白了,并做了详细的注释,现放上来,供需要的朋友参考,希望能有所帮助。

      Config.py文件实现的是通过作用域的管理来控制变量在不同作用域的共享。代码定义了一个变量作用域管理的Config类,并实现了一个小例子。在例子中先定义一个Config对象,然后向其列表中的字典中添加新项,相同作用域下添加的项会自动放在同一个字典中。例子也展示了作用域外对作用域内添加的项是不可见的,若试图从作用域外读取作用域内添加的项,则会报错。

      要弄懂这份代码,需要掌握的python知识包括列表、字典、类、魔法函数等,建议通过设置断点,在调试模式下查看函数的调用顺序。通过在适当的地方加入print函数,查看各变量的值的变化,了解各部分代码的作用.

       简书上有一篇介绍这份代码的文章,但不是很详细,可以把它和本文结合起来学习。

# This is a variable scope aware configuation object for TensorFlow

# 本代码实现一个变量作用域管理的类,并包含了一个小例子,便于读者了解其用法

#参考链接
# https://www.jianshu.com/p/1a7a050e043f

# 源码链接
# https://github.com/ry/tensorflow-resnet/blob/master/config.py

import tensorflow as tf    # 导入Tensorflow模块

FLAGS = tf.app.flags.FLAGS    # flags解析.主要用于从对应的命令行参数取出参数

class Config:
    def __init__(self):
        root = self.Scope('')    # 创建Scope的空对象(空字典)
        for k in FLAGS:    
            v = FLAGS[k].value    # 把FLAGS中所有的值依次付给v
            root[k] = v    # 向root中相应的键中添加FLAGS中对应的值,root是scope对象,近似一个字典
        self.stack = [ root ]    # 把root放入列表中(后续用该列表实现栈的功能,所以命名为stack)
        # stack为包含字典的列表,形如[{'flag_string': 'yes', 'flag_float': 0.01, 'flag_int': 400, 'flag_bool': True}]

    def iteritems(self):    # 获取所stack中所有元素(字典)中所有的更新后的项
        return self.to_dict().iteritems()    # self.to_dict()返回的是一个字典
    # 后面的iteritems()函数是字典类的成员函数(返回字典的迭代器,该函数在python3中被items()替代了),不是Config类的同名成员函数 

    def to_dict(self):    # 把stack中的所有元素中的键和值放入一个字典,其中重复的用新的代替旧的 
        self._pop_stale()    # 把当前作用域名放到stack列表的的第一个元素中
        out = {}    # 创建空字典
        # Work backwards from the flags to top fo the stack
        # overwriting keys that were found earlier.(把stack中靠后面的元素中的键用靠前面的元素中同名的键覆盖,stack每个元素是一个字典)
        for i in range(len(self.stack)):    
            cs = self.stack[-i]    # 获取stack中倒数第i个元素(是个字典)
            for name in cs:
                out[name] = cs[name]    # 把字典cs付给字典out(如果是out中已有的键,则会将其值覆盖)
        return out    # 返回最新的out    

    def _pop_stale(self):    # 若当前作用域名不是以stack中第一个元素的名字(初始值为'')开头(用于判断后一个作用域是否包含于前一个作用域),把之前在它前面的元素都弹出
        var_scope_name = tf.get_variable_scope().name    # 获取变量作用域名       
        top = self.stack[0]    # 把stack列表中的第一个元素(字典)赋给top       
        # print('var_scope_name的值为:', var_scope_name)
        # print('stack的值为:', self.stack)
        # print('top.name的值为:', top.name)
        
        while not top.contains(var_scope_name):    # 若var_scope_name不是以top的名字(初始值为'')开头
            # We aren't in this scope anymore
            self.stack.pop(0)    # 弹出stack列表中的第一个元素
            top = self.stack[0]    # 把stack中第一个元素赋值给top
        # print('stack的值为:', self.stack)

    def __getitem__(self, name):     # 重写__getitem__魔法函数找到stack的元素中name键对应的值
        self._pop_stale()     # 把包含当前作用域名的stack元素放到stack列表首位,把之前在它前面的元素都弹出
        # Recursively extract value        
        for i in range(len(self.stack)):    # 查看stack的元素中是否有name键,有的话返回对应的值
            cs = self.stack[i]
            if name in cs:
                return cs[name]    

        raise KeyError(name)   # 若name不在stack的元素中,抛出异常

    def set_default(self, name, value):    # 查看config对象中是否有指定的键,若没有,则把指定的键名和值赋给Config对象
        if not name in self:
            self[name] = value    # 该句执行时会自动调用__setitem__实现向Config对象中添加新项

    def __contains__(self, name):    # 重写__contains__魔法函数,查看stack的元素中是否有name键,有的话返回Ture,否则返回False
        self._pop_stale()    # 把包含当前作用域名的stack元素放到stack列表首位,把之前在它前面的元素都弹出
        for i in range(len(self.stack)):    #查看stack的元素中是否有name键,有的话返回Ture,否则返回False
            cs = self.stack[i]
            if name in cs:
                return True
        return False

    def __setitem__(self, name, value):     # 重写__setitem__魔法函数,当向类的对象添加新项时,自动调用
        self._pop_stale()     # 把包含当前作用域名的stack元素放到stack列表首位,把之前在它前面的元素都弹出
        top = self.stack[0]    # 把stack的第一个元素赋给top       
        var_scope_name = tf.get_variable_scope().name    # 获得当前作用域名
        
        assert top.contains(var_scope_name)    # 断言语句,若表达式的值为假,则抛出异常

        if top.name != var_scope_name:    # 若当前作用域名与top的名字不同
            top = self.Scope(var_scope_name)    #重新创建以'var_scope_name'命名的Scope对象赋给top
            self.stack.insert(0, top)    # 把新的top放到stack列表首位

        top[name] = value    # 向top中添加键和值,该操作同样作用于stack[0],因为他们是指向同一存储区域的


    class Scope(dict):    # 定义Scope(作用域)类,它继承自dict类
        def __init__(self, name):    # 构造函数
            self.name = name     # 把字符串传递给对象名字   

        def contains(self, var_scope_name):    # 检查是否包含指定的Scope名字
            return var_scope_name.startswith(self.name)    # startswith() 方法用于检查字符串是否是以指定子字符串开头,如果是则返回 True,否则返回 False



# Test
# 下面的例子,先定义一个Config对象,然后向其列表中的字典中添加新项,相同作用域下添加的项会自动放在同一个字典中
if __name__ == '__main__':    # 判断是否在运行这个文件,如果是作为模块导入到其他文件中运行,则__name__属性值为‘config’
    def assert_raises(exception, fn):    # 定义assert_raises函数,当函数出错时抛出异常
        try:
            fn()
        except exception:    # 若发生指定的异常
            pass    # 什么也不执行
        else:
            assert False, "Expected exception"    #assert为断言语句,相当于raise-if-not,若表达式为假,则弹出异常

    c = Config()    # 新建Config对象
    # 此时c.stack的值为[{}](包含一个空字典的列表)

    c['hello'] = 1    # 自动调用c.__setitem__()(魔法函数,向c添加键和值时自动调用),把新的键和值添加到c.stack里   
    # 此时c.stack的值为[{'hello':1}],当前作用域名为''(空)
    assert c['hello'] == 1    #检查该键和值是否正确添加,否则抛出异常

    with tf.variable_scope('foo'):    # 指定以下代码中变量的作用域为‘foo’。tf.variable_scope函数的作用是管理传给get_variable()的变量名称的作用域
        c.set_default("bar", 10)    # 添加新项
        # 此时c.stack的值为[{'bar':10},{'hello':1}],当前作用域名为‘foo’
        c['bar'] = 2    # 覆盖‘bar’键原来的值
        # 此时c.stack的值为[{'bar':2},{'hello':1}]
        
        assert c['bar'] == 2    # 判断c中‘bar’对应的值是否为2,不是则抛出异常
        assert c['hello'] == 1

        c.set_default("mario", True)    # 添加新项
        # 此时c.stack的值为[{'bar':2,‘mario’:True},{'hello':1}]

        with tf.variable_scope('meow'):
            c['dog'] = 3
            # 此时c.stack的值为[{'dog':3},{'bar':2,‘mario’:True},{'hello':1}]
            assert c['dog'] == 3    # 判断c中‘dog’对应的值是否为2,不是则抛出异常
            assert c['bar'] == 2
            assert c['hello'] == 1

            assert c['mario'] == True

        assert_raises(KeyError, lambda: c['dog'])    # assert_raises函数的第二个参数是lambda表达式,返回c['dog']的值。若c中不存在‘dog’键,则抛出异常
        # 此时c.stack的值为[{'bar':2,‘mario’:True},{'hello':1}]
        # 因此会引发KeyError错误,但assert_raises的设定是当KeyError错误出现时,什么也不执行,不知道作者为什么要这样设定
        # 加入此处执行assert c['dog'] == 3,则会抛出KeyError异常
        assert c['bar'] == 2
        assert c['hello'] == 1

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值