Tensorflow使用技巧:通过graph.as_graph_def探索函数内部机制

Tensorflow有tf.Graph类,用于存储计算图。而计算图其实就是由节点和有向边组成,每个点包括操作Op、数值value、类型dtype、形状shape等属性。探索诸如tf.Variable()等函数的内部机制的过程中,就需要查看计算图的变化情况,包括新建了哪些节点,输入是什么等等。

例如想要探讨tf.constant函数的内部机制,则运行以下代码:

import tensorflow as tf

a = tf.constant(1)
print(tf.get_default_graph().as_graph_def())

返回

node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 1
      }
    }
  }
}
versions {
  producer: 22
}

就可以判断tf.constant函数其实仍然是有Operation的,为const,而值是用属性attr存储的。

进一步探讨tf.Variable机制,则同样运行以下代码:

import tensorflow as tf

a = tf.constant(1)
b = tf.Variable(a)
print(tf.get_default_graph().as_graph_def())

返回如下结果:

node {
  name: "ones"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Variable"
  op: "VariableV2"
  attr {
    key: "container"
    value {
      s: ""
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
      }
    }
  }
  attr {
    key: "shared_name"
    value {
      s: ""
    }
  }
}
node {
  name: "Variable/Assign"
  op: "Assign"
  input: "Variable"
  input: "ones"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@Variable"
      }
    }
  }
  attr {
    key: "use_locking"
    value {
      b: true
    }
  }
  attr {
    key: "validate_shape"
    value {
      b: true
    }
  }
}
node {
  name: "Variable/read"
  op: "Identity"
  input: "Variable"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@Variable"
      }
    }
  }
}
versions {
  producer: 22
}

发现虽然只是运行了一行代码,但是实际上添加了三个节点,分别是Variable、Variable/read和Variable/Assign,从字面理解可以猜测分别是作为存储器、读取接口和赋值接口。而这个assign的存在就是区分Variable和Tensor的重要标志(Variable是mutable的)。

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
载入数据集 mnist=input_data.read_data_sets("MNIST_data",one_hot=True) #设定训练批次的大小 batch_size=50 #计算多少个批次 n_batch=mnist.train.num_examples//batch_size def variable_summaries(var): with tf.name_scope('summaries'): mean=tf.reduce_mean(var) tf.summary.scalar('mean',mean)#平均值 with tf.name_scope('stddev'): stddev=tf.sqrt(tf.reduce_mean(tf.square(var-mean))) tf.summary.scalar('stddev',stddev)#标准差 tf.summary.scalar('max',tf.reduce_max(var))#最大值 tf.summary.scalar('min',tf.reduce_max(var))#最小值 tf.summary.histogram('histogram',var)#直方图 #命名空间 with tf.name_scope('input'): #定义两个placeholder x=tf.placeholder(tf.float32,[None,784],name='x-input') y=tf.placeholder(tf.float32,[None,10],name='y-input') with tf.name_scope('layer'): #建立神经网络 with tf.name_scope('wights'): W=tf.Variable(tf.zeros([784,10]),name='W') variable_summaries(W) with tf.name_scope('biases'): b=tf.Variable(tf.zeros([10]),name='b') variable_summaries(b) with tf.name_scope('wx_plus_b'): wx_plus_b=tf.matmul(x,W)+b with tf.name_scope('softmax'): predicton=tf.nn.softmax(wx_plus_b) #定义二次代价函数 # loss=tf.reduce_mean(tf.square(y-predicton)) with tf.name_scope('loss'): loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=predicton)) tf.summary.scalar('loss',loss) with tf.name_scope('train'): #使用梯度下降法 train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量 init=tf.global_variables_initializer() with tf.name_scope('accuracy'): with tf.name_scope('predicton_correct'): #预测结果用布尔型列表存放 predicton_correct=tf.equal(tf.argmax(y,1),tf.argmax(predicton,1))#argmax返回一维张量中最大值所在位置 with tf.name_scope('accuracy'): #计算准确率 accuracy=tf.reduce_mean(tf.cast(predicton_correct,tf.float32)) tf.summary.scalar('accuracy',accuracy) #h合并所有summary merged=tf.summary.merge_all() #建立会话 with tf.Session() as sess: sess.run(init) writer=tf.summary.FileWriter('logs/',sess.graph) #设置循环次数 for epoch in range(51): for batch in range(n_batch): batch_x,batch_y=mnist.train.next_batch(batch_size) summary,_=sess.run([merged,train_step],feed_dict={x:batch_x,y:batch_y}) writer.add_summary(summary,epoch) #导入测试集计算准确率 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) #打印正确率 print("Iter "+str(epoch)+",Testing Accuray "+str(acc))
class Path(object): def __init__(self,path,distancecost,timecost): self.__path = path self.__distancecost = distancecost self.__timecost = timecost #路径上最后一个节点 def getLastNode(self): return self.__path[-1] #获取路径路径 @property def path(self): return self.__path #判断node是否为路径上最后一个节点 def isLastNode(self, node): return node == self.getLastNode() #增加加点和成本产生一个新的path对象 def addNode(self, node, dprice, tprice): return Path(self.__path+[node],self.__distancecost + dprice,self.__timecost + tprice) #输出当前路径 def printPath(self): for n in self.__path: if self.isLastNode(node=n): print(n) else: print(n, end="->") print(f"最短路径距离(self.__distancecost:.0f)m") print(f"红绿路灯个数(self.__timecost:.0f)个") #获取路径总成本的只读属性 @property def dCost(self): return self.__distancecost @property def tCost(self): return self.__timecost class DirectedGraph(object): def __init__(self, d): if isinstance(d, dict): self.__graph = d else: self.__graph = dict() print('Sth error') #通过递归生成所有可能的路径 def __generatePath(self, graph, path, end, results, distancecostIndex, timecostIndex): current = path.getLastNode() if current == end: results.append(path) else: for n in graph[current]: if n not in path.path: self.__generatePath(graph, path.addNode(n,self.__graph[path.getLastNode()][n][distancecostIndex][timecostIndex]), end, results, distancecostIndex, timecostIndex) #搜索start到end之间时间或空间最短的路径,并输出 def __searchPath(self, start, end, distancecostIndex, timecostIndex): results = [] self.__generatePath(self.__graph, Path([start],0,0), end, results,distancecostIndex,timecostIndex) results.sort(key=lambda p: p.distanceCost) results.sort(key=lambda p: p.timeCost) print('The {} shortest path from '.format("spatially" if distancecostIndex==0 else "temporally"), start, ' to ', end, ' is:', end="") print('The {} shortest path from '.format("spatially" if timecostIndex==0 else "temporally"), start, ' to ', end, ' is:', end="") results[0].printPath() #调用__searchPath搜索start到end之间的空间最短的路径,并输出 def searchSpatialMinPath(self,start, end): self.__searchPath(start,end,0,0) #调用__searc 优化这个代码
最新发布
06-07
这段代码中可以进行如下优化: 1. 将类属性的命名改为下划线开头的私有属性,避免外部直接修改属性值,可以使用@property装饰器来获取属性的值。 2. 可以将类的构造函数中的参数改为可选参数,避免在初始化时需要传入参数,可以在需要时再传入参数。 3. 在递归函数的参数中加入默认值,避免在调用递归函数时需要传入所有参数。 4. 可以使用f-string格式化输出语句,代码更加简洁易读。 5. 可以使用列表推导式对列表进行过滤。 6. 在搜索最短路径时可以使用heapq模块的堆排序来优化效率。 优化后的代码如下: ```python import heapq class Path(object): def __init__(self, path = [], distance_cost = 0, time_cost = 0): self._path = path self._distance_cost = distance_cost self._time_cost = time_cost def get_last_node(self): return self._path[-1] def is_last_node(self, node): return node == self.get_last_node() def add_node(self, node, d_price, t_price): return Path(self._path + [node], self._distance_cost + d_price, self._time_cost + t_price) def print_path(self): for n in self._path: print(n, end="->" if not self.is_last_node(n) else "\n") print(f"最短路径距离({self._distance_cost:.0f}m)") print(f"红绿路灯个数({self._time_cost:.0f})个") @property def d_cost(self): return self._distance_cost @property def t_cost(self): return self._time_cost class DirectedGraph(object): def __init__(self, d = {}): if isinstance(d, dict): self._graph = d else: self._graph = dict() print('Something went wrong!') def __generate_path(self, graph, path, end, results, distance_cost_index=0, time_cost_index=0): current = path.get_last_node() if current == end: results.append(path) else: for n in graph[current]: if n not in path._path: self.__generate_path(graph, path.add_node(n, self._graph[path.get_last_node()][n][distance_cost_index][time_cost_index]), end, results, distance_cost_index, time_cost_index) def __search_path(self, start, end, distance_cost_index=0, time_cost_index=0): results = [] self.__generate_path(self._graph, Path([start], 0, 0), end, results, distance_cost_index, time_cost_index) results = heapq.nsmallest(1, results, key=lambda p: (p.d_cost, p.t_cost)) print(f"The {'spatially' if distance_cost_index==0 else 'temporally'} shortest path from {start} to {end} is:") results[0].print_path() def search_spatial_min_path(self, start, end): self.__search_path(start, end, 0, 0) def search_temporal_min_path(self, start, end): self.__search_path(start, end, 1, 1) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值