题解 牛客 Explorer (线段树+dfs / LCT)
题目链接:https://ac.nowcoder.com/acm/contest/888/E
题意
给你一个 N N N个点 M M M 条边的图,总共有 1 1 1 到 1 0 9 10^9 109 个点想要从点 1 到达点 N,但是每条边只会允许 [ l , r ] [l,r] [l,r] 的点可以经过,问最后会有多少点可以到达终点
数据范围: 1 ≤ N , M ≤ 1 0 5 ; 1 ≤ l , r ≤ 1 0 9 1 \le N ,M \le 10^5; 1 \le l,r \le 10^9 1≤N,M≤105;1≤l,r≤109
思路
线段树
题目中l,r的范围非常大,但是显然有用的区间只有 2 ∗ 1 0 5 2*10^5 2∗105 个,所以我们将其离散化,每个区间我们可以将其放在线段树上,分解为 l o g N logN logN 个小区间。
这是我们想要枚举每一个 [ l , l ] [l,l] [l,l]能否成功让 1和N 联通。
然后,我们发现,当我们在探索 [ 1 , 1 ] [1,1] [1,1] 的状态时,可以将大部分的状态沿用到 [ 2 , 2 ] [2,2] [2,2] 状态,所以,我们需要维护一个并查集来储存建边的信息,这个并查集必须还得可以进行删边,我们就不能使用路径压缩,按秩合并是一个好的方法。
LCT(待补充)
代码
// 线段树
#include <bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = (int)j;i <= (int)k;i ++)
#define debug(x) cerr<<#x<<":"<<x<<endl
#define pb push_back
typedef long long ll;
typedef pair<int,int> pi;
const int MAXN = (int)1e5+7;
int N,M;
vector<int> vp;
inline int get(int x) {return lower_bound(vp.begin(),vp.end(),x)-vp.begin()+1;}
struct Node{
int u,v,l,r;
Node(int u = 0,int v = 0,int l = 0,int r = 0):u(u),v(v),l(l),r(r){}
}e[MAXN];
stack<int> prefa;
stack<pi> presz;
int fa[MAXN],hei[MAXN];
inline int findfa(int x) {
if (x == fa[x]) return x;
return findfa(fa[x]);
}
void unite(int x,int y) {
int u = findfa(x),v = findfa(y);
if (u == v) return;
if (hei[u] > hei[v]) swap(u,v);
prefa.push(u);
fa[u] = v;
if (hei[u] == hei[v]) {
presz.push(make_pair(v,hei[v]));
hei[v] ++;
}
}
bool isConnect(int x,int y) {
if (findfa(x) == findfa(y)) return true;
else return false;
}
struct segment {
#define lson rt<<1
#define rson rt<<1|1
vector<pi> all[MAXN<<3];
void Update(int u,int v,int L,int R,int l,int r,int rt) {
if (L <= l && r <= R) {
all[rt].pb(make_pair(u,v));
return ;
}
int m = l+r>>1;
if (L <= m) Update(u,v,L,R,l,m,lson);
if (R > m) Update(u,v,L,R,m+1,r,rson);
}
ll dfs(int l,int r,int rt) {
int sz1 = prefa.size();
int sz2 = presz.size();
int res = 0;
rep(i,0,all[rt].size()-1) {
int u = all[rt][i].first;
int v = all[rt][i].second;
unite(u,v);
}
if (isConnect(1,N)) {
res += vp[r]-vp[l-1];
}else if (l != r){
int m = l+r>>1;
res += dfs(l,m,lson);
res += dfs(m+1,r,rson);
}
while (prefa.size() > sz1) {
int u = prefa.top(); prefa.pop();
fa[u] = u;
}
while (presz.size() > sz2) {
int u = presz.top().first,szt = presz.top().second; presz.pop();
hei[u] = szt;
}
return res;
}
}seg;
int main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin >> N >> M;
rep(i,1,M) {
int u,v,l,r;
cin >> u >> v >> l >> r;
vp.pb(l);vp.pb(r+1);
e[i] = Node(u,v,l,r);
}
sort(vp.begin(),vp.end());vp.erase(unique(vp.begin(),vp.end()),vp.end());
int num = vp.size();
rep(i,1,M) {
e[i].l = get(e[i].l);
e[i].r = get(e[i].r+1)-1;
seg.Update(e[i].u,e[i].v,e[i].l,e[i].r,1,num,1);
}
rep(i,1,N) fa[i] = i,hei[i] = 1;
ll ans = seg.dfs(1,num,1);
cout << ans << endl;
}
/*
2 1
1 2 1 10
*/
// LCT
#include <bits/stdc++.h>
using namespace std;
#ifndef ONLINE_JUDGE
#define debug(fmt, ...) fprintf(stderr, "[%s] " fmt "\n", __func__, ##__VA_ARGS__)
#else
#define debug(...)
#endif
const int maxn = 1 << 20;
struct Edge
{
int u, v, l, r;
bool operator<(const Edge& rhs) const { return r > rhs.r; }
} e[maxn];
const int INF = 0x3f3f3f3f;
struct LCT
{
int val[maxn];
pair<int, int> minv[maxn];
int rev[maxn], ch[maxn][2], fa[maxn];
int stk[maxn];
inline void init(int n)
{
val[0] = -INF - 1, minv[0] = {-INF - 1, 0};
for (int i = 1; i <= n; i++) val[i] = -INF, minv[i] = {-INF, i};
for (int i = 1; i <= n; i++) fa[i] = ch[i][0] = ch[i][1] = rev[i] = 0;
}
inline bool isroot(int x) { return ch[fa[x]][0] != x && ch[fa[x]][1] != x; }
inline bool get(int x) { return ch[fa[x]][1] == x; }
void pushdown(int x)
{
if (!rev[x]) return;
swap(ch[x][0], ch[x][1]);
if (ch[x][0]) rev[ch[x][0]] ^= 1;
if (ch[x][1]) rev[ch[x][1]] ^= 1;
rev[x] ^= 1;
}
void pushup(int x)
{
minv[x] = max({{val[x], x}, minv[ch[x][0]], minv[ch[x][1]]});
}
void rotate(int x)
{
int y = fa[x], z = fa[fa[x]], d = get(x);
if (!isroot(y)) ch[z][get(y)] = x;
fa[x] = z;
ch[y][d] = ch[x][d ^ 1], fa[ch[y][d]] = y;
ch[x][d ^ 1] = y, fa[y] = x;
pushup(y), pushup(x);
}
void splay(int x)
{
int top = 0;
stk[++top] = x;
for (int i = x; !isroot(i); i = fa[i]) stk[++top] = fa[i];
for (int i = top; i; i--) pushdown(stk[i]);
for (int f; !isroot(x); rotate(x))
if (!isroot(f = fa[x])) rotate(get(x) == get(f) ? f : x);
}
void access(int x)
{
for (int y = 0; x; y = x, x = fa[x]) splay(x), ch[x][1] = y, pushup(x);
}
int find(int x)
{
access(x), splay(x);
while (ch[x][0]) x = ch[x][0];
return x;
}
void makeroot(int x) { access(x), splay(x), rev[x] ^= 1; }
void link(int x, int y)
{
makeroot(x), fa[x] = y, splay(x);
}
void cut(int x, int y)
{
makeroot(x), access(y), splay(y), fa[x] = ch[y][0] = 0;
}
void update(int x, int v) { val[x] = v, access(x), splay(x); }
int query(int x, int y)
{
makeroot(y), access(x), splay(x);
return minv[x].second;
}
} lct;
int n, m;
int tot;
void addedge(int u, int v, int val)
{
if (lct.find(u) == lct.find(v))
{
int d = lct.query(u, v);
lct.val[++tot] = val;
lct.minv[tot] = {lct.val[tot], tot};
if (lct.val[d] <= val) return;
assert(d - n - 1 >= 0);
lct.cut(e[d - n - 1].u, d);
lct.cut(e[d - n - 1].v, d);
lct.link(u, tot);
lct.link(v, tot);
}
else
{
lct.val[++tot] = val;
lct.minv[tot] = {lct.val[tot], tot};
lct.link(u, tot);
lct.link(v, tot);
}
}
int query(int val)
{
if (lct.find(1) != lct.find(n)) return -1;
int d = lct.query(1, n);
assert(d > 0);
return lct.val[d];
// if (lct.val[d] <= val) return true;
// return false;
}
int main()
{
scanf("%d%d", &n, &m);
lct.init(n);
tot = n;
for (int i = 0, u, v, l, r; i < m; i++)
{
scanf("%d%d%d%d", &u, &v, &l, &r);
e[i] = {u, v, l, r};
}
sort(e, e + m);
int ans = 0, last = -1;
for (int i = 0, l, r; i < m; i++)
{
addedge(e[i].u, e[i].v, e[i].l);
r = e[i].r, l = e[i].l;
while (i < m && e[i].r == e[i + 1].r)
{
i++;
addedge(e[i].u, e[i].v, e[i].l);
l = min(l, e[i].l);
}
int lmax = query(r);
if (lmax == -1) continue;
lmax = max(l, lmax);
if (lmax > r) continue;
if (last == -1)
{
ans += r - lmax + 1;
last = lmax;
continue;
}
if (last <= lmax) continue;
r = min(r, last - 1);
ans += r - lmax + 1;
last = lmax;
}
printf("%d\n", ans);
}