glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list);该方法需要一个参数用来指定匹配的路径字符串(字符串可以为绝对路径也可以为相对路径),其返回的文件名只包括当前目录里的文件名,不包括子文件夹里的文件。
eg1: glob.glob(r’c:*.txt’)
获取C盘下所有txt文件
eg2: glob.glob(r’E:\pic**.jpg’)
获得指定目录下的所有jpg文件
eg3: glob.glob(r’../*.py’)
使用相对路径
def read_imgs_masks(args):
paths_img = glob.glob(args.images+'/*.*[gG]')
paths_mask = glob.glob(args.masks+'/*.*[gG]')
paths_img = sort(paths_img)
paths_mask = sort(paths_mask)
print('#imgs: ' + str(len(paths_img)))
print('#imgs: ' + str(len(paths_mask)))
print(paths_img)
print(paths_mask)
return paths_img, paths_mask
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.images = '../samples/testset' # input image directory
args.masks = '../samples/maskset' # input mask director
args.output_dir = './results' # output directory
args.multiple = 6 # multiples of image resizing
paths_img, paths_mask = read_imgs_masks(args)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with tf.Graph().as_default():
with open('./pb/hifill.pb', "rb") as f:
output_graph_def = tf.GraphDef()
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
image_ph = sess.graph.get_tensor_by_name('img:0')
mask_ph = sess.graph.get_tensor_by_name('mask:0')
inpainted_512_node = sess.graph.get_tensor_by_name('inpainted:0')
attention_node = sess.graph.get_tensor_by_name('attention:0')
mask_512_node = sess.graph.get_tensor_by_name('mask_processed:0')
for path_img, path_mask in zip(paths_img, paths_mask):
raw_img = cv2.imread(path_img)
raw_mask = cv2.imread(path_mask)
inpainted = inpaint(raw_img, raw_mask, sess, inpainted_512_node, attention_node, mask_512_node, image_ph, mask_ph, args.multiple)
filename = args.output_dir + '/' + os.path.basename(path_img)
cv2.imwrite(filename + '_inpainted.jpg', inpainted)