题意:
给你一棵树,每条边有一个权值[0,9],让你找出所有点对(u,v)使得u到v路径组成的数能被m整除(像字符串一样组成)
题解:
这题的思路,对这棵树进行点分治,每次分治算经过根节点的满足条件的点对数有几个,每次分治时在dfs时记录下每个节点的d1(从该节点到根组成的数),d2(从根到该节点组成的数),并且我们用map来存d1的个数,用pair<int,int>来存d2和deep(该节点到根的距离即深度,这样我们可以通过d1*(10^deep)+d2来表示一个数,然后对每一个pair<d2,deep>计算map中有多少个d1使得d1*(10^deep)+d2%m==0,也即map<-d2/(10^deep)>,这里要用到逆元,最好用exgcd求,如果有费马小定理求得话记得特判1(不然会TLE),处理数字的时候要对从根出发的数字特殊处理下
#include<map>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll __int64
#define MS(x,y) memset(x,y,sizeof(x))
#define pii pair<int,int>
#define MP make_pair
using namespace std;
const int N=1e5+10;
const int M=1e6+10;
void in(){freopen("in.in","r",stdin);}
void out(){freopen("out.out","w",stdout);}
struct node
{
int v,w,next;
}edge[M];
ll ans;
map<int,ll>mp;
pii dig[N*2];
int head[N],size[N],mx[N],a[N];
int NE,max_sub,root,num,m,n;
bool vis[N];
void init()
{
MS(head,-1);
NE=0;
a[0]=1;
for(int i=1;i<=n;i++)a[i]=a[i-1]*10ll%m;
}
void add(int u,int v,int w)
{
edge[NE].v=v;
edge[NE].w=w;
edge[NE].next=head[u];
head[u]=NE++;
}
void dfssize(int u,int fa)
{
size[u]=1;
mx[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v])continue;
if(v==fa)continue;
dfssize(v,u);
size[u]+=size[v];
mx[u]=max(mx[u],size[v]);
}
}
void dfsroot(int r,int u,int fa)
{
if(size[r]-size[u]>mx[u])mx[u]=size[r]-size[u];
if(mx[u]<max_sub){max_sub=mx[u];root=u;}
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa)continue;
if(vis[v])continue;
dfsroot(r,v,u);
}
}
void dfsdis(int u,int fa,ll d1,ll d2,int deep)
{
if(deep>=0){
mp[d1]++;
dig[num++]=MP(d2,deep);
}
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa)continue;
if(vis[v])continue;
ll d3=(d1+edge[i].w*1ll*a[deep+1])%m;
ll d4=(0ll+d2*10+edge[i].w)%m;
dfsdis(v,u,d3,d4,deep+1);
}
}
void gcd(ll a,ll b,ll &d,ll &x,ll &y){
if(!b){
d=a;
x=1;
y=0;
return ;
}
gcd(b,a%b,d,y,x);
y-=x*(a/b);
}
ll inv(ll a,ll n){
ll d,x,y;
gcd(a,n,d,x,y);
return d==1?(x%n+n)%n:-1;
}
ll cal(int u,int d)
{
ll res=0;
mp.clear();num=0;
if(d)dfsdis(u,-1,d%m,d%m,0);
else dfsdis(u,-1,0,0,-1);
for(int i=0;i<num;i++){
ll tmp=((-dig[i].first*1ll*inv(a[dig[i].second+1],m))%m+m)%m;
if(mp.find(tmp)!=mp.end())res+=mp[tmp];
if(d==0)res+=(dig[i].first==0);
}
if(d==0)res+=mp[0];
return res;
}
void solve(int u)
{
max_sub=n;
dfssize(u,-1);
dfsroot(u,u,-1);
ans+=cal(root,0);
vis[root]=1;
for(int i=head[root];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v])continue;
ans-=cal(v,edge[i].w);
solve(v);
}
}
int main()
{
while(~scanf("%d%d",&n,&m)){
init();
for(int i=0;i<n-1;i++){
int u,v,w; scanf("%d%d%d",&u,&v,&w);
add(u,v,w); add(v,u,w);
}
ans=0;MS(vis,0);
solve(1);
printf("%I64d\n",ans);
}
return 0;
}