EfficientNetV2 Finetune on CIFAR10出现head_1/dense/bias is not found in ckpt

  1. 起因:我是个小白,想在自己的图片数据集上使用efficientnetv2,就先运行了GitHub上的示例代码(需要指定其中的路径,这个等下再细说):
    python main.py --mode=train  --model_name=efficientnetv2-s  --dataset_cfg=cifar10Ft --model_dir=$DIR --hparam_str="train.ft_init_ckpt=$PRETRAIN_CKPT_PATH"

    但是会出现在ckpt文件里找不到head_1/dense/bias或者head_1/dense/kernel的问题,在这之前还出现过类似的找不到。

  2. 失败尝试:

    1. 修改命令里的路径:在GitHub上下载了好几个ckpt压缩包,拿里面的ckpt和model文件来运行,但无论哪个压缩包都不对

    2. 检查代码:输出下载的ckpt的变量名,想看看究竟有没有,结果是,有head的,没有head_1的。偶然看到一篇博客说如果多次载入一张graph,就会出现由于一个变量名已经存在了,于是会创建一个新的变量,类似于已经有weight,就会在第二次读取graph的时候命名为weight_1,如果要避免,在每次读完graph之后就添加一句ops.reset_default_graph()。于是我拼命找,找"with"、找"session",每个函数都看过去,到底哪里多次载入了?我找不准,添加的这句要么是没有报错,但是找不到变量的错误仍然在,要么就是说不可以中途打断nested graph,建议我新建一张graph。我再输出代码里面与ckpt里变量有关的trainable variable,发现head_1是本应该有的,不是多次载入graph的结果,因为列表里dense/bias和dense/kernel就只有

      <tf.Variable 'efficientnetv2-s/head_1/dense/kernel:0' shape=(1280, 1000) dtype=float32>, <tf.Variable 'efficientnetv2-s/head_1/dense/bias:0' shape=(1000,) dtype=float32>

      这两个变量,如果多读的话,是应该至少有四个的,也就是efficientnetv2-s/head/dense/kernel、efficientnetv2-s/head/dense/bias、efficientnetv2-s/head_1/dense/kernel、efficientnetv2-s/head_1/dense/bias这样。但是没有,所以不是多次载入的问题。

  3. 成功的转机:服务器连不上了,重启了一次,重新连上之后有些包找不到了,为了兼容就给好些包都换了版本,之后再运行发现代码跑通了!!出现了类似下面这样的刷屏信息:

    INFO:tensorflow:global_step/sec: 11.9026
    I0327 20:51:49.209239 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.9026
    INFO:tensorflow:examples/sec: 6094.11
    I0327 20:51:49.209511 139832147883200 tpu_estimator.py:2392] examples/sec: 6094.11
    INFO:tensorflow:global_step/sec: 11.8865
    I0327 20:51:49.293369 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.8865
    INFO:tensorflow:examples/sec: 6085.9
    I0327 20:51:49.293624 139832147883200 tpu_estimator.py:2392] examples/sec: 6085.9
    INFO:tensorflow:global_step/sec: 11.7802
    I0327 20:51:49.378258 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.7802
    INFO:tensorflow:examples/sec: 6031.48
    I0327 20:51:49.378525 139832147883200 tpu_estimator.py:2392] examples/sec: 6031.48
    INFO:tensorflow:global_step/sec: 11.8018
    I0327 20:51:49.462988 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.8018
    INFO:tensorflow:examples/sec: 6042.54
    I0327 20:51:49.463258 139832147883200 tpu_estimator.py:2392] examples/sec: 6042.54
    INFO:tensorflow:global_step/sec: 11.8562
    I0327 20:51:49.547334 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.8562
    INFO:tensorflow:examples/sec: 6070.35
    I0327 20:51:49.547589 139832147883200 tpu_estimator.py:2392] examples/sec: 6070.35
    INFO:tensorflow:global_step/sec: 11.8676
    I0327 20:51:49.631590 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.8676
    INFO:tensorflow:examples/sec: 6076.22
    I0327 20:51:49.631861 139832147883200 tpu_estimator.py:2392] examples/sec: 6076.22
    INFO:tensorflow:global_step/sec: 11.8057
    I0327 20:51:49.716304 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.8057
    
    
    
    ……
    
    
    
    I0327 20:52:10.817095 139832147883200 tpu_estimator.py:2391] global_step/sec: 11.751
    INFO:tensorflow:examples/sec: 6016.51
    I0327 20:52:10.817527 139832147883200 tpu_estimator.py:2392] examples/sec: 6016.51
    INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000...
    I0327 20:52:10.901108 139832147883200 basic_session_run_hooks.py:627] Calling checkpoint listeners before saving checkpoint 10000...
    INFO:tensorflow:Saving checkpoints for 10000 into /mnt/data2/zjy/efficientnetv2/model.ckpt.
    I0327 20:52:10.901316 139832147883200 basic_session_run_hooks.py:632] Saving checkpoints for 10000 into /mnt/data2/zjy/efficientnetv2/model.ckpt.
    INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000...
    I0327 20:52:12.875514 139832147883200 basic_session_run_hooks.py:639] Calling checkpoint listeners after saving checkpoint 10000...
    INFO:tensorflow:global_step/sec: 0.484929
    I0327 20:52:12.879274 139832147883200 tpu_estimator.py:2391] global_step/sec: 0.484929
    INFO:tensorflow:examples/sec: 248.284
    I0327 20:52:12.879500 139832147883200 tpu_estimator.py:2392] examples/sec: 248.284
    INFO:tensorflow:Loss for final step: 3.2697053.
    I0327 20:52:13.525003 139832147883200 estimator.py:361] Loss for final step: 3.2697053.
    INFO:tensorflow:training_loop marked as finished
    I0327 20:52:13.526473 139832147883200 error_handling.py:115] training_loop marked as finished

    所以总结一下自己错误的地方:

    1. 命令里的路径错误,下载好GitHub的代码之后,给的命令像下面这样就可以了,我理解的是这两个路径是用来给代码存放生成的数据的,不需要下载GitHub的ckpt来供给:

      python main.py --mode=train --model_name=efficientnetv2-s --dataset_cfg=cifar10Ft --model_dir=/mnt/data2/efficientnetv2 --hparam_str="train.ft_init_ckpt=/mnt/data2/efficientnetv2/model"

      其中的/mnt/data2是我存放efficientv2代码文件夹的路径

    2. 可能有些包的版本不对,贴一下能成功运行的,我觉得主要是tensorflow和numpy的版本号:

      ~$ pip list
      Package                      Version
      ---------------------------- -------------------
      absl-py                      1.4.0
      array-record                 0.5.0
      astunparse                   1.6.3
      cachetools                   5.3.3
      certifi                      2024.2.2
      charset-normalizer           3.3.2
      click                        8.1.7
      contourpy                    1.2.0
      cycler                       0.11.0
      dill                         0.3.8
      dm-tree                      0.1.8
      etils                        1.5.2
      filelock                     3.13.1
      flatbuffers                  24.3.25
      fonttools                    4.25.0
      fsspec                       2024.3.1
      gast                         0.4.0
      google-auth                  2.29.0
      google-auth-oauthlib         1.0.0
      google-pasta                 0.2.0
      googleapis-common-protos     1.63.0
      grpcio                       1.62.1
      h5py                         3.1.0
      idna                         3.6
      importlib_metadata           7.1.0
      importlib-resources          6.1.1
      Jinja2                       3.1.3
      joblib                       1.3.2
      keras                        2.13.1
      keras-nightly                2.5.0.dev2021032900
      Keras-Preprocessing          1.1.2
      kiwisolver                   1.4.4
      libclang                     18.1.1
      Markdown                     3.6
      markdown-it-py               3.0.0
      MarkupSafe                   2.1.5
      matplotlib                   3.8.0
      mdurl                        0.1.2
      mkl-fft                      1.3.8
      mkl-random                   1.2.4
      mkl-service                  2.4.0
      ml-dtypes                    0.3.2
      mpmath                       1.3.0
      munkres                      1.1.4
      namex                        0.0.7
      networkx                     3.2.1
      numpy                        1.23.5
      nvidia-cublas-cu12           12.1.3.1
      nvidia-cuda-cupti-cu12       12.1.105
      nvidia-cuda-nvrtc-cu12       12.1.105
      nvidia-cuda-runtime-cu12     12.1.105
      nvidia-cudnn-cu12            8.9.2.26
      nvidia-cufft-cu12            11.0.2.54
      nvidia-curand-cu12           10.3.2.106
      nvidia-cusolver-cu12         11.4.5.107
      nvidia-cusparse-cu12         12.1.0.106
      nvidia-nccl-cu12             2.19.3
      nvidia-nvjitlink-cu12        12.4.99
      nvidia-nvtx-cu12             12.1.105
      oauthlib                     3.2.2
      opt-einsum                   3.3.0
      optree                       0.10.0
      packaging                    23.2
      pandas                       2.1.0
      pillow                       10.2.0
      pip                          24.0
      ply                          3.11
      promise                      2.3
      protobuf                     3.20.3
      psutil                       5.9.8
      pyasn1                       0.5.1
      pyasn1-modules               0.3.0
      Pygments                     2.17.2
      pyparsing                    3.0.9
      PyQt5                        5.15.10
      PyQt5-sip                    12.13.0
      python-dateutil              2.8.2
      pytz                         2024.1
      PyYAML                       6.0.1
      requests                     2.31.0
      requests-oauthlib            2.0.0
      rich                         13.7.1
      rsa                          4.9
      scikit-learn                 1.4.1.post1
      scipy                        1.12.0
      setuptools                   69.2.0
      sip                          6.7.12
      six                          1.15.0
      sympy                        1.12
      tensorboard                  2.13.0
      tensorboard-data-server      0.7.2
      tensorboard-plugin-wit       1.8.1
      tensorflow                   2.13.0
      tensorflow-addons            0.23.0
      tensorflow-datasets          4.5.0
      tensorflow-estimator         2.13.0
      tensorflow-io-gcs-filesystem 0.36.0
      tensorflow-metadata          1.14.0
      termcolor                    1.1.0
      tf-estimator-nightly         2.8.0.dev2021122109
      threadpoolctl                3.4.0
      toml                         0.10.2
      tomli                        2.0.1
      torch                        2.2.1
      tornado                      6.3.3
      tqdm                         4.66.2
      triton                       2.2.0
      typeguard                    2.13.3
      typing-extensions            3.7.4.3
      tzdata                       2024.1
      urllib3                      2.2.1
      Werkzeug                     3.0.1
      wheel                        0.43.0
      wrapt                        1.12.1
      zipp                         3.18.1

    3. 就酱,虽然还要看看怎么在自己的图片数据集上跑,但是已经很开心了!!(卡了快俩月了

——————————————————————————————————————

1、出现问题:

上次写的时候是3月27日,现在过了一周多,有写好自己的数据集了,想让efficientnetv2跑跑我自己写好的tfds格式的数据,结果又出现最最开始时不提供ckpt等文件的 “ckpt文件找不到” 的错误,即使换回用cifar10Ft也是这样。然后就像上次说的,将github上给出的ckpt放到efficientnetv2的文件夹里面,再运行 “python main.py --mode=train --model_name=efficientnetv2-s --dataset_cfg=cifar10Ft --model_dir=/mnt/data2/efficientnetv2 --hparam_str="train.ft_init_ckpt=/mnt/data2/efficientnetv2/model” 这个命令,但是又出现了找不到head_1的问题。感觉又回到了起点,好难过TAT

2、尝试:

想起上次卡住的时候有试过不给hparam_str 直接运行“python main.py --mode=train --model_name=efficientnetv2-s --dataset_cfg=cifar10Ft --model_dir=/mnt/data2/efficientnetv2” 这个命令 来生成自己的ckpt,这次也先这样,看看会怎样。然后,训练成功,training_loop marked as finished。

之后想着已经生成ckpt了,加上hparam_str看看会怎样,结果出现 “NaN loss during Training” 这样的问题,很是奇怪啊,之前运行成功都没试过的。删掉ckpt再来、重启vs code再来、重新填写launch.json,都不行,甚至中间出现过 “GCS is not accessible”这样的问题(这个看到stack overflow有人说可以换个国家的ip,我就到clash里给梯子换了个节点,解决了)。没想到,这clash节点更换 顺带解决了ckpt找不到的问题!

同样不用给github上的ckpt文件,直接执行“ python main.py --mode=train --model_name=efficientnetv2-s --dataset_cfg=cifar10Ft --model_dir=/mnt/data2/efficientnetv2 --hparam_str="train.ft_init_ckpt=/mnt/data2/efficientnetv2/model" ”可以运行成功。

3、小结:

模型finetune失败,除了上次说的包的版本问题,也有可能是网络问题,建议:检查wifi或网线有没有问题、确保梯子可用(刷新一下、换换节点【我是梯子可用,可以上google,但是也会不行,所以记得换换节点,并且重启一下vs code】)

  • 15
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值