现有一棵由n个节点组成的无向树,节点按从0到n - 1编号。输入一个整数n和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ui, vi, wi] 表示树中存在一条位于节点 ui 和节点 vi 之间、权重为 wi 的边。
另给你一个长度为 m 的二维整数数组 queries ,其中 queries[i] = [ai, bi] 。对于每条查询,请你找出使从 ai 到 bi 路径上每条边的权重相等所需的 最小操作次数 。在一次操作中,你可以选择树上的任意一条边,并将其权重更改为任意值。
注意:
查询之间 相互独立 的,这意味着每条新的查询时,树都会回到 初始状态 。
从 ai 到 bi的路径是一个由 不同 节点组成的序列,从节点 ai 开始,到节点 bi 结束,且序列中相邻的两个节点在树中共享一条边。
返回一个长度为 m 的数组 answer ,其中 answer[i] 是第 i 条查询的答案。
n为节点数,edges为存在边的两个节点之间的距离,queries表示的两节点间有一条路径,其中每边长度可能不同,修改最少次数让每边相等。
最近公共祖先
以节点0为根节点,使用数组count[i]
记录节点i到根节点0的路径上边权重的数量,即count[i][j]
表示节点i到根节点0的路径上权重为j的边数量。对于查询queries[i]=[ai,bi]
,记节点lcai为节点ai与bi的最近公共祖先,那么从节点ai到节点bi的路径上,权重为j的边数量tj的计算如下:
tj=count[ai][j]+count[bi][j]−2×count[lcai][j]
为了让节点 ai到节点 bi路径上每条边的权重都相等,贪心地将路径上所有的边都更改为边数量最多的权重即可,即从节点 ai到节点 bi路径上每条边的权重都相等所需的最小操作次数 resi的计算如下:
resi=∑j=1Wtj−max1≤j≤Wtj
其中W=26表示权重的最大值。
最近公共祖先节点的求解可以采用Tarjan算法。
const int W = 26;
typedef struct Node {
int node;
int index;
struct Node *next;
} Node;
typedef struct {
int key;
int val;
UT_hash_handle hh;
} HashItem;
Node *creatNode(int node, int index) {
Node *obj = (Node *)malloc(sizeof(Node));
obj->node = node;
obj->index = index;
obj->next = NULL;
return obj;
}
HashItem *hashFindItem(HashItem **obj, int key) {
HashItem *pEntry = NULL;
HASH_FIND_INT(*obj, &key, pEntry);
return pEntry;
}
bool hashAddItem(HashItem **obj, int key, int val) {
if (hashFindItem(obj, key)) {
return false;
}
HashItem *pEntry = (HashItem *)malloc(sizeof(HashItem));
pEntry->key = key;
pEntry->val = val;
HASH_ADD_INT(*obj, key, pEntry);
return true;
}
bool hashSetItem(HashItem **obj, int key, int val) {
HashItem *pEntry = hashFindItem(obj, key);
if (!pEntry) {
hashAddItem(obj, key, val);
} else {
pEntry->val = val;
}
return true;
}
int hashGetItem(HashItem **obj, int key, int defaultVal) {
HashItem *pEntry = hashFindItem(obj, key);
if (!pEntry) {
return defaultVal;
}
return pEntry->val;
}
void hashFree(HashItem **obj) {
HashItem *curr = NULL, *tmp = NULL;
HASH_ITER(hh, *obj, curr, tmp) {
HASH_DEL(*obj, curr);
free(curr);
}
}
void freeList(Node *list) {
while (list) {
Node *cur = list;
list = list->next;
free(cur);
}
}
int find(int* uf, int i) {
if (uf[i] == i) {
return i;
}
uf[i] = find(uf, uf[i]);
return uf[i];
}
void tarjan(int node, int parent, HashItem **neighbors, int **count, int *uf, int *visited, int *lca, Node **queryArr) {
if (parent != -1) {
memcpy(count[node], count[parent], sizeof(int) * (W + 1));
count[node][hashGetItem(&neighbors[node], parent, 0)]++;
}
uf[node] = node;
for (HashItem *pEntry = neighbors[node]; pEntry; pEntry = pEntry->hh.next) {
int child = pEntry->key;
if (child == parent) {
continue;
}
tarjan(child, node, neighbors, count, uf, visited, lca, queryArr);
uf[child] = node;
}
for (Node *p = queryArr[node]; p; p = p->next) {
int node1 = p->node;
int index = p->index;
if (node != node1 && !visited[node1]) {
continue;
}
lca[index] = find(uf, node1);
}
visited[node] = 1;
};
int* minOperationsQueries(int n, int** edges, int edgesSize, int* edgesColSize, int** queries, int queriesSize, int* queriesColSize, int* returnSize) {
int m = queriesSize;
HashItem *neighbors[n];
Node *queryArr[n];
for (int i = 0; i < n; i++) {
neighbors[i] = NULL;
queryArr[i] = NULL;
}
for (int i = 0; i < edgesSize; i++) {
int u = edges[i][0];
int v = edges[i][1];
int w = edges[i][2];
hashAddItem(&neighbors[u], v, w);
hashAddItem(&neighbors[v], u, w);
}
for (int i = 0; i < m; i++) {
int a = queries[i][0];
int b = queries[i][1];
Node *node1 = creatNode(b, i);
node1->next = queryArr[a];
queryArr[a] = node1;
Node *node2 = creatNode(a, i);
node2->next = queryArr[b];
queryArr[b] = node2;
}
int *count[n];
int visited[n], uf[n], lca[m];
memset(visited, 0, sizeof(visited));
memset(uf, 0, sizeof(uf));
memset(lca, 0, sizeof(lca));
for (int i = 0; i < n; i++) {
count[i] = (int *)malloc(sizeof(int) * (W + 1));
memset(count[i], 0, sizeof(int) * (W + 1));
}
tarjan(0, -1, neighbors, count, uf, visited, lca, queryArr);
int *res = (int *)malloc(sizeof(int) * m);
for (int i = 0; i < m; i++) {
int totalCount = 0, maxCount = 0;
for (int j = 1; j <= W; j++) {
int t = count[queries[i][0]][j] + count[queries[i][1]][j] - 2 * count[lca[i]][j];
maxCount = fmax(maxCount, t);
totalCount += t;
}
res[i] = totalCount - maxCount;
}
*returnSize = m;
for (int i = 0; i < n; i++) {
free(count[i]);
freeList(queryArr[i]);
hashFree(&neighbors[i]);
}
return res;
}
class Solution:
def find(self, uf: List[int], i: int) -> int:
if uf[i] == i:
return i
uf[i] = self.find(uf, uf[i])
return uf[i]
def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
m, W = len(queries), 26
neighbors = [dict() for i in range(n)]
for edge in edges:
neighbors[edge[0]][edge[1]] = edge[2]
neighbors[edge[1]][edge[0]] = edge[2]
queryArr = [[] for i in range(n)]
for i in range(m):
queryArr[queries[i][0]].append([queries[i][1], i])
queryArr[queries[i][1]].append([queries[i][0], i])
count = [[0 for j in range(W + 1)] for i in range(n)]
visited, uf, lca = [0 for i in range(n)], [0 for i in range(n)], [0 for i in range(m)]
def tarjan(node: int, parent: int):
if parent != -1:
count[node] = count[parent].copy()
count[node][neighbors[node][parent]] += 1
uf[node] = node
for child in neighbors[node].keys():
if child == parent:
continue
tarjan(child, node)
uf[child] = node
for [node1, index] in queryArr[node]:
if node != node1 and not visited[node1]:
continue
lca[index] = self.find(uf, node1)
visited[node] = 1
tarjan(0, -1)
res = [0 for i in range(m)]
for i in range(m):
totalCount, maxCount = 0, 0
for j in range(1, W+1):
t = count[queries[i][0]][j] + count[queries[i][1]][j] - 2 * count[lca[i]][j]
maxCount = max(maxCount, t)
totalCount += t
res[i] = totalCount - maxCount
return res