大意: 给定树, 每个点有一个十进制数位, 求有多少条路径组成的十进制数被$k$整除.
点分治, 可以参考CF715C, 转化为求$10^a x+b\equiv 0(mod\space k)$的$x$的个数.
要注意
- $tmp$不要设成全局!!
- 如果$\text{y%z==0}$的话, 那么$\text{x%y%z==x%z}$
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <math.h>
#include <set>
#include <map>
#include <queue>
#include <string>
#include <string.h>
#include <bitset>
#define REP(i,a,n) for(int i=a;i<=n;++i)
#define PER(i,a,n) for(int i=n;i>=a;--i)
#define hr putchar(10)
#define pb push_back
#define lc (o<<1)
#define rc (lc|1)
#define mid ((l+r)>>1)
#define ls lc,l,mid
#define rs rc,mid+1,r
#define x first
#define y second
#define io std::ios::sync_with_stdio(false)
#define endl '\n'
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N = 1e5+10;
int sum, n, rt, m, p10[N];
int sz[N], mx[N], vis[N], b[N];
char s[N];
vector<int> g[N];
ll ans, ans1, Phi;
int gcd(int a, int b) {return b?gcd(b,a%b):a;}
int exgcd(int a, int b, int &x, int &y) {
int d;
if (b) d=exgcd(b,a%b,y,x), y-=a/b*x;
else d=a,x=1,y=0;
return d;
}
bool chk(int &a, int &b, int &p) {
//ax=b(mod p)是否有解
int x, k, d = exgcd(a,p,x,k);
if (b%d==0) a=1,p/=d,b=(b/d*x%p+p)%p;
return a==1;
}
void getrt(int x, int fa) {
mx[x]=0, sz[x]=1;
for (int y:g[x]) if (!vis[y]&&y!=fa) {
getrt(y,x),sz[x]+=sz[y];
mx[x]=max(mx[x],sz[y]);
}
mx[x]=max(mx[x],sum-sz[x]);
if (mx[rt]>mx[x]) rt=x;
}
int ID(int x) {
return lower_bound(b+1,b+1+*b,x)-b;
}
map<int,int> mp[40];
//mp[i][j] 记录10^h*x=y(mod m)的y的个数, 其中y = j (mod b[i]), b[i] = m/gcd(10^h,m)
void dfs1(int x, int fa, int dep, int down) {
//求10^dep*x=(m-down)%m
int a = p10[dep], b = (m-down)%m, p = m;
if (chk(a,b,p)) {
auto &u = mp[ID(p)];
if (u.count(b)) ans += u[b];
}
for (int y:g[x]) if (!vis[y]&&y!=fa) {
dfs1(y,x,dep+1,((ll)down*10ll+s[y])%m);
}
}
int up[40];
void dfs2(int x, int fa, int dep) {
REP(i,1,*b) {
++mp[i][up[i]];
}
int tmp[40];
for (int y:g[x]) if (!vis[y]&&y!=fa) {
REP(i,1,*b) tmp[i]=up[i],up[i]=((ll)s[y]*p10[dep]+up[i])%b[i];
dfs2(y,x,dep+1);
REP(i,1,*b) up[i]=tmp[i];
}
}
void dfs3(int x, int fa, int down, int dep, int up) {
ans1 += !up+!down;
for (int y:g[x]) if (!vis[y]&&y!=fa) {
dfs3(y,x,((ll)down*10+s[y])%m,dep+1,((ll)s[y]*p10[dep]+up)%m);
}
}
vector<int> q;
void calc(int x) {
REP(i,1,*b) mp[i].clear();
if (s[x]%m==0) ++ans1;
for (int y:q) {
dfs1(y,x,1,s[y]%m);
REP(i,1,*b) up[i] = (s[x]+10ll*s[y])%b[i];
dfs2(y,x,2);
dfs3(y,x,(10ll*s[x]+s[y])%m,2,(s[x]+10ll*s[y])%m);
}
}
void solve(int x) {
vis[x] = 1;
q.clear();
for (int y:g[x]) if (!vis[y]) q.pb(y);
calc(x);
reverse(q.begin(),q.end());
calc(x);
for (int y:g[x]) if (!vis[y]) {
mx[rt=0]=n,sum=sz[y];
getrt(y,0), solve(rt);
}
}
void work() {
scanf("%d%d%s", &n, &m, s+1);
REP(i,1,n) p10[i]=p10[i-1]*10ll%m;
REP(i,1,n) s[i]-='0';
ans = ans1 = 0;
REP(i,1,n) g[i].clear(),vis[i]=0;
REP(i,2,n) {
int u, v;
scanf("%d%d", &u, &v);
g[u].pb(v);
g[v].pb(u);
}
if (m==1) return printf("%lld\n", (ll)n*n),void();
*b = 0;
REP(i,0,min(n,30)) b[++*b]=m/gcd(p10[i],m);
sort(b+1,b+1+*b),*b=unique(b+1,b+1+*b)-b-1;
sum=mx[rt=0]=n,getrt(1,0),solve(rt);
printf("%lld\n", ans+ans1/2);
}
int main() {
p10[0]=1;
int t;
scanf("%d", &t);
while (t--) work();
}