问题描述:
本地的tensorflow部署到服务器时,tensorflow has no attribute “contrib”,这是因为服务器上的服务器是tensorflow2.0,移除了contrib,而本地的tensorflow版本是1.1.5。
解决办法:
- 降低tensorflow的版本到1.x,但是可能cuda有匹配了
- 用字典 HParams 替代 tf.contrib.training.HParams,对参数进行初始化,这样不需要更改cuda版本,十分方便。
1.x版本的contrib
def data_hparams():
params = tf.contrib.training.HParams(
# vocab
data_type='train',
data_path='data/',
thchs30=True,
)
return params
#读取数据
class get_data():
def __init__(self, args):
self.data_type = args.data_type
self.data_path = args.data_path
self.thchs30 = args.thchs30
#调用
get_data(data_hparams)
可以更改为:
data_hparams = {'data_type':'train',"data_path":'train/','thchs30'=True,}
class get_data():
def __init__(self, args):
self.data_type = args['data_type']
self.data_path = args["data_path"]
self.thchs30 = args["thchs30"]
'thchs30'
#调用
get_data(data_hparams )