其实打返祖边就相当于$x$到祖先这一段点(不包括两端)答案都要减$1$.
然后每个点最多减$1$次$1$。
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (int i(a); i <= (b); ++i)
#define dec(i, a, b) for (int i(a); i >= (b); --i)
#define MP make_pair
#define fi first
#define se second
typedef long long LL;
const int N = 5e3 + 10;
bool c[N][N];
int used[N];
int father[N], d[N];
int n, m, a, b, _x, _y;
int ans;
vector <int> v[N];
vector <int> cnt;
void dfs(int x, int fa){
father[x] = fa;
for (auto u : v[x]){
if (u == fa) continue;
dfs(u, x);
}
}
int main(){
while (~scanf("%d%d%d%d%d%d", &n, &m, &a, &b, &_x, &_y)){
rep(i, 0, n + 1) v[i].clear();
rep(i, 1, n){
rep(j, 1, n) c[i][j] = 0;
}
rep(i, 2, n){
int x, y;
scanf("%d%d", &x, &y);
++x;
++y;
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1, 0);
ans = 0;
rep(i, 1, n){
d[i] = (int)v[i].size();
ans ^= d[i];
}
rep(i, 1, n) used[i] = 1;
used[1] = 0;
rep(i, 1, n){
for (int j = i; j; j = father[j]){
c[j][i] = 1;
}
}
rep(i, 1, m){
int x = _x, y = _y;
_x = (a * x + b * y + ans) % n;
_y = (b * x + a * y + ans) % n;
x = _x + 1;
y = _y + 1;
printf("a = %d %d\n", x, y);
cnt.clear();
for (; x && !c[father[x]][y]; x = father[x]){
cnt.push_back(x);
if (father[x] == 1) continue;
ans ^= d[father[x]];
d[father[x]] -= used[x];
used[x] = 0;
ans ^= d[father[x]];
}
for (auto u : cnt) father[u] = x;
}
printf("%d %d\n", _x, _y);
}
return 0;
}