best graph

def _get_n_best_paf_graphs(
    data,
    metadata,
    full_graph,
    n_graphs=10,
    root=None,
    which="best",
    ignore_inds=None,
    metric="auc",
):
    if which not in ("best", "worst"):
        raise ValueError('`which` must be either "best" or "worst"')

    (within_train, _), (between_train, _) = _calc_within_between_pafs(
        data,
        metadata,
        train_set_only=True,
    )
    # Handle unlabeled bodyparts...
    existing_edges = set(k for k, v in within_train.items() if v)
    if ignore_inds is not None:
        existing_edges = existing_edges.difference(ignore_inds)
    existing_edges = list(existing_edges)

    if not any(between_train.values()):
        # Only 1 animal, let us return the full graph indices only
        return ([existing_edges], dict(zip(existing_edges, [0] * len(existing_edges))))

    scores, _ = zip(
        *[
            _calc_separability(between_train[n], within_train[n], metric=metric)
            for n in existing_edges
        ]
    )

    # Find minimal skeleton
    G = nx.Graph()
    for edge, score in zip(existing_edges, scores):
        if np.isfinite(score):
            G.add_edge(*full_graph[edge], weight=score)
    if which == "best":
        order = np.asarray(existing_edges)[np.argsort(scores)[::-1]]
        if root is None:
            root = []
            for edge in nx.maximum_spanning_edges(G, data=False):
                root.append(full_graph.index(sorted(edge)))
    else:
        order = np.asarray(existing_edges)[np.argsort(scores)]
        if root is None:
            root = []
            for edge in nx.minimum_spanning_edges(G, data=False):
                root.append(full_graph.index(sorted(edge)))

    n_edges = len(existing_edges) - len(root)
    lengths = np.linspace(0, n_edges, min(n_graphs, n_edges + 1), dtype=int)[1:]
    order = order[np.isin(order, root, invert=True)]
    paf_inds = [root]
    for length in lengths:
        paf_inds.append(root + list(order[:length]))
    return paf_inds, dict(zip(existing_edges, scores))
results = _benchmark_paf_graphs(
        cfg,
        inf_cfg_temp,
        data,
        paf_inds,
        greedy,
        add_discarded,
        oks_sigma=oks_sigma,
        margin=margin,
        symmetric_kpts=symmetric_kpts,
        calibration_file=calibration_file,
        split_inds=[
            metadata["data"]["trainIndices"],
            metadata["data"]["testIndices"],
        ],
    )
    # Select optimal PAF graph
    df = results[1]
    size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"])
    pose_config = inference_config.replace("inference_cfg", "pose_cfg")
    if not overwrite_config:
        shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml"))
    inds = list(paf_inds[size_opt])
    auxiliaryfunctions.edit_config(
        pose_config, {"paf_best": [int(ind) for ind in inds]}
    )
    if output_name:
        with open(output_name, "wb") as file:
            pickle.dump([results], file)
    return results[:3], paf_scores, results[3][size_opt]
def cross_validate_paf_graphs(
    config,
    inference_config,
    full_data_file,
    metadata_file,
    output_name="",
    pcutoff=0.1,
    oks_sigma=0.1,
    margin=0,
    greedy=False,
    add_discarded=True,
    calibrate=False,
    overwrite_config=True,
    n_graphs=10,
    paf_inds=None,
    symmetric_kpts=None,
):
    cfg = auxiliaryfunctions.read_config(config)
    inf_cfg = auxiliaryfunctions.read_plainconfig(inference_config)
    inf_cfg_temp = inf_cfg.copy()
    inf_cfg_temp["pcutoff"] = pcutoff

    with open(full_data_file, "rb") as file:
        data = pickle.load(file)
    with open(metadata_file, "rb") as file:
        metadata = pickle.load(file)

    params = _set_up_evaluation(data)
    to_ignore = auxfun_multianimal.filter_unwanted_paf_connections(
        cfg, params["paf_graph"]
    )
    best_graphs = _get_n_best_paf_graphs(
        data,
        metadata,
        params["paf_graph"],
        ignore_inds=to_ignore,
        n_graphs=n_graphs,
    )
    paf_scores = best_graphs[1]
    if paf_inds is None:
        paf_inds = best_graphs[0]

    if calibrate:
        trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg)
        calibration_file = os.path.join(
            cfg["project_path"],
            str(trainingsetfolder),
            "CollectedData_" + cfg["scorer"] + ".h5",
        )
    else:
        calibration_file = ""

    results = _benchmark_paf_graphs(
        cfg,
        inf_cfg_temp,
        data,
        paf_inds,
        greedy,
        add_discarded,
        oks_sigma=oks_sigma,
        margin=margin,
        symmetric_kpts=symmetric_kpts,
        calibration_file=calibration_file,
        split_inds=[
            metadata["data"]["trainIndices"],
            metadata["data"]["testIndices"],
        ],
    )
    # Select optimal PAF graph
    df = results[1]
    size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"])
    pose_config = inference_config.replace("inference_cfg", "pose_cfg")
    if not overwrite_config:
        shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml"))
    inds = list(paf_inds[size_opt])
    auxiliaryfunctions.edit_config(
        pose_config, {"paf_best": [int(ind) for ind in inds]}
    )
    if output_name:
        with open(output_name, "wb") as file:
            pickle.dump([results], file)
    return results[:3], paf_scores, results[3][size_opt]

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值