Tensorflow/Keras 多线程同时训练多个模型

研究了很久,终于实现了多线程同时训练多个模型。

核心在于要理解TF里的Graph和Session。

废话不多说,直接上代码,看完代码就懂了!


class MyModel(object):

    def __init__(self):

        self.model1Thread = None
        self.model2Thread = None

        self.model1_graph = tf.Graph()
        self.model1_sess = tf.Session(graph=self.model1_graph)
      
        self.model2_graph = tf.Graph()
        self.model2_sess = tf.Session(graph=self.model2_graph)
        
        self.build_model1()
        self.build_model2()
    
    def build_model1(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():

                model = Model(inputs=[xxx], outputs=[xxx])

                model.compile()

                model._make_predict_function()
                return model

    def build_model2(self):
        with self.model2_sess.as_default():
            with self.model2_graph.as_default():

                model = Model(inputs=[xxx], outputs=[xxx])
                model.compile()

                model._make_predict_function()
                return model
        

    def predict(self):
    
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.predict([xxxx])

                

    def learn(self):
        
        self.model1Thread = threading.Thread(target=self.learn_model1,args=())
        self.model1Thread.setDaemon(True)
        self.model1Thread.start()

        self.model2Thread = threading.Thread(target=self.learn_model2,args=())
        self.model2Thread.setDaemon(True)
        self.model2Thread.start()

        self.generation = self.generation + 1
        self.flush_log()
    
    def learn_model1(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.fit([xxxx], [xxxx])

            
    def learn_model2(self, obs,reward):
        with self.model2_sess.as_default():
            with self.model2_graph.as_default():
                self.model2.fit([xxx], [xxx])
                  
    
    def save_weights(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.save_weights()

        with self.model2_sess.as_default():
            with self.model2_graph.as_default():
                self.model2.save_weights()

    def load_weights(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.load_weights()

        with self.model2_sess.as_default():
            with self.model2_graph.as_default():
                self.model2.load_weights()
    
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 点我我会动 设计师: 上身试试
应支付0元
点击重新获取
扫码支付

支付成功即可阅读