题目链接:http://codeforces.com/contest/715/problem/C
题目大意:给定一棵n个点(编号0~n-1)的树和一个数M,边权为1~9的正整数,求树中有多少对有序点对(x,y)满足从x到y的路径上的边权顺次相接所组成的十进制数被M整除。
数据范围:2 ≤ n ≤ 100 000, 1 ≤ M ≤ 10^9, gcd(M,10)=1
题解:看到有关统计树上路径的问题应该可以想到树分治,关键是如何统计答案。
我们给每个点计算三个值,分别为从当前点往上到树根组成的数up,从树根往下到当前点组成的数down,以及当前点到树根的距离dep,up和down均是对M取余后的结果。
统计答案时,我们考虑以当前点为终点,那么假设起点到根组成的数为x,则x应满足(x*10^dep+down) mod M=0,即x=(M-down)/(10^dep)。于是我们可以预处理出10^i(i<=n)的逆元(由于gcd(M,10)=0,所以10^i的逆元=(10^i)^(φ(M)-1)),在进行统计时,用map记下之前的点的up值,计算的时候直接加上(M-down)乘10^dep的逆元的值的个数就可以了。由于先后顺序不确定,所以需要正着做一遍再倒着做一遍。
时间复杂度O(nlog^2n)
(ps:托cf的福,发现自己多年以来的树分治打法一直有个bug,果然当初做出刷(你确定能用这个字?)cf的决定的我真是不能更机智了hhhh)
代码如下:
#include <algorithm>
#include <cstdio>
#include <map>
const int N=100005;
int to[N*2],ne[N*2],val[N*2],fi[N],num[N],mx[N],f[N],ine[N],pow[N],
up[N],down[N],dep[N],d[N],u[N],q[N],tot=0,curn,rt,n,m;
int phi;
long long ans=0;
std::map<int,int> A;
void add(int x,int y,int z){
to[++tot]=y;val[tot]=z;ne[tot]=fi[x];fi[x]=tot;
}
void findrt(int x){
num[x]=1;mx[x]=0;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]] && to[i]!=f[x]){
f[to[i]]=x;
findrt(to[i]);
num[x]+=num[to[i]];
mx[x]=std::max(mx[x],num[to[i]]);
}
mx[x]=std::max(mx[x],curn-num[x]);
if (mx[x]<mx[rt]) rt=x;
}
int qsm(int x,int y){
int i=1;
for (;y;y>>=1,x=1ll*x*x%m)
if (y&1) i=1ll*i*x%m;
return i;
}
void work(int x,int sum){
if (sum==1) return;
curn=sum;rt=n;
findrt(x);
u[rt]=1;
int l=0,s=1,t=0;
for (int i=fi[rt];i;i=ne[i])
if (!u[to[i]]){
d[++l]=s;
f[to[i]]=rt;
up[to[i]]=down[to[i]]=val[i]%m;
dep[to[i]]=1;
for (q[++t]=to[i];s<=t;s++)
for (int j=fi[q[s]];j;j=ne[j])
if (!u[to[j]] && to[j]!=f[q[s]]){
f[q[++t]=to[j]]=q[s];
dep[to[j]]=dep[q[s]]+1;
up[to[j]]=(1ll*pow[dep[q[s]]]*val[j]+up[q[s]])%m;
down[to[j]]=(10ll*down[q[s]]+val[j])%m;
}
}
d[l+1]=s;
A.clear();A[0]=1;
for (int i=1,j=1;i<=l;i++){
for (;j<d[i+1];j++)
ans=ans+A[1ll*(m-down[q[j]])*ine[dep[q[j]]]%m];
for (j=d[i];j<d[i+1];j++) A[up[q[j]]]++;
}
A.clear();
for (int i=l,j=t;i>=1;i--){
for (;j>=d[i];j--)
ans=ans+A[1ll*(m-down[q[j]])*ine[dep[q[j]]]%m];
for (j=d[i+1]-1;j>=d[i];j--) A[up[q[j]]]++;
}
ans=ans+A[0];
x=rt;
for (int i=fi[x];i;i=ne[i])
if (!u[to[i]]){
if (num[to[i]]>num[x]) num[to[i]]=sum-num[x];
work(to[i],num[to[i]]);
}
}
int main(){
scanf("%d%d\n",&n,&m);
for (int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d\n",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
int x=m;phi=m;
for (int i=2;i*i<=x;i++)
if (x%i==0)
for (phi=phi/i*(i-1);x%i==0;x/=i);
if (x>1) phi=phi/x*(x-1);
ine[0]=1;ine[1]=qsm(10,phi-1);
for (int i=2;i<=n;i++) ine[i]=1ll*ine[i-1]*ine[1]%m;
pow[0]=1;
for (int i=1;i<=n;i++) pow[i]=10ll*pow[i-1]%m;
mx[n]=n;
if (m==1) ans=1ll*n*(n-1);
else work(0,n);
printf("%I64d\n",ans);
}