def train_network(
config,
shuffle=1,
trainingsetindex=0,
max_snapshots_to_keep=5,
displayiters=None,
saveiters=None,
maxiters=None,
allow_growth=True,
gputouse=None,
autotune=False,
keepdeconvweights=True,
modelprefix="",
superanimal_name="",
superanimal_transfer_learning=False,
):
if allow_growth:
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import tensorflow as tf
# reload logger.
import importlib
import logging
importlib.reload(logging)
logging.shutdown()
from deeplabcut.utils import auxiliaryfunctions
tf.compat.v1.reset_default_graph()
start_path = os.getcwd()
# Read file path for pose_config file. >> pass it on
cfg = auxiliaryfunctions.read_config(config) #2
modelfoldername = auxiliaryfunctions.get_model_folder(
cfg["TrainingFraction"][trainingsetindex], shuffle, cfg, modelprefix=modelprefix
)
poseconfigfile = Path(
os.path.join(
cfg["project_path"], str(modelfoldername), "train", "pose_cfg.yaml"
)
)
if not poseconfigfile.is_file():
print("The training datafile ", poseconfigfile, " is not present.")
print(
"Probably, the training dataset for this specific shuffle index was not created."
)
print(
"Try with a different shuffle/trainingsetfraction or use function 'create_training_dataset' to create a new trainingdataset with this shuffle index."
)
else:
# Set environment variables
if (
autotune is not False
): # see: https://github.com/tensorflow/tensorflow/issues/13317
os.environ["TF_CUDNN_USE_AUTOTUNE"] = "0"
if gputouse is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gputouse)
try:
cfg_dlc = auxiliaryfunctions.read_plainconfig(poseconfigfile)
if superanimal_name != "":
from deeplabcut.modelzoo.utils import parse_available_supermodels
from dlclibrary.dlcmodelzoo.modelzoo_download import (
download_huggingface_model,
MODELOPTIONS,
)
import glob
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()
supermodels = parse_available_supermodels()
weight_folder = str(
Path(dlc_root_path)
/ "pose_estimation_tensorflow"
/ "models"
/ "pretrained"
/ (superanimal_name + "_weights")
)
if superanimal_name in MODELOPTIONS:
if not os.path.exists(weight_folder):
download_huggingface_model(superanimal_name, weight_folder)
else:
print(f"{weight_folder} exists, using the downloaded weights")
else:
print(
f"{superanimal_name} not available. Available ones are: ",
MODELOPTIONS,
)
snapshots = glob.glob(os.path.join(weight_folder, "snapshot-*.index"))
init_weights = os.path.abspath(snapshots[0]).replace(".index", "")
from deeplabcut.pose_estimation_tensorflow.core.train_multianimal import (
train,
)
print("Selecting multi-animal trainer")
train(
str(poseconfigfile),
displayiters,
saveiters,
maxiters,
max_to_keep=max_snapshots_to_keep,
keepdeconvweights=keepdeconvweights,
allow_growth=allow_growth,
init_weights=init_weights,
remove_head=True
if superanimal_name != "" and superanimal_transfer_learning
else False,
) # pass on path and file name for pose_cfg.yaml!
elif "multi-animal" in cfg_dlc["dataset_type"]:
from deeplabcut.pose_estimation_tensorflow.core.train_multianimal import (
train,
)
print("Selecting multi-animal trainer")
train(
str(poseconfigfile),
displayiters,
saveiters,
maxiters,
max_to_keep=max_snapshots_to_keep,
keepdeconvweights=keepdeconvweights,
allow_growth=allow_growth,
) # pass on path and file name for pose_cfg.yaml!
else:
from deeplabcut.pose_estimation_tensorflow.core.train import train
print("Selecting single-animal trainer")
train(
str(poseconfigfile),
displayiters,
saveiters,
maxiters,
max_to_keep=max_snapshots_to_keep,
keepdeconvweights=keepdeconvweights,
allow_growth=allow_growth,
) # pass on path and file name for pose_cfg.yaml!
except BaseException as e:
raise e
finally:
os.chdir(str(start_path))
print(
"The network is now trained and ready to evaluate. Use the function 'evaluate_network' to evaluate the network."
)