tensorflow1.x版本代码迁移到2.0

 由于3090显卡只支持tf2.0以后的版本,而且随着显卡的更新换代,tf1.x版本也不支持更高级的显卡,所以有必要将1.x版本的代码转成2.0后的版本。

Tf2.0版本和tf1.0版本的主要区别

  主要区别在于tf1.x是静态图,需要先把模型结构先定好,再进行训练

Tf2.0版本则是动态图,训练前不用先构建完整的结构,而是按流程一步步构建,因此在训练的时候tf1.x相比于tf2.0占cpu内存大,训练的速度更快

代码转换主要分几个方面:输入、模型网络、训练、模型保存

1.输入

  在1.x的代码中,对于输入需要首先加placeholder,作为整个网络的入口。而tf2.0取消了这个部分,因此修改的方法是去掉这部分代码,直接在训练的时候输入数据,例如:  

Tf1.x:

1

2

3

self.inputs = tf.placeholder(tf.int32, [NoneNone], name="inputs")  # 数据输入

self.labels = tf.placeholder(tf.float32, [NoneNone], name="labels")  # 标签

  

修改后直接在训练的时候赋值就行:

1

2

3

4

5

self.inputs = batch["x"]

self.labels = batch["y"]

self.keep_prob = dropout_prob

  

2.模型网络

这部分比较好改,因为很多api可以在tensorflow官方文档上找到相应的替换函数,几个常用的如下:

tf.get_variable()变成tf.variable()

Initializer的改变

1

2

3

4

5

# embedding_w = tf.compat.v1.get_variable("embedding_w", shape=[self.vocab_size, self.config["embedding_size"]],

#initializer=tf.compat.v1.contrib.layers.xavier_initializer())

embedding_w = tf.Variable(tf.keras.initializers.glorot_normal()(shape=[self.vocab_size, self.config["embedding_size"]],dtype=tf.float32), name='embedding')

   

3.训练

训练过程包括梯度的操作、优化算法的选择,主要的操作如下:

模型训练要继承tf.Module这个api,因为训练的时候要选择状态容器以便存储模型的参数,如果用keras或estimator模块写模型也可以继承其他的api,具体的继承规则可以参考这个树形结构:

https://zhuanlan.zhihu.com/p/73575776

Trackable

  |

  |-- tf.Variable

  |

  |-- MutableHashTable

  |

  |-- AutoTrackable

        |

        |-- ListWrapper/DictWrapper

        |

        |-- tf.train.Checkpoint

        |

        |-- tf.Module

                |

                |-- tf.keras.layers.Layer

                        |

                        |-- tf.keras.Model

                                |

                                |-- tf.keras.Sequential

几种状态容器的选择准则一般为:

仅在学习和深入研究状态容器(或基于对象的储存)时使用Trackable和AutoTrackable

tf.Module: 适合自定义训练循环时使用

tf.keras.layers.Layer:适合实现一些中间层,比如Attention之类的,可以配合tf.keras.Sequential使用,极少看见大的模型继承自这个类型。

tf.keras.Model:适合一些固定套路的模型(使用compile + fit)。虽然也可以自定义训练循环,但是有一种杀鸡用牛刀的感觉。

tf.keras.Sequential:适合一条路走到黑的(子)模型。

选择完状态容器后则要进行对应的训练循环,也就是梯度下降的操作:

Tf1.x首先定义好train_op,然后session.run

Tf2.0则直接在epoch循环内使用

1

2

3

4

5

with tf.GradientTape() as t:

    grads = t.gradient(self.model.loss, self.model.trainable_variables)

optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

也就是将sess.run里面的操作换成一步步执行的函数流程

4.模型保存

  Tf1.x和2.0的模型保存变化不大,都可以保存成checkepoint和pb这两种格式,根据文档将api换一下就可以了,但是需要注意的是保存的模型加载的时候版本需要和之前一致,否则在模型加载的时候可能会报错。Summary的保存也是一样,需要把api替换掉。

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 很抱歉,TensorFlow 2.0 中已经删除了 `tensorflow.contrib` 模块,因此不能直接导入 `tensorflow.contrib.learn`。不过,您可以使用 `TensorFlow 2.0` 中内置的 `tf.keras` 模块,或者使用 `TensorFlow Hub` 中的预训练模型。 ### 回答2: 要导入tensorflow.contrib.learn,您需要使用tensorflow 2.0的兼容性模块tf.compat.v1。在TensorFlow 2.0中,tf.contrib模块已被移除。然而,通过tf.compat.v1模块,您仍然可以使用一些tensorflow.contrib模块中的功能。 您可以按照以下步骤来导入tensorflow.contrib.learn: 1. 导入所需的模块: ```python import tensorflow.compat.v1 as tf from tensorflow.compat.v1 import contrib ``` 2. 启用兼容性模式: ```python tf.disable_v2_behavior() ``` 3. 现在您可以使用tf.contrib.learn及其功能: ```python contrib.learn.Estimator(...) ``` 注意:虽然这种方法使您能够导入tensorflow.contrib.learn,但由于tf.compat.v1模块是为了向后兼容而设计的,因此它可能在将来的版本中被删除。因此,最好尽量使用tensorflow 2.0的原生API。如果您使用tensorflow.contrib.learn的功能非常重要,您可以考虑使用旧版本tensorflow(如tensorflow 1.15)来支持它。 ### 回答3: 在TensorFlow 2.0中,已经不再支持`tensorflow.contrib.learn`这个模块。`tensorflow.contrib`是一个容纳实验性、不太稳定或较少使用的功能和功能组件的命名空间,而且在TensorFlow 1.X版本中是存在的。在TensorFlow 2.0中,TensorFlow团队已经将这些组件整合到了其他模块中,或者将它们作为独立的项目进行维护。因此,如果你想在TensorFlow 2.0中使用`tensorflow.contrib.learn`,你将无法直接导入它。 如果你仍然想使用类似于`tensorflow.contrib.learn`的某些功能,可以考虑以下方法: 1. 使用TensorFlow 2.0官方文档中提供的迁移指南,查找替代`tensorflow.contrib.learn`的功能或模块。官方文档通常会提供有关如何将旧版本代码迁移TensorFlow 2.0的详细说明。 2. 如果你只是需要用到一些机器学习算法,你可以考虑使用`scikit-learn`这个Python库。它是一个专门用于机器学习的库,提供了丰富的算法和工具,可以与TensorFlow 2.0进行结合使用。 总之,在TensorFlow 2.0中,将不再直接支持导入`tensorflow.contrib.learn`。如果你有特定的需求,需要找到替代的方法来实现你的目标。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值