import sys
import tensorflow as tf
#from icon_reg_net import GoogleNet
from det_icon_reg_deploy import GoogleNet_Reg
from det_icon_cls_deploy import GoogleNet_Cls
import numpy as np
image_path = './1113014_7.jpg'
mean = np.array([148., 148., 148.])
new_shape = [224, 224]
def device_for_node(n):
if n.type == "MatMul":
return "/gpu:0"
else:
return "/cpu:0"
def test(image_path):
g = tf.Graph()
with g.as_default():
with g.device(device_for_node): #直接写 "/gpu:0" 会出问题,详见下述1
file_data = tf.read_file(image_path)
# Decode the image data
img = tf.image.decode_jpeg(file_data, channels=3)
#img = tf.reverse(img, [False, False, True])
img = tf.image.resize_images(img, new_shape[0], new_shape[1])
img = tf.to_float(img) - mean
with tf.Session(graph=g) as sess1:
#tf.initialize_all_variables().run()
print type(img), img.get_shape()
img = sess1.run(img) # 这里需要先执行这句,详见下述2
print type(img), img.shape
img = np.reshape(img, (1, 224, 224, 3))
input_node = tf.placeholder(tf.float32, shape=(None, new_shape[0], new_shape[0], 3))
net = GoogleNet_Reg({'data': input_node})
model_path = './det_icon_reg_iter_110000.npy'
net.load(data_path=model_path, session=sess1)
probs = sess1.run(net.get_output(), feed_dict={input_node: img})
print probs
pos = probs[0]
x1 = int(pos[0] * new_shape[0])
y1 = int(pos[1] * new_shape[1])
x2 = int(pos[2] * new_shape[0])
y2 = int(pos[3] * new_shape[1])
#tf.reset_default_graph() #如果前面没有g = tf.Graph(),那么如果不加上这句可能会出错,详见下述3
#g2 = tf.Graph()
roiimg = tf.slice(img, begin=tf.pack([0, x1, y1, 0]), size=tf.pack([1, x2-x1, y2-y1, 3]))
roiimg = tf.image.resize_images(roiimg, new_shape[0], new_shape[1])
#g = tf.Graph()
with tf.Session() as sess2:
roiimg = sess2.run(roiimg)
roiimg = np.reshape(img, (1, 224, 224, 3))
print g
print tf.get_default_graph()
test = tf.constant(1)
print test.graph #这里使用的是默认的图,tf.get_default_graph() == test.graph
input_node = tf.placeholder(tf.float32, shape=(None, new_shape[0], new_shape[0], 3))
net = GoogleNet_Cls({'data': input_node})
model_path = './det_icon_cls_iter_50000.npy'
net.load(data_path=model_path, session=sess2)
probs = sess2.run(net.get_output(), feed_dict={input_node: roiimg})
print probs
scores = probs[0]
rank = np.argsort(-scores)
print rank[0], scores[rank[0]]
if __name__ == '__main__':
if len(sys.argv)>1 :
print sys.argv
func = getattr(sys.modules[__name__], sys.argv[1])
func(*sys.argv[2:])
else:
print >> sys.stderr,'%s command' % (__file__)
1. 在指定使用GPU时,如果直接指定 "/gpu:0"
with g.device("/gpu:0"):
这样会报错,类似:
tensorflow.python.framework.errors.InvalidArgumentError: Cannot assign a device to node 'GradientDescent/update_Variable_2/ScatterSub': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0'
[[Node: GradientDescent/update_Variable_2/ScatterSub = ScatterSub[T=DT_FLOAT, Tindices=DT_INT64, use_locking=false](Variable_2, gradients/concat_1, GradientDescent/update_Variable_2/mul)]]
Caused by op u'GradientDescent/update_Variable_2/ScatterSub', defined at:
原因是并不是图中的所有的操作都支持GPU运算:
It seems a whole bunch of operations used in this example aren't supported on a GPU. A quick workaround is to restrict operations such that only matrix muls are ran on the GPU.
There's an example in the docs: http://tensorflow.org/api_docs/python/framework.md
See the section on tf.Graph.device(device_name_or_function)
简答的办法是指定只有矩阵乘法才在GPU上进行:def device_for_node(n):
if n.type == "MatMul":
return "/gpu:0"
else:
return "/cpu:0"
with graph.as_default():
with graph.device(device_for_node):
...
2. 在没有执行 img = sess1.run(img) 前,img 是 tf.read_file 然后 tf.image.decode_jpeg 后得到的,但是这里都是添加到图中的操作,并没有真正被执行,此时 img 的类型是
<class 'tensorflow.python.framework.ops.Tensor'> (224, 224, 3)
在 img = sess1.run(img) 后,img 变成了
<type 'numpy.ndarray'> (224, 224, 3)
3. 所有操作如果不指定图,则会使用默认图 tf.get_default_graph() ,上述代码中加载了两个模型,如果两个模型中出现里相同的name,就会出错。
g = tf.Graph()
with g.as_default():
...
<pre name="code" class="python">with tf.Session(graph=g) as sess1:
...
上面的方式将第一个模型放在了图g中,session执行的是图g,而不是默认图,这样后面就可以不用显示指定图,直接使用默认图。如果没有使用图g, 那么在第二个模型时需要先将默认图重置以清空默认图中之前的添加的ops
tf.reset_default_graph()