题目描述
传送门
题意:一棵树,每条边上有一个数字(1~9),给出一个与10互质的数m,问整棵树上有多少条链满足从起点走到终点树链上形成的十进制数是m的倍数。
题解
本来是看dsu on the tree找到了这道题,但是发现用dsu on the tree写好麻烦啊=w=
不过用点分就没有那么恶心了
对于每一次分治到的子树,需要一些节点的信息:从当前点出发向上走到根形成的十进制数在模m意义下,记为up(x);以及从根出发走到当前点形成的十进制数在模m意义下,记为down(x);当前根到这个点的步数(深度)为deep(x)
如果两个点xy组成的链可以整除m的话,一定满足
up(x)∗10deep(y)+down(y)≡0(modm)
移项之后可以化成
ax≡b(modm)
的形式
枚举y之后可以用扩欧求出up(x)的值。把所有的up值存下来然后二分左右端点就可以了。
比较好的一道思路题。
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
using namespace std;
#define N 100005
int n,m,root,sum;
int tot,point[N],nxt[N*2],v[N*2],c[N*2];
long long ans;
int mi[N],big[N],size[N],pt[N],deep[N],up[N],numup[N],numdown[N];
bool vis[N];
void add(int x,int y,int z)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
}
void getroot(int x,int fa)
{
size[x]=1;big[x]=0;
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa&&!vis[v[i]])
{
getroot(v[i],x);
size[x]+=size[v[i]];
big[x]=max(big[x],size[v[i]]);
}
big[x]=max(big[x],sum-size[x]);
if (big[x]<big[root]) root=x;
}
void getdeep(int x,int fa,int dep)
{
pt[++pt[0]]=x;deep[x]=dep;
up[++up[0]]=numup[x];
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa&&!vis[v[i]])
{
numup[v[i]]=(numup[x]+(long long)c[i]*mi[dep]%m)%m;
numdown[v[i]]=((long long)numdown[x]*10%m+c[i])%m;
getdeep(v[i],x,dep+1);
}
}
int gcd(int a,int b)
{
if (!b) return a;
else return gcd(b,a%b);
}
void exgcd(int a,int b,int &x,int &y)
{
if (!b) x=1,y=0;
else exgcd(b,a%b,y,x),y-=a/b*x;
}
int tongyu(int a,int b,int m)
{
int t=gcd(a,m);
if (b%t) return -1;
a/=t,m/=t,b/=t;
int x=0,y=0;
exgcd(a,m,x,y);
x=(((long long)b*x)%m+m)%m;
return x;
}
int findl(int x)
{
int l=1,r=up[0],mid,ans=-1;
while (l<=r)
{
mid=(l+r)>>1;
if (up[mid]==x) ans=mid,r=mid-1;
else if (up[mid]<x) l=mid+1;
else r=mid-1;
}
return ans;
}
int findr(int x)
{
int l=1,r=up[0],mid,ans=-1;
while (l<=r)
{
mid=(l+r)>>1;
if (up[mid]==x) ans=mid,l=mid+1;
else if (up[mid]<x) l=mid+1;
else r=mid-1;
}
return ans;
}
long long calc(int x,int now)
{
numup[x]=now%m;numdown[x]=now%m;
pt[0]=0;up[0]=0;
getdeep(x,0,(now)?1:0);
sort(up+1,up+up[0]+1);
long long t=0;
for (int i=1;i<=pt[0];++i)
{
int b=(-numdown[pt[i]]%m+m)%m;
int a=mi[deep[pt[i]]];
int state=tongyu(a,b,m);
if (state==-1) continue;
int l=findl(state);
int r=findr(state);
if (l==-1||r==-1) continue;
t+=(long long)(r-l+1);
}
if (!now) --t;
return t;
}
void dfs(int x)
{
ans+=calc(x,0);
vis[x]=1;
for (int i=point[x];i;i=nxt[i])
if (!vis[v[i]])
{
ans-=calc(v[i],c[i]);
root=0;sum=size[v[i]];
getroot(v[i],0);
dfs(root);
}
}
int main()
{
scanf("%d%d",&n,&m);
mi[0]=1;for (int i=1;i<=n;++i) mi[i]=(long long)mi[i-1]*10%m;
for (int i=1;i<n;++i)
{
int x,y,z;scanf("%d%d%d",&x,&y,&z);
++x,++y;
add(x,y,z),add(y,x,z);
}
big[0]=N;
root=0;sum=n;
getroot(1,0);
dfs(root);
printf("%I64d\n",ans);
}