解法:树形dp,dp[k][j][i]表示模k意义下以i为终点长度为i的路径是否存在。
转移:因为每次最多走一条边,所以只需要枚举k转移即可,方程见代码,很容易想到。边界条件就是能直接到的点的dp距离值=1。代码中为了简化,使用了三个数组,循环使用,来表示当前点状态,后续状态,总状态。这样就能使写代码快很多,还有就是注意,后续状态转移完之后,还要再枚举一次,得到完整的后续状态,不然会WA!
ps:从标程的写法上涨姿势了,终于懂得如何在一堆数组中简化自己的代码了!
代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#define register int int
using namespace std;
vector<pair<int,int> >e[3005];
int n,q,sta[3005],cost[3005],uu,vv,ww,u[100005],k[100005],maxk;
bool useing[3005],nex[3005][105],dp[105][3005][105];
inline int read()
{
int res=0,f=1;
char c=getchar();
while(c<'0'||c>'9')
{
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9')
{
res=res*10+c-'0';
c=getchar();
}
return res*f;
}
void init(int now,int fa,int k,bool f[3005][105])
{
f[now][0]=1;
for(int i=0,len=e[now].size();i<len;i++)
{
int to=e[now][i].first,pay=e[now][i].second;
if(to==fa) continue;
init(to,now,k,f);
for(int j=0;j<k;j++)
if(f[to][j])
f[now][(j+pay)%k]=1;
}
}
void DP(int now,int fa,int k,bool f[3005][105])
{
int head=0;
for(int i=0,len=e[now].size();i<len;i++)
{
int to=e[now][i].first,pay=e[now][i].second;//把除了父亲之外的出边保存到一个栈里,之后更新
head++;
sta[head]=to;
cost[head]=pay;
}
for(int i=0;i<k;i++)
{
useing[i]=nex[now][i];//useing用来更新nex,nex用来更新dp,这一轮的nex就是下一轮的dp,所以再次更新,循环往复
f[now][i]|=nex[now][i];//useing是当前状态,nex是后续状态,dp是整个状态,所以分别是1,2,3维
}
useing[0]|=1;
for(int i=1;i<=head;i++)
{
int v=sta[i],pay=cost[i];
for(int j=0;j<k;++j)
nex[v][(j+pay)%k]|=useing[j];//模k意义下能到达的边,更新
for(int j=0;j<k;++j)
useing[(j+pay)%k]|=f[v][j];
}
for(int i=0;i<k;i++)
useing[i]=nex[now][i];
useing[0]|=1;
for(int i=head;i>=1;i--)//刚才是从前往后推,现在nex发生了变化,所以再次从后向前推一遍,以至于可以转移所有状态
{
int v=sta[i],pay=cost[i];
for(int j=0;j<k;++j)
nex[v][(j+pay)%k]|=useing[j];
for(int j=0;j<k;++j)
useing[(j+pay)%k]|=f[v][j];
}
for(int i=0,len=e[now].size();i<len;i++)//遍历出边
{
int to=e[now][i].first;
if(to!=fa) DP(to,now,k,f);
}
}
void solve(int k,bool f[3005][105])
{
init(1,0,k,f);
memset(nex,0,sizeof(nex));
DP(1,0,k,f);
}
int main()
{
freopen("mtree.in","r",stdin); freopen("mtree.out","w",stdout);
n=read();
for(int i=1;i<n;i++)
{
uu=read();vv=read();ww=read();
e[uu].push_back(make_pair(vv,ww));
e[vv].push_back(make_pair(uu,ww));
}
q=read();
for(int i=1;i<=q;i++)
{
u[i]=read();k[i]=read();
maxk=max(maxk,k[i]);
}
for(int i=1;i<=maxk;i++)
solve(i,dp[i]);
for(int i=1;i<=q;i++)
for(int j=k[i]-1;j>=0;j--)
if(dp[k[i]][u[i]][j])
{
printf("%d\n",j);
break;
}
}