classUnionFindSet:def__init__(self, n:int)->None:
self.parent =[i for i inrange(n)]
self.rank =[0]* n
deffind(self, x:int)->int:if self.parent[x]!= x:
self.parent[x]= self.find(self.parent[x])return self.parent[x]defunion(self, x:int, y:int)->None:
xroot, yroot = self.find(x), self.find(y)if xroot != yroot:if self.rank[xroot]< self.rank[yroot]:
xroot, yroot = yroot, xroot
self.parent[yroot]= xroot
if self.rank[xroot]== self.rank[yroot]:
self.rank[xroot]+=1
classSolution:defcountPairs(self, n:int, edges: List[List[int]])->int:
graph = defaultdict(list)for u, v in edges:
graph[u].append(v)
graph[v].append(u)
ufs = UnionFindSet(n)for u, vset in graph.items():for v in vset:
ufs.union(u, v)
counter = defaultdict(int)for u inrange(n):
counter[ufs.find(u)]+=1
res =0for cnt in counter.values():
res += cnt *(n - cnt)
res //=2return res