train network

def train_network(
    config,
    shuffle=1,
    trainingsetindex=0,
    max_snapshots_to_keep=5,
    displayiters=None,
    saveiters=None,
    maxiters=None,
    allow_growth=True,
    gputouse=None,
    autotune=False,
    keepdeconvweights=True,
    modelprefix="",
    superanimal_name="",
    superanimal_transfer_learning=False,
):
    if allow_growth:
        os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

    import tensorflow as tf

    # reload logger.
    import importlib
    import logging

    importlib.reload(logging)
    logging.shutdown()

    from deeplabcut.utils import auxiliaryfunctions

    tf.compat.v1.reset_default_graph()
    start_path = os.getcwd()

    # Read file path for pose_config file. >> pass it on
    cfg = auxiliaryfunctions.read_config(config) #2
    modelfoldername = auxiliaryfunctions.get_model_folder(
        cfg["TrainingFraction"][trainingsetindex], shuffle, cfg, modelprefix=modelprefix
    )
    poseconfigfile = Path(
        os.path.join(
            cfg["project_path"], str(modelfoldername), "train", "pose_cfg.yaml"
        )
    )
    if not poseconfigfile.is_file():
        print("The training datafile ", poseconfigfile, " is not present.")
        print(
            "Probably, the training dataset for this specific shuffle index was not created."
        )
        print(
            "Try with a different shuffle/trainingsetfraction or use function 'create_training_dataset' to create a new trainingdataset with this shuffle index."
        )
    else:
        # Set environment variables
        if (
            autotune is not False
        ):  # see: https://github.com/tensorflow/tensorflow/issues/13317
            os.environ["TF_CUDNN_USE_AUTOTUNE"] = "0"
        if gputouse is not None:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(gputouse)
    try:
        cfg_dlc = auxiliaryfunctions.read_plainconfig(poseconfigfile)

        if superanimal_name != "":
            from deeplabcut.modelzoo.utils import parse_available_supermodels
            from dlclibrary.dlcmodelzoo.modelzoo_download import (
                download_huggingface_model,
                MODELOPTIONS,
            )
            import glob

            dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()
            supermodels = parse_available_supermodels()
            weight_folder = str(
                Path(dlc_root_path)
                / "pose_estimation_tensorflow"
                / "models"
                / "pretrained"
                / (superanimal_name + "_weights")
            )

            if superanimal_name in MODELOPTIONS:
                if not os.path.exists(weight_folder):
                    download_huggingface_model(superanimal_name, weight_folder)
                else:
                    print(f"{weight_folder} exists, using the downloaded weights")
            else:
                print(
                    f"{superanimal_name} not available. Available ones are: ",
                    MODELOPTIONS,
                )

            snapshots = glob.glob(os.path.join(weight_folder, "snapshot-*.index"))
            init_weights = os.path.abspath(snapshots[0]).replace(".index", "")

            from deeplabcut.pose_estimation_tensorflow.core.train_multianimal import (
                train,
            )

            print("Selecting multi-animal trainer")
            train(
                str(poseconfigfile),
                displayiters,
                saveiters,
                maxiters,
                max_to_keep=max_snapshots_to_keep,
                keepdeconvweights=keepdeconvweights,
                allow_growth=allow_growth,
                init_weights=init_weights,
                remove_head=True
                if superanimal_name != "" and superanimal_transfer_learning
                else False,
            )  # pass on path and file name for pose_cfg.yaml!

        elif "multi-animal" in cfg_dlc["dataset_type"]:
            from deeplabcut.pose_estimation_tensorflow.core.train_multianimal import (
                train,
            )

            print("Selecting multi-animal trainer")
            train(
                str(poseconfigfile),
                displayiters,
                saveiters,
                maxiters,
                max_to_keep=max_snapshots_to_keep,
                keepdeconvweights=keepdeconvweights,
                allow_growth=allow_growth,
            )  # pass on path and file name for pose_cfg.yaml!
        else:
            from deeplabcut.pose_estimation_tensorflow.core.train import train

            print("Selecting single-animal trainer")
            train(
                str(poseconfigfile),
                displayiters,
                saveiters,
                maxiters,
                max_to_keep=max_snapshots_to_keep,
                keepdeconvweights=keepdeconvweights,
                allow_growth=allow_growth,
            )  # pass on path and file name for pose_cfg.yaml!

    except BaseException as e:
        raise e
    finally:
        os.chdir(str(start_path))
    print(
        "The network is now trained and ready to evaluate. Use the function 'evaluate_network' to evaluate the network."
    )

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值