ZYB's Tree
Time Limit: 3000/1500 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others)
Total Submission(s): 795 Accepted Submission(s): 266
Problem Description
ZYB has a tree with N nodes,now he wants you to solve the numbers of nodes distanced no more than K for each node.
the distance between two nodes(x,y) is defined the number of edges on their shortest path in the tree.
To save the time of reading and printing,we use the following way:
For reading:we have two numbers A and B,let fai be the father of node i,fa1=0,fai=(A∗i+B)%(i−1)+1 for i∈[2,N] .
For printing:let ansi be the answer of node i,you only need to print the xor sum of all ansi.
Input
In the first line there is the number of testcases T.
For each teatcase:
In the first line there are four numbers N,K,A,B
1≤T≤5,1≤N≤500000,1≤K≤10,1≤A,B≤1000000
Output
For T lines,each line print the ans.
Please open the stack by yourself.
N≥100000 are only for two tests finally.
Sample Input
1 3 1 1 1
Sample Output
3
Source
Recommend
hujie
【思路】
题目要求树上每一点的与其他点距离不超过K的点的个数。由于K很小,可以作为状态划分,我们可以处理出所有子树中到子树根的距离为某个值的点数,也就是用dp[i][j]表示以i为根的子树里,与i距离为j的点的个数。然后所有点的答案可以通过对每个点上溯0到K层,加上往侧链走若干步的所有点,减去当前链上走若干步的所有点得到。详见代码。
【代码】
//******************************************************************************
// File Name: HDU_5593.cpp
// Author: Shili_Xu
// E-Mail: shili_xu@qq.com
// Created Time: 2018年08月11日 星期六 11时14分07秒
//******************************************************************************
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long ll;
const int MAXN = 5e5 + 5;
int t, n, k, a, b;
int fa[MAXN], sum[11];
int dp[MAXN][11];
int get(int u)
{
int ans = 0;
memset(sum, 0, sizeof(sum));
for (int i = 0, son = 0; u && i <= k; son = u, u = fa[u], i++) {
sum[i]++;
for (int j = 1; j <= k - i; j++) sum[i + j] += dp[u][j] - dp[son][j - 1];
}
for (int i = 0; i <= k; i++) ans += sum[i];
return ans;
}
int main()
{
scanf("%d", &t);
while (t--) {
scanf("%d %d %d %d", &n, &k, &a, &b);
memset(dp, 0, sizeof(dp));
fa[1] = 0;
dp[1][0] = 1;
for (int i = 2; i <= n; i++) {
fa[i] = ((ll)a * i + b) % (i - 1) + 1;
for (int u = i, j = 0; u && j <= k; u = fa[u], j++) dp[u][j]++;
}
int ans = 0;
for (int i = 1; i <= n; i++) ans ^= get(i);
printf("%d\n", ans);
}
return 0;
}