题目
Toxel likes arrays. Before traveling to the Paldea region, Serval gave him an array a a a as a gift. This array has n n n pairwise distinct elements.
In order to get more arrays, Toxel performed m m m operations with the initial array. In the i i i-th operation, he modified the pi-th element of the ( i − 1 i−1 i−1)-th array to v i v_{i} vi, resulting in the i i i-th array (the initial array a is numbered as 0 0 0). During modifications, Toxel guaranteed that the elements of each array are still pairwise distinct after each operation.
Finally, Toxel got m + 1 m+1 m+1 arrays and denoted them as A 0 = a , A 1 , … , A m A_{0}=a,A_{1},…,A_{m} A0=a,A1,…,Am. For each pair ( i , j ) ( 0 ≤ i < j ≤ m ) (i,j) (0≤i<j≤m) (i,j)(0≤i<j≤m), Toxel defines its value as the number of distinct elements of the concatenation of A i A_{i} Ai and A j A_{j} Aj. Now Toxel wonders, what is the sum of the values of all pairs? Please help him to calculate the answer.
It is guaranteed that the sum of n n n and the sum of m m m over all test cases do not exceed 2 ⋅ 1 0 5 2⋅10^{5} 2⋅105.
思路
①初始化所有在原始数组中的数的数量为
m
+
1
m+1
m+1。因为一共有
m
+
1
m+1
m+1个数组,那么我一开始假定所有数组都与原始数组保持相同。
②对于第
i
(
0
≤
i
<
m
)
i(0≤i<m)
i(0≤i<m)个数组,也就是给出一个位置的修改,那么这个位置上原来的数对应的数量就要相应地减少
m
−
i
m-i
m−i,因为在当前包括后面一共
m
−
i
m-i
m−i个数组中这个位置上的数都要修改,并将其变成
v
v
v,然后
v
v
v的数量相应地增加
m
−
i
m-i
m−i。这样,我们就统计出了这
m
+
1
m+1
m+1个数组中所有不同的数分别出现的次数。
③现在统计这
m
+
1
m+1
m+1个数组中不同的数分别对答案产生的贡献之和。设某个数的数量是
c
n
t
cnt
cnt,那么含有这个数的
c
n
t
cnt
cnt个数组与其它
m
+
1
−
c
n
t
m+1-cnt
m+1−cnt个数组组合时这个数产生的贡献是
c
n
t
∗
(
m
+
1
−
c
n
t
)
cnt*(m+1-cnt)
cnt∗(m+1−cnt);此外,含有这个数的
c
n
t
cnt
cnt个数组两两组合时产生的
C
c
n
t
2
=
c
n
t
∗
(
c
n
t
−
1
)
/
2
C^{2}_{cnt}=cnt*(cnt-1)/2
Ccnt2=cnt∗(cnt−1)/2个组合分别产生了
1
1
1的贡献。因此,答案就是:
a
n
s
=
∑
1
n
+
m
(
c
n
t
[
i
]
∗
(
m
+
1
−
c
n
t
[
i
]
)
+
c
n
t
[
i
]
∗
(
c
n
t
[
i
]
−
1
)
/
2
)
ans = \sum^{n+m}_{1}(cnt[i]*(m+1-cnt[i])+cnt[i]*(cnt[i]-1)/2)
ans=∑1n+m(cnt[i]∗(m+1−cnt[i])+cnt[i]∗(cnt[i]−1)/2)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll maxn = 2e5+5;
ll n,m,a[maxn],cnt[maxn<<1];
void solve(){
cin>>n>>m;
memset(cnt,0,sizeof(cnt));
for(ll i=1;i<=n;i++)cin>>a[i], cnt[a[i]] = m+1;
for(ll i=0;i<m;i++){
ll p,v;
cin>>p>>v;
cnt[a[p]] -= (m-i);
a[p] = v;
cnt[v] += (m-i);
}
ll ans = 0;
for(ll i=1;i<=n+m;i++){
if(cnt[i]==0)continue;
ans += cnt[i]*(m+1-cnt[i]);
ans += cnt[i]*(cnt[i]-1)/2;
}
cout<<ans<<endl;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
ll tests;
cin>>tests;
while(tests--){
solve();
}
return 0;
}