解决Semantic-Segmentation-Suite套件预测结果为全黑

最近在做一个项目尝试用分割的方式解决 利用 https://github.com/GeorgeSeif/Semantic-Segmentation-Suite 套件 做 星星比较多(以为比较靠谱,虽然作者已经不维护了但是确实代码中存在bug)

问题 训练正常loss也正常下降但是云顶predict.py结果全为黑

心路历程. 最开始调试认为是 模型没有加载成功导致.后来发现 模型确实没有加载成功 在训练阶段model_builder.build_model 中 is_training 为 True 测试阶段 istraining 为False 虽然这样写是正常的逻辑 但是 我用的是moilenetv2 + deeplab+v3 moilenetv2 中 有很多BN层 这样就导致 在restore 的时候 如果设置 istraining False BN层的参数不能正确加载 .如果在 predict 中 istraining = True 可以正常推理.但是这并不是我想要的啊.我在推理的时候的肯定是转pbmodle 然后这BN层参数不固定怎么可以 .当然后续我也转到MNN中才发现这个问题.

解决方法

在写优化器的时候 替换
原文:99行
opt = tf.train.RMSPropOptimizer(learning_rate=0.0001, decay=0.995).minimize(loss, var_list=[var for var in tf.trainable_variables()])
修改后:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
opt = tf.train.RMSPropOptimizer(learning_rate=0.0001, decay=0.995).minimize(loss, var_list=[var for var in tf.trainable_variables()])

这样 BN层的参数可以正确更新

另外保存的时候也要保存修改

原文:101行
saver=tf.train.Saver(max_to_keep=1000)
替换为:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if ‘moving_mean’ in g.name]
bn_moving_vars += [g for g in g_list if ‘moving_variance’ in g.name]
var_list += bn_moving_vars
saver=tf.train.Saver(var_list=var_list,max_to_keep=5)

这样我们就可以正确更新BN层参数 并且 保存BN层参数.在预测的时候可以设置 istraining=False啦

如何正确从加载模型的方式保存pb

这部分没什么就直接poll 全部的代码了

import os,time,cv2, sys, math
import tensorflow as tf
import argparse
import numpy as np
from tensorflow.python import pywrap_tensorflow
from utils import utils, helpers
from builders import model_builder
import os
from tensorflow.python.framework import graph_util

os.environ[“CUDA_VISIBLE_DEVICES”] = “1”

def freeze_graph(sess, output_graph):
‘’’
:param input_checkpoint:
:param output_graph: PB
:return:
‘’’
# checkpoint = tf.train.get_checkpoint_state(model_folder)
# input_checkpoint = checkpoint.model_checkpoint_path

output_node_names = "logits/BiasAdd"
output_graph_def = graph_util.convert_variables_to_constants(  
    sess=sess,
    input_graph_def=sess.graph_def,
    output_node_names=output_node_names.split(","))  

with tf.gfile.GFile(output_graph, "wb") as f:  
    f.write(output_graph_def.SerializeToString()) 
print("%d ops in the final graph." % len(output_graph_def.node))  

parser = argparse.ArgumentParser()
parser.add_argument(’–image’, type=str, default=’’, help=‘The image you want to predict on. ‘)
parser.add_argument(’–checkpoint_path’, type=str, default=’’, help=‘The path to the latest checkpoint weights for your model.’)
parser.add_argument(’–crop_height’, type=int, default=512, help=‘Height of cropped input image to network’)
parser.add_argument(’–crop_width’, type=int, default=512, help=‘Width of cropped input image to network’)
parser.add_argument(’–model’, type=str, default=‘DeepLabV3_plus’, help=‘The model you are using’)
parser.add_argument(’–dataset’, type=str, default="", required=False, help=‘The dataset you are using’)

args = parser.parse_args()

class_names_list, label_values = helpers.get_label_info(os.path.join(args.dataset, “class_dict.csv”))

num_classes = len(label_values)

print("\n***** Begin prediction *****")
print(“Dataset -->”, args.dataset)
print(“Model -->”, args.model)
print(“Crop Height -->”, args.crop_height)
print(“Crop Width -->”, args.crop_width)
print(“Num Classes -->”, num_classes)
print(“Image -->”, args.image)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)
#net_input = tf.placeholder(tf.float32,shape=[1,512,512,3])#这个位置固定输入是为了转mnn 用的
net_input = tf.placeholder(tf.float32,shape=[1,512,512,3])#这个位置固定输入是为了转mnn 用的
net_output = tf.placeholder(tf.float32,shape=[None,None,None,num_classes])

network, _ = model_builder.build_model(args.model,
frontend=“MobileNetV2”,
net_input=net_input,
num_classes=num_classes,
crop_width=args.crop_width,
crop_height=args.crop_height,
is_training=False)

sess.run(tf.global_variables_initializer())

print(‘Loading model checkpoint weights’)

saver=tf.train.Saver(max_to_keep=1000)
saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_path))

out_pb_path=“pbmodel/frozen_model.pb”

freeze_graph(sess,out_pb_path)

参考:

https://www.cnblogs.com/baijing1/p/9842321.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值