ZYB's Tree
Accepts: 77
Submissions: 513
Time Limit: 3000/1500 MS (Java/Others)
Memory Limit: 131072/131072 K (Java/Others)
问题描述
ZYB有一颗N个节点的树,现在他希望你对于每一个点,求出离每个点距离不超过K的点的个数. 两个点(x,y)在树上的距离定义为两个点树上最短路径经过的边数, 为了节约读入和输出的时间,我们采用如下方式进行读入输出: 读入:读入两个数A,B,令fai为节点i的父亲,fa1=0;fai=(A∗i+B)%(i−1)+1 i∈[2,N] . 输出:输出时只需输出N个点的答案的xor和即可。
输入描述
第一行一个整数T表示数据组数。 接下来每组数据: 一行四个正整数N,K,A,B. 最终数据中只有两组N≥100000。 1≤T≤5,1≤N≤500000,1≤K≤10,1≤A,B≤1000000
输出描述
T行每行一个整数表示答案.
输入样例
1 3 1 1 1
输出样例
3
思路:
观察到K的范围不大,那么我们考虑对每个点进行Dp计数。
这类Dp基本上都是,分两个方向去Dp。很套路。
①设定Dp【i】【j】表示以点i为根的子树中,到点i距离为j的点的个数。那么不难写出有:Dp【i】【j】+=Dp【son【i】】【j-1】;
②再设定F【i】【j】表示以点i为中心,非子树方向,到点i距离为j的点的个数,其实这部分的转移也很好想,除了根方向以外的部分,都要转移过来即可。
那么只有两种可能,一种是从父亲节点转移过来:F【i】【j】+=F【fa】【j-1】;
另一种可能就是从兄弟节点转移过来:F【i】【j】+=Dp【brother】【j-2】;
很显然直接去转移兄弟节点会TLE掉,因为一个节点的兄弟节点会很多,那么每个兄弟节点都处理一次的话,任务量实在是太大了。所以我们优化一下,设定Sum【i】【j】表示ΣDp【son【i】】【j】;那么我们就可以优化最后一个转移方程为:F【i】【j】+=Sum【fa【i】】【j-2】-Dp【i】【j-2】;
过程维护统计一下即可。
#include<stdio.h>
#include<vector>
#include<string.h>
using namespace std;
#define ll long long int
vector<int>mp[550000];
int dp[550000][12];
int sum[550000][12];
int F[550000][12];
int n,k,A,B;
void Dfs(int u,int from)
{
for(int i=0; i<mp[u].size(); i++)
{
int v=mp[u][i];
if(v==from)continue;
Dfs(v,u);
for(int j=1; j<=k; j++)
{
dp[u][j]+=dp[v][j-1];
}
}
for(int i=0; i<mp[u].size(); i++)
{
int v=mp[u][i];
if(v==from)continue;
for(int j=0; j<=k; j++)
{
sum[u][j]+=dp[v][j];
}
}
}
void dfs(int u,int from)
{
if(from!=-1)for(int j=1; j<=k; j++)F[u][j]+=F[from][j-1];
for(int j=2; j<=k; j++)
{
if(j-2>=0&&from!=-1)
{
F[u][j]+=sum[from][j-2]-dp[u][j-2];
}
}
for(int i=0; i<mp[u].size(); i++)
{
int v=mp[u][i];
if(v==from)continue;
dfs(v,u);
}
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
memset(sum,0,sizeof(sum));
memset(dp,0,sizeof(dp));
memset(F,0,sizeof(F));
scanf("%d%d%d%d",&n,&k,&A,&B);
for(int i=1; i<=n; i++)mp[i].clear();
for(int i=2; i<=n; i++)
{
ll fa=(ll)((ll)A*(ll)i+B)%(ll)(i-1)+1;
mp[i].push_back(fa);
mp[fa].push_back(i);
}
for(int i=1; i<=n; i++)dp[i][0]=1,F[i][0]=1;
Dfs(1,-1);
dfs(1,-1);
int ans=0;
for(int i=1; i<=n; i++)
{
int sum=0;
for(int j=0; j<=k; j++)
{
sum+=F[i][j]+dp[i][j];
}
ans^=(sum-1);
}
printf("%d\n",ans);
}
}