题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6605
题目大意:有一颗树树上每一个点都带有一个十进制位,定义f(u,v)为 u -> v路径上所有点的数位组成的十进制数,询问有多少对点满足f(u,v) % k == 0。
题解:树上路径问题,考虑点分治:设当前分治中心为 r, 路径(u,v)可以分为(u,r) ,(r,v)两段,第二段去掉r组成,做法是枚举第二段,求符合的第一段。
设第二段(r,v) 去掉 r 组成的十进制数为 y,点 v 在分治中心 r 这棵树上的深度为 h v h_v hv,(u,r)组成的十进制数 为 x。若x,y能组成答案,必须满足式子: x ∗ 1 0 h = ( k − y ) m o d    k x*10^h = (k - y) \mod k x∗10h=(k−y)modk 。通过扩展欧几里得算法可以求解线性同余方程: x ∗ 1 0 h + n ∗ k = ( k − y ) x*10^h +n*k = (k-y) x∗10h+n∗k=(k−y),并得到一个 特解 x 0 x_0 x0,它的通解为 x 0 + t ∗ k g c d ( 1 0 h , k ) x_0 + t * \frac{k}{gcd(10^h,k)} x0+t∗gcd(10h,k)k,如果开桶暴力统计所有解的答案,显然是会T的。
注意通过扩展欧几里得解得的解 x = x 0 m o d    k g c d ( 1 0 h , k ) x = x_0 \mod \frac{k}{gcd(10^h,k)} x=x0modgcd(10h,k)k,如果每次将(u,r)路径组成的数 x 存入 x m o d    k g c d ( 1 0 h , k ) x\mod \frac{k}{gcd(10^h,k)} xmodgcd(10h,k)k的桶中,那么每次我们可以通过求解最小正整数解 x 0 x_0 x0以得到所有解的答案总和,避免了去枚举所有可行解 x x x。而 k ′ = k g c d ( 1 0 h , k ) k' = \frac{k}{gcd(10^h,k)} k′=gcd(10h,k)k至多只有 l o g k logk logk种不同的 k ′ k' k′,原因是 k k k 的因子 2 2 2 至多只有 l o g k logk logk个,因此可以预处理出所有的 k ′ k' k′,每次维护更新 l o g k logk logk个桶的信息,总复杂度为 n ∗ l o g n ∗ l o g k n*logn*logk n∗logn∗logk
(注意答案是双向的,要两次扫描子树来求得所有解)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int t,n,k;
const int maxn = 5e4 + 10;
const int maxm = 3e5 + 10;
const int mx = 1e5 + 10;
int sz[maxn],f[maxn],root,tot;
int pw[maxn];
int d1[maxn],d2[maxn],d3[maxn];
map<int,int> mp,np;
vector<int> v,tmp;
vector<int> to[maxn];
bool done[maxn];
char str[maxn];
int a[maxn];
int p[20][mx],num;
ll res = 0;
void init() {
mp.clear();np.clear();
for(int i = 1; i <= n; i++) to[i].clear();
fill(done,done + n + 1,0);
fill(d1,d1 + n + 1,0);
fill(d2,d2 + n + 1,0);
fill(d3,d3 + n + 1,0);
fill(pw,pw + n + 1,0);
for(int i = 1; i <= 19; i++) fill(p[i],p[i] + k + 1,0);
res = num = tot = 0;
}
void add(int u,int v) {
to[u].push_back(v);
to[v].push_back(u);
}
int gcd(int a,int b) {
return !b ? a : gcd(b,a % b);
}
int exgcd(int a,int b,int &x,int &y) {
if(!b) {
x = 1;y = 0;
return a;
} else {
int g = exgcd(b,a % b,y,x);
y -= x * (a / b);
return g;
}
}
int mul(int a,int b,int p) {
return (int)(1ll * a * b % p);
}
void getroot(int u,int fa) {
sz[u] = 1;f[u] = 0;
for(int i = 0; i < to[u].size(); i++) {
if(to[u][i] == fa || done[to[u][i]]) continue;
getroot(to[u][i],u);
sz[u] += sz[to[u][i]];
f[u] = max(f[u],sz[to[u][i]]);
}
f[u] = max(f[u],tot - sz[u]);
if(!root || f[root] > f[u]) root = u;
}
void dfs(int u,int fa) {
v.push_back(u);tmp.push_back(u);
for(int i = 0; i < to[u].size(); i++) {
if(to[u][i] == fa || done[to[u][i]]) continue;
d1[to[u][i]] = d1[u] + 1;
d2[to[u][i]] = (mul(d2[u],10,k) + a[to[u][i]]) % k;
d3[to[u][i]] = (mul(pw[d1[to[u][i]]],a[to[u][i]],k) + d3[u]) % k;
dfs(to[u][i],u);
}
}
ll solve(int u) {
ll ans = 0;tmp.clear();
done[u] = true;
if(a[u] % k == 0) ans++;
for(int i = 0; i < to[u].size(); i++) {
if(done[to[u][i]]) continue;
v.clear();
d1[to[u][i]] = 1;d2[to[u][i]] = a[to[u][i]];d3[to[u][i]] = (mul(a[to[u][i]],pw[d1[to[u][i]]],k) + a[u]);
dfs(to[u][i],u);
for(auto it : v) {
int x,y;
int gi = exgcd(pw[d1[it]],k,x,y);
if(d3[it] % k == 0) ans++;
if((d2[it] + mul(pw[d1[it]],a[u],k) % k) % k == 0) ans++;
if((k - d2[it]) % gi != 0) continue;
else {
x = x * ((k - d2[it]) / gi);
int ki = k / gi;
x = (x % ki + ki) % ki;
ans += p[mp[ki]][x];
}
}
for(auto it : v)
for(int i = 1; i <= num; i++) p[i][d3[it] % np[i]]++;
}
for(auto it : tmp)
for(int i = 1; i <= num; i++) p[i][d3[it] % np[i]]--;
tmp.clear();
for(int i = to[u].size() - 1; i >= 0; i--) {
if(done[to[u][i]]) continue;
v.clear();
d1[to[u][i]] = 1;d2[to[u][i]] = a[to[u][i]];d3[to[u][i]] = (mul(a[to[u][i]],pw[d1[to[u][i]]],k) + a[u]);
dfs(to[u][i],u);
for(auto it : v) {
int x,y;
int gi = exgcd(pw[d1[it]],k,x,y);
if((k - d2[it]) % gi != 0) continue;
else {
x = x * ((k - d2[it]) / gi);
int ki = k / gi;
x = (x % ki + ki) % ki;
ans += p[mp[ki]][x];
}
}
for(auto it : v)
for(int i = 1; i <= num; i++) p[i][d3[it] % np[i]]++;
}
for(auto it : tmp)
for(int i = 1; i <= num; i++) p[i][d3[it] % np[i]]--;
return ans;
}
void divide(int rt) {
res += solve(rt);
for(auto it : to[rt]) {
if(done[it]) continue;
tot = sz[it];root = 0;
getroot(it,-1);
divide(root);
}
}
int main() {
scanf("%d",&t);
while(t--) {
scanf("%d%d%s",&n,&k,str + 1);
init();
for(int i = 1; i <= n; i++) a[i] = str[i] - '0';
pw[0] = 1;mp.clear();np.clear();
for(int i = 1; i <= n; i++) pw[i] = pw[i - 1] * 10 % k;
for(int i = 1; i <= n; i++) {
int p = gcd(pw[i],k);
if(mp[k / p]) break;
mp[k / p] = ++num;
np[num] = k / p;
}
for(int i = 1; i < n; i++) {
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
}
tot = n;root = 0;
getroot(1,-1);
divide(root);
printf("%lld\n",res);
}
return 0;
}
/*
2
2 10 50
1 2
*/