原题链接:https://ac.nowcoder.com/acm/contest/11258/F
题意
给定两棵树,要求找一个点集使得两棵树上都满足:
- 第一棵树上点集是相互连通的,而且任意两点互为祖先节点
- 第二棵树上任意两点都不能互为祖先节点
分析
首先这道题我们必须知道一个性质,一棵树上的dfs序入点和出点一定是连续的,且子树内的dfs序一定大于根节点的dfs序。这样其实就变成了我们在dfs序上找最多的连续区间,使得他们不相交。但要满足第一棵树上连续成链的性质,其实不难想到固定一个端点然后二分去找最远取到的祖先节点。
至于在树上维护一条链,我们在遍历点的时候将点压入栈里,回溯的时候出栈就可以了,这时这个栈内的点一定是一条链上且连续,我们在每个节点上建立一个主席树,每次选择一个右端点,将右端点贡献加入主席树内,至于贡献怎么算?只要将当前点的入点dfs序和出点dfs序的区间+1,这样之后每次查询当前端点子树内是否被染色过来进行二分。
光这样还不够,每次二分的区间也必须有限制,我们要记录每次符合要求的最右端,并将二分的最左端设置为这个值,因为答案具有单调性。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned int ul;
typedef pair<int, int> PII;
const ll inf = 1e18;
const int N = 3e5 + 10;
const int M = 1e6 + 10;
const ll mod = 1e9 + 7;
const double eps = 1e-8;
#define lowbit(i) (i & -i)
#define Debug(x) cout << (x) << endl
#define fi first
#define se second
#define mem memset
#define endl '\n'
namespace StandardIO {
template<typename T>
inline void read(T &x) {
x = 0; T f = 1;
char c = getchar();
for (; c < '0' || c > '9'; c = getchar()) if (c == '-') f = -1;
for (; c >= '0' && c <= '9'; c = getchar()) x = x * 10 + c - '0';
x *= f;
}
template<typename T>
inline void write(T x) {
if (x < 0) putchar('-'), x *= -1;
if (x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
}
namespace Comb {
ll f[N];
ll ksm(ll a, ll b) {
ll res = 1, base = a;
while (b) {
if (b & 1) res = res * base % mod;
base = base * base % mod;
b >>= 1;
}
return res;
}
void init() {
f[0] = 1;
for (ll i = 1; i < N; i++) f[i] = f[i - 1] * i % mod;
}
ll C(ll a, ll b) {
if (a < 0 || b < 0 || b > a) return 0;
return f[a] * ksm(f[a - b], mod - 2) % mod * ksm(f[b], mod - 2) % mod;
}
}
struct node {
int ls, rs;
int sum, tag;
}hjt[N*50];
int cnt, rt[N];
void modify(int &now, int pre, int ql, int qr, int l, int r, int val) {
now = ++cnt;
hjt[now] = hjt[pre];
hjt[now].sum += (min(qr, r) - max(ql, l) + 1) * val;
if (ql <= l && qr >= r) {
hjt[now].tag += val;
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) modify(hjt[now].ls, hjt[pre].ls, ql, qr, l, mid, val);
if (qr > mid) modify(hjt[now].rs, hjt[pre].rs, ql, qr, mid+1, r, val);
}
int query(int now, int pre, int ql, int qr, int l, int r) {
if (ql <= l && qr >= r) return hjt[now].sum - hjt[pre].sum;
int mid = (l + r) >> 1;
int ans = (min(qr, r) - max(ql, l) + 1) * (hjt[now].tag - hjt[pre].tag);
if (ql <= mid) ans += query(hjt[now].ls, hjt[pre].ls, ql, qr, l, mid);
if (qr > mid) ans += query(hjt[now].rs, hjt[pre].rs, ql, qr, mid+1, r);
return ans;
}
vector<int> G1[N], G2[N];
int n, in[N], out[N], tot;
void dfs1(int x, int fa) {
in[x] = ++tot;
for (auto v : G2[x]) {
if (v == fa) continue;
dfs1(v, x);
}
out[x] = tot;
}
int stk[N], top, ans, pre[N];
void dfs2(int x, int fa) {
stk[++top] = x; rt[x] = rt[fa];
int l = pre[fa], r = top;
while (l <= r) {
int mid = (l + r) >> 1;
//cout << x << ' ' << stk[mid] << ' ' << query(rt[x], rt[stk[mid]], in[x], out[x], 1, n) << endl;
if (query(rt[x], rt[stk[mid]], in[x], out[x], 1, n)) l = mid + 1;
else r = mid - 1;
}
//cout << x << ' ' << top << ' ' << top - l << ' ' << pre[fa] << endl;
pre[x] = l;
ans = max(ans, top - l);
modify(rt[x], rt[x], in[x], out[x], 1, n, 1);
for (auto v : G1[x]) {
if (v == fa) continue;
dfs2(v, x);
}
top--;
}
void init() {
for (int i = 1; i <= n; i++) {
G1[i].clear();
G2[i].clear();
in[i] = out[i] = 0;
rt[i] = 0;
}
for (int i = 1; i <= cnt; i++) hjt[i].ls = hjt[i].rs = hjt[i].sum = hjt[i].tag = 0;
cnt = tot = top = 0;
ans = 1;
}
inline void solve() {
int T; cin >> T; while (T--) {
cin >> n;
init();
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
G1[u].push_back(v);
G1[v].push_back(u);
}
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
G2[u].push_back(v);
G2[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0);
cout << ans << endl;
}
}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
signed test_index_for_debug = 1;
char acm_local_for_debug = 0;
do {
if (acm_local_for_debug == '$') exit(0);
if (test_index_for_debug > 20)
throw runtime_error("Check the stdin!!!");
auto start_clock_for_debug = clock();
solve();
auto end_clock_for_debug = clock();
cout << "Test " << test_index_for_debug << " successful" << endl;
cerr << "Test " << test_index_for_debug++ << " Run Time: "
<< double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
cout << "--------------------------------------------------" << endl;
} while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
#else
solve();
#endif
return 0;
}