使用spark分布式环境,训练和加载tensorflow。通常会将加载的模型广播出去,
这个时候涉及两个问题:
1.加载的模型的路径是hdfs,tensorflow.SavedModelBundle的load使用的是本地路径,所以需要使用sc.addFiles("",true),这个样保证hdfs数据get到环境路径下(本地路径);
2.同时,加载的模型往往需要广播到各个节点,但是这个时候也容易报错,可以采用在udf内部或者各个partition中加载模型。
补充:如果spark将hdfscopy到本地,从本地加载也是可以的(最好加上随机数,防止多次任务间干扰),这样再使用广播变量。(注意:别忘记将该路径删掉)