def _get_n_best_paf_graphs(
data,
metadata,
full_graph,
n_graphs=10,
root=None,
which="best",
ignore_inds=None,
metric="auc",
):
if which not in ("best", "worst"):
raise ValueError('`which` must be either "best" or "worst"')
(within_train, _), (between_train, _) = _calc_within_between_pafs(
data,
metadata,
train_set_only=True,
)
# Handle unlabeled bodyparts...
existing_edges = set(k for k, v in within_train.items() if v)
if ignore_inds is not None:
existing_edges = existing_edges.difference(ignore_inds)
existing_edges = list(existing_edges)
if not any(between_train.values()):
# Only 1 animal, let us return the full graph indices only
return ([existing_edges], dict(zip(existing_edges, [0] * len(existing_edges))))
scores, _ = zip(
*[
_calc_separability(between_train[n], within_train[n], metric=metric)
for n in existing_edges
]
)
# Find minimal skeleton
G = nx.Graph()
for edge, score in zip(existing_edges, scores):
if np.isfinite(score):
G.add_edge(*full_graph[edge], weight=score)
if which == "best":
order = np.asarray(existing_edges)[np.argsort(scores)[::-1]]
if root is None:
root = []
for edge in nx.maximum_spanning_edges(G, data=False):
root.append(full_graph.index(sorted(edge)))
else:
order = np.asarray(existing_edges)[np.argsort(scores)]
if root is None:
root = []
for edge in nx.minimum_spanning_edges(G, data=False):
root.append(full_graph.index(sorted(edge)))
n_edges = len(existing_edges) - len(root)
lengths = np.linspace(0, n_edges, min(n_graphs, n_edges + 1), dtype=int)[1:]
order = order[np.isin(order, root, invert=True)]
paf_inds = [root]
for length in lengths:
paf_inds.append(root + list(order[:length]))
return paf_inds, dict(zip(existing_edges, scores))
results = _benchmark_paf_graphs(
cfg,
inf_cfg_temp,
data,
paf_inds,
greedy,
add_discarded,
oks_sigma=oks_sigma,
margin=margin,
symmetric_kpts=symmetric_kpts,
calibration_file=calibration_file,
split_inds=[
metadata["data"]["trainIndices"],
metadata["data"]["testIndices"],
],
)
# Select optimal PAF graph
df = results[1]
size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"])
pose_config = inference_config.replace("inference_cfg", "pose_cfg")
if not overwrite_config:
shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml"))
inds = list(paf_inds[size_opt])
auxiliaryfunctions.edit_config(
pose_config, {"paf_best": [int(ind) for ind in inds]}
)
if output_name:
with open(output_name, "wb") as file:
pickle.dump([results], file)
return results[:3], paf_scores, results[3][size_opt]
def cross_validate_paf_graphs(
config,
inference_config,
full_data_file,
metadata_file,
output_name="",
pcutoff=0.1,
oks_sigma=0.1,
margin=0,
greedy=False,
add_discarded=True,
calibrate=False,
overwrite_config=True,
n_graphs=10,
paf_inds=None,
symmetric_kpts=None,
):
cfg = auxiliaryfunctions.read_config(config)
inf_cfg = auxiliaryfunctions.read_plainconfig(inference_config)
inf_cfg_temp = inf_cfg.copy()
inf_cfg_temp["pcutoff"] = pcutoff
with open(full_data_file, "rb") as file:
data = pickle.load(file)
with open(metadata_file, "rb") as file:
metadata = pickle.load(file)
params = _set_up_evaluation(data)
to_ignore = auxfun_multianimal.filter_unwanted_paf_connections(
cfg, params["paf_graph"]
)
best_graphs = _get_n_best_paf_graphs(
data,
metadata,
params["paf_graph"],
ignore_inds=to_ignore,
n_graphs=n_graphs,
)
paf_scores = best_graphs[1]
if paf_inds is None:
paf_inds = best_graphs[0]
if calibrate:
trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg)
calibration_file = os.path.join(
cfg["project_path"],
str(trainingsetfolder),
"CollectedData_" + cfg["scorer"] + ".h5",
)
else:
calibration_file = ""
results = _benchmark_paf_graphs(
cfg,
inf_cfg_temp,
data,
paf_inds,
greedy,
add_discarded,
oks_sigma=oks_sigma,
margin=margin,
symmetric_kpts=symmetric_kpts,
calibration_file=calibration_file,
split_inds=[
metadata["data"]["trainIndices"],
metadata["data"]["testIndices"],
],
)
# Select optimal PAF graph
df = results[1]
size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"])
pose_config = inference_config.replace("inference_cfg", "pose_cfg")
if not overwrite_config:
shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml"))
inds = list(paf_inds[size_opt])
auxiliaryfunctions.edit_config(
pose_config, {"paf_best": [int(ind) for ind in inds]}
)
if output_name:
with open(output_name, "wb") as file:
pickle.dump([results], file)
return results[:3], paf_scores, results[3][size_opt]