Ryan Dahl用Tensorflow实现了ResNet,并在Github上开源了其代码。这份代码写得比较复杂,用了比较底层的方式实现。相比Tensorflow官方提供的基于slim库实现的代码要复杂很多。Ryan Dahl自己也在该代码的Readme文件夹中说这份代码很老了,建议用slim库中的代码。但作为初学者学习Tensorflow,硬啃这份代码,必然对编程能力的提高有帮助。
我最近花了点时间把这份代码中Config.py文件中的代码弄明白了,并做了详细的注释,现放上来,供需要的朋友参考,希望能有所帮助。
Config.py文件实现的是通过作用域的管理来控制变量在不同作用域的共享。代码定义了一个变量作用域管理的Config类,并实现了一个小例子。在例子中先定义一个Config对象,然后向其列表中的字典中添加新项,相同作用域下添加的项会自动放在同一个字典中。例子也展示了作用域外对作用域内添加的项是不可见的,若试图从作用域外读取作用域内添加的项,则会报错。
要弄懂这份代码,需要掌握的python知识包括列表、字典、类、魔法函数等,建议通过设置断点,在调试模式下查看函数的调用顺序。通过在适当的地方加入print函数,查看各变量的值的变化,了解各部分代码的作用.
简书上有一篇介绍这份代码的文章,但不是很详细,可以把它和本文结合起来学习。
# This is a variable scope aware configuation object for TensorFlow
# 本代码实现一个变量作用域管理的类,并包含了一个小例子,便于读者了解其用法
#参考链接
# https://www.jianshu.com/p/1a7a050e043f
# 源码链接
# https://github.com/ry/tensorflow-resnet/blob/master/config.py
import tensorflow as tf # 导入Tensorflow模块
FLAGS = tf.app.flags.FLAGS # flags解析.主要用于从对应的命令行参数取出参数
class Config:
def __init__(self):
root = self.Scope('') # 创建Scope的空对象(空字典)
for k in FLAGS:
v = FLAGS[k].value # 把FLAGS中所有的值依次付给v
root[k] = v # 向root中相应的键中添加FLAGS中对应的值,root是scope对象,近似一个字典
self.stack = [ root ] # 把root放入列表中(后续用该列表实现栈的功能,所以命名为stack)
# stack为包含字典的列表,形如[{'flag_string': 'yes', 'flag_float': 0.01, 'flag_int': 400, 'flag_bool': True}]
def iteritems(self): # 获取所stack中所有元素(字典)中所有的更新后的项
return self.to_dict().iteritems() # self.to_dict()返回的是一个字典
# 后面的iteritems()函数是字典类的成员函数(返回字典的迭代器,该函数在python3中被items()替代了),不是Config类的同名成员函数
def to_dict(self): # 把stack中的所有元素中的键和值放入一个字典,其中重复的用新的代替旧的
self._pop_stale() # 把当前作用域名放到stack列表的的第一个元素中
out = {} # 创建空字典
# Work backwards from the flags to top fo the stack
# overwriting keys that were found earlier.(把stack中靠后面的元素中的键用靠前面的元素中同名的键覆盖,stack每个元素是一个字典)
for i in range(len(self.stack)):
cs = self.stack[-i] # 获取stack中倒数第i个元素(是个字典)
for name in cs:
out[name] = cs[name] # 把字典cs付给字典out(如果是out中已有的键,则会将其值覆盖)
return out # 返回最新的out
def _pop_stale(self): # 若当前作用域名不是以stack中第一个元素的名字(初始值为'')开头(用于判断后一个作用域是否包含于前一个作用域),把之前在它前面的元素都弹出
var_scope_name = tf.get_variable_scope().name # 获取变量作用域名
top = self.stack[0] # 把stack列表中的第一个元素(字典)赋给top
# print('var_scope_name的值为:', var_scope_name)
# print('stack的值为:', self.stack)
# print('top.name的值为:', top.name)
while not top.contains(var_scope_name): # 若var_scope_name不是以top的名字(初始值为'')开头
# We aren't in this scope anymore
self.stack.pop(0) # 弹出stack列表中的第一个元素
top = self.stack[0] # 把stack中第一个元素赋值给top
# print('stack的值为:', self.stack)
def __getitem__(self, name): # 重写__getitem__魔法函数找到stack的元素中name键对应的值
self._pop_stale() # 把包含当前作用域名的stack元素放到stack列表首位,把之前在它前面的元素都弹出
# Recursively extract value
for i in range(len(self.stack)): # 查看stack的元素中是否有name键,有的话返回对应的值
cs = self.stack[i]
if name in cs:
return cs[name]
raise KeyError(name) # 若name不在stack的元素中,抛出异常
def set_default(self, name, value): # 查看config对象中是否有指定的键,若没有,则把指定的键名和值赋给Config对象
if not name in self:
self[name] = value # 该句执行时会自动调用__setitem__实现向Config对象中添加新项
def __contains__(self, name): # 重写__contains__魔法函数,查看stack的元素中是否有name键,有的话返回Ture,否则返回False
self._pop_stale() # 把包含当前作用域名的stack元素放到stack列表首位,把之前在它前面的元素都弹出
for i in range(len(self.stack)): #查看stack的元素中是否有name键,有的话返回Ture,否则返回False
cs = self.stack[i]
if name in cs:
return True
return False
def __setitem__(self, name, value): # 重写__setitem__魔法函数,当向类的对象添加新项时,自动调用
self._pop_stale() # 把包含当前作用域名的stack元素放到stack列表首位,把之前在它前面的元素都弹出
top = self.stack[0] # 把stack的第一个元素赋给top
var_scope_name = tf.get_variable_scope().name # 获得当前作用域名
assert top.contains(var_scope_name) # 断言语句,若表达式的值为假,则抛出异常
if top.name != var_scope_name: # 若当前作用域名与top的名字不同
top = self.Scope(var_scope_name) #重新创建以'var_scope_name'命名的Scope对象赋给top
self.stack.insert(0, top) # 把新的top放到stack列表首位
top[name] = value # 向top中添加键和值,该操作同样作用于stack[0],因为他们是指向同一存储区域的
class Scope(dict): # 定义Scope(作用域)类,它继承自dict类
def __init__(self, name): # 构造函数
self.name = name # 把字符串传递给对象名字
def contains(self, var_scope_name): # 检查是否包含指定的Scope名字
return var_scope_name.startswith(self.name) # startswith() 方法用于检查字符串是否是以指定子字符串开头,如果是则返回 True,否则返回 False
# Test
# 下面的例子,先定义一个Config对象,然后向其列表中的字典中添加新项,相同作用域下添加的项会自动放在同一个字典中
if __name__ == '__main__': # 判断是否在运行这个文件,如果是作为模块导入到其他文件中运行,则__name__属性值为‘config’
def assert_raises(exception, fn): # 定义assert_raises函数,当函数出错时抛出异常
try:
fn()
except exception: # 若发生指定的异常
pass # 什么也不执行
else:
assert False, "Expected exception" #assert为断言语句,相当于raise-if-not,若表达式为假,则弹出异常
c = Config() # 新建Config对象
# 此时c.stack的值为[{}](包含一个空字典的列表)
c['hello'] = 1 # 自动调用c.__setitem__()(魔法函数,向c添加键和值时自动调用),把新的键和值添加到c.stack里
# 此时c.stack的值为[{'hello':1}],当前作用域名为''(空)
assert c['hello'] == 1 #检查该键和值是否正确添加,否则抛出异常
with tf.variable_scope('foo'): # 指定以下代码中变量的作用域为‘foo’。tf.variable_scope函数的作用是管理传给get_variable()的变量名称的作用域
c.set_default("bar", 10) # 添加新项
# 此时c.stack的值为[{'bar':10},{'hello':1}],当前作用域名为‘foo’
c['bar'] = 2 # 覆盖‘bar’键原来的值
# 此时c.stack的值为[{'bar':2},{'hello':1}]
assert c['bar'] == 2 # 判断c中‘bar’对应的值是否为2,不是则抛出异常
assert c['hello'] == 1
c.set_default("mario", True) # 添加新项
# 此时c.stack的值为[{'bar':2,‘mario’:True},{'hello':1}]
with tf.variable_scope('meow'):
c['dog'] = 3
# 此时c.stack的值为[{'dog':3},{'bar':2,‘mario’:True},{'hello':1}]
assert c['dog'] == 3 # 判断c中‘dog’对应的值是否为2,不是则抛出异常
assert c['bar'] == 2
assert c['hello'] == 1
assert c['mario'] == True
assert_raises(KeyError, lambda: c['dog']) # assert_raises函数的第二个参数是lambda表达式,返回c['dog']的值。若c中不存在‘dog’键,则抛出异常
# 此时c.stack的值为[{'bar':2,‘mario’:True},{'hello':1}]
# 因此会引发KeyError错误,但assert_raises的设定是当KeyError错误出现时,什么也不执行,不知道作者为什么要这样设定
# 加入此处执行assert c['dog'] == 3,则会抛出KeyError异常
assert c['bar'] == 2
assert c['hello'] == 1