Digit Tree
Codeforce 715C
看见求路径,就是点分治,点分治,点分治!!!
1.从u走到v经过的边上的数字依次以字符串的方式连接成的数,就是x*10+a的操作。
2.dis(u,v)是P的倍数,就是dis(u,v)%P=0
3.点分治可以求出以当前重心为端点的每条树链,考虑合并两条树链成一条路径。如果两条路径要合并那么当头的那条要从下往上算,当尾的那条要从上往下算,由于在根处会重复,这里可以规定当头的那条不含根的值。定义:当头的那条的值为dis1,数字个数为sz1,当尾的那条的值为dis2,数字个数为sz2。那么合并得到的路径值为dis1*pow(10,sz2)+dis2
,数字个数为sz1+sz2。
4.由此得到路径的dis为dis1*pow(10,sz2)+dis2
。发现当尾的那条需要知道两个信息,因此只能枚举当尾的那条,算出当头的那条。已知:(dis1*pow(10,sz2)+dis2)%P=0
,dis2,sz2
。则dis1=(-dis2)*逆元(pow(10,sz2))
。
5.这里要求逆元,但是P不一定是质数,所以不能用费马小定理。注意到题目保证保证gcd(P,10)=1,所以可以用拓展欧几里得求逆元
void exgcd(int a,int b,int &x,int &y) {
if(!b) {
y=0,x=1;
return;
}
exgcd(b,a%b,y,x);
y-=(a/b)*x;
}
int ni(int a,int b) {
int x,y;
exgcd(a,b,x,y);
return (x%b+b)%b;
}
6.注意处理好模运算就可以轻松愉快的AC了
具体代码
#include<bits/stdc++.h>
using namespace std;
const int M=1e5+5;
int n,P;
int head[M],asdf;
struct edge {
int to,nxt,cost;
} G[M*2];
void add_edge(int a,int b,int c) {
G[++asdf].to=b;
G[asdf].nxt=head[a];
G[asdf].cost=c;
head[a]=asdf;
}
void exgcd(int a,int b,int &x,int &y) {
if(!b) {
y=0,x=1;
return;
}
exgcd(b,a%b,y,x);
y-=(a/b)*x;
}
int ni(int a,int b) {
int x,y;
exgcd(a,b,x,y);
return (x%b+b)%b;
}
bool mark[M];
int que_len,que[M],sz[M],son[M];
void get_root(int x,int f) {
que[++que_len]=x;
sz[x]=1;
son[x]=0;
for(int i=head[x]; i; i=G[i].nxt) {
int y=G[i].to;
if(y==f||mark[y])continue;
get_root(y,x);
sz[x]+=sz[y];
if(sz[y]>son[x])son[x]=sz[y];
}
}
map<int,int>mp;
int dis[M];
void get_dis(int x,int f,int c) {
//printf("dis[%d]=%d\n",x,dis[x]);
mp[dis[x]]++;
for(int i=head[x]; i; i=G[i].nxt) {
int y=G[i].to;
if(y==f||mark[y])continue;
dis[y]=(1ll*G[i].cost*c%P+dis[x])%P;
get_dis(y,x,1ll*c*10%P);
}
}
long long ans;
void get_ans(int x,int f,int c,int t) {
ans+=t*mp[1ll*(P-dis[x])*ni(c,P)%P];
for(int i=head[x]; i; i=G[i].nxt) {
int y=G[i].to;
if(y==f||mark[y])continue;
dis[y]=(1ll*dis[x]*10%P+G[i].cost)%P;
get_ans(y,x,1ll*c*10%P,t);
}
}
void solve(int x) {
que_len=0;
get_root(x,x);
for(int i=1; i<=que_len; i++) {
son[que[i]]=max(son[que[i]],que_len-sz[que[i]]);
if(son[que[i]]<son[x])x=que[i];
}
mark[x]=1;
//printf("x=%d\n",x);
mp.clear();//由于P可能很大,所以开mp存方便
dis[x]=0;
get_dis(x,x,1);
get_ans(x,x,1,1);
ans--;
//printf("%lld\n",ans);
for(int i=head[x]; i; i=G[i].nxt) {
int y=G[i].to;
if(mark[y])continue;
mp.clear();
dis[y]=G[i].cost%P;
get_dis(y,y,10%P);
get_ans(y,y,10%P,-1);
solve(y);
}
}
int main() {
int a,b,c;
scanf("%d %d",&n,&P);
for(int i=1; i<n; i++) {
scanf("%d %d %d",&a,&b,&c);
a++,b++;
add_edge(a,b,c);
add_edge(b,a,c);
}
solve(1);
printf("%lld\n",ans);
return 0;
}