Description
给出一个
n
×
m
n \times m
n×m 的网格图,每个格子上有一个数,形成一个
n
×
m
n \times m
n×m 阶排列。
求有多少个区间
[
l
,
r
]
[l,r]
[l,r],使得这个区间内的所有数所在的格子网格图上构成一棵树。
n × m ≤ 200000 n\times m \leq 200000 n×m≤200000
Solution
考虑把树的限制拆为没有环 + + +连通块个数为 1 1 1。
先只考虑没有环的限制,显然,当区间右端点固定时,合法的左端点是一段区间。所以这个可以 l c t + lct+ lct+双指针。
森林的连通块的数目可以表示为点数 − - −边数,设 f i f_i fi表示添加 [ i , r ] [i,r] [i,r]的数时点数 − - −边数的值, f i = 1 f_i=1 fi=1的 i i i为答案。那么用线段树维护 f i f_i fi的值,只用支持区间加和区间求 M i n Min Min及其个数。
#include <bits/stdc++.h>
using namespace std;
inline int gi()
{
char c = getchar();
while(c < '0' || c > '9') c = getchar();
int sum = 0;
while('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
return sum;
}
const int maxn = 200005;
const int d[4][2] = {
0, 1, 0, -1, 1, 0, -1, 0
};
int n, m, N;
int f[maxn], ch[maxn][2], rev[maxn], x[maxn], y[maxn];
int p[1005][1005];
#define get(x) (ch[f[x]][1] == x)
#define is_root(x) (ch[f[x]][0] != x && ch[f[x]][1] != x)
inline void rotate(int x)
{
int fa = f[x], gfa = f[fa], k = get(x);
ch[fa][k] = ch[x][k ^ 1]; f[ch[x][k ^ 1]] = fa;
ch[x][k ^ 1] = fa;
if (!is_root(fa)) ch[gfa][get(fa)] = x;
f[fa] = x; f[x] = gfa;
}
#define push_rev(x) (swap(ch[x][0], ch[x][1]), rev[x] ^= 1)
inline void pushdown(int x)
{
if (!rev[x]) return ;
if (ch[x][0]) push_rev(ch[x][0]);
if (ch[x][1]) push_rev(ch[x][1]);
rev[x] = 0;
}
inline void splay(int x)
{
static int stk[maxn], top, fa;
stk[top = 1] = x;
while (!is_root(x)) stk[++top] = x = f[x];
while (top) pushdown(stk[top--]);
x = stk[1];
while (!is_root(x)) {
fa = f[x];
if (!is_root(fa))
get(x) ^ get(fa) ? rotate(x) : rotate(fa);
rotate(x);
}
}
inline void access(int x)
{
for (int y = 0; x; y = x, x = f[x])
splay(x), ch[x][1] = y;
}
inline void make_root(int x) {access(x); splay(x); push_rev(x);}
inline void link(int x, int y) {make_root(x); make_root(y); f[y] = x;}
inline void cut(int x, int y) {make_root(x); access(y); splay(y); ch[y][0] = f[x] = 0;}
inline int find(int x) {access(x); splay(x); while (ch[x][0]) x = ch[x][0]; splay(x); return x;}
int Min[maxn << 2], sum[maxn << 2], tag[maxn << 2];
#define mid ((l + r) >> 1)
#define lch (s << 1)
#define rch (s << 1 | 1)
void build(int s, int l, int r)
{
if (l == r) return Min[s] = 0, sum[s] = 1, void();
build(lch, l, mid);
build(rch, mid + 1, r);
if (Min[lch] == Min[rch]) sum[s] = sum[lch] + sum[rch];
else if (Min[lch] < Min[rch]) sum[s] = sum[lch];
else sum[s] = sum[rch];
Min[s] += tag[s];
}
void insert(int s, int l, int r, int ql, int qr, int v)
{
if (ql <= l && r <= qr) return tag[s] += v, Min[s] += v, void();
if (ql <= mid) insert(lch, l, mid, ql, qr, v);
if (qr >= mid + 1) insert(rch, mid + 1, r, ql, qr, v);
if (Min[lch] == Min[rch]) sum[s] = sum[lch] + sum[rch];
else if (Min[lch] < Min[rch]) sum[s] = sum[lch];
else sum[s] = sum[rch];
Min[s] = min(Min[lch], Min[rch]) + tag[s];
}
pair<int, int> operator + (pair<int, int> a, pair<int, int> b)
{
return a.first == b.first ? make_pair(a.first, a.second + b.second) : min(a, b);
}
pair<int, int> query(int s, int l, int r, int ql, int qr)
{
if (ql <= l && r <= qr) return make_pair(Min[s], sum[s]);
pair<int, int> res(1e9, 0);
if (ql <= mid) res = query(lch, l, mid, ql, qr);
if (qr >= mid + 1) res = res + query(rch, mid + 1, r, ql, qr);
res.first += tag[s];
return res;
}
int check(int l, int r)
{
for (int x1, y1, i = 0; i < 4; ++i) {
x1 = x[r] + d[i][0]; y1 = y[r] + d[i][1];
if (p[x1][y1] < l || p[x1][y1] >= r) continue;
for (int x2, y2, j = i + 1; j < 4; ++j) {
x2 = x[r] + d[j][0]; y2 = y[r] + d[j][1];
if (p[x2][y2] < l || p[x2][y2] >= r) continue;
if (find(p[x1][y1]) == find(p[x2][y2])) return 1;
}
}
return 0;
}
void add(int t, int v, int l, int r)
{
insert(1, 1, N, 1, t, v);
for (int x1, y1, i = 0; i < 4; ++i) {
x1 = x[t] + d[i][0]; y1 = y[t] + d[i][1];
if (p[x1][y1] < l || p[x1][y1] >= r) continue;
insert(1, 1, N, 1, min(t, p[x1][y1]), -v);
if (v > 0) assert(find(p[x1][y1]) != find(t)), link(p[x1][y1], t);
else assert(find(p[x1][y1]) == find(t)), cut(p[x1][y1], t);
}
}
int main()
{
freopen("f.in", "r", stdin);
freopen("f.out", "w", stdout);
n = gi(); m = gi(); N = n * m;
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
p[i][j] = gi(), x[p[i][j]] = i, y[p[i][j]] = j;
long long ans = 0;
pair<int, int> res;
build(1, 1, N);
for (int r = 1, l = 1; r <= N; ++r) {
while (l < r && check(l, r)) {
add(l, -1, l, r), ++l;
}
add(r, 1, l, r);
res = query(1, 1, N, l, r);
ans += res.first == 1 ? res.second : 0;
}
printf("%lld\n", ans);
return 0;
}