@description@
给定一个矩阵。求它的所有子矩阵中本质不同的行的个数之和。
input
第一行,两个正整数 n, m。
第二行,n * m 个正整数,第 i 个数表示 A[i/m][i mod m]。
保证 n * m <= 10^5, 1 <= A[i][j] <= 10^9
output
输出一个非负整数表示答案。
sample input
2 2
1 1 1 2
sample output
11
@solution@
假如我们枚举矩阵的左右边界,从上往下扫描行。
假如第 i 行上一个与它相同的行在第 j 行,则它对答案的贡献,即只考虑它这一行(因为包含其他与第 i 行相同的行已经被统计过了)的子矩阵数量,等于它到下边界的距离*它到第 j 行的距离。
我们可以只枚举左边界,再把每一行插入 trie 里面。这样我们就可以不用特意去枚举右边界(因为插入进 trie 的时候就可以顺便统计出每一列作为右边界的贡献),就可以省去繁杂的字符串比较匹配,简化时间复杂度。
其实是因为右边界移动时有些之前的信息可以被保留下来。
再细细品味,可以发现左边界移动时有些信息也可以保留下来。
具体的操作而言,可以是先固定左边界在第一列,往右移动时将根的所有子树合并成一棵 trie,同时动态维护出答案。
具体到算法细节,我们在每个结点中维护一个 set 表示包含这个结点所表示的字符串的行集合,再维护一个 val 表示这个结点对答案的贡献。
如果向右移动左边界,先减去根的所有儿子对答案的贡献 val,然后随便选中根的某一个子树,将其他的子树向它合并。
如果两个子树 A 要向 B 合并,首先要将 A, B 根结点合并成一个结点。对于 A 根结点的某一个儿子,如果 B 没有则 B 根结点接指针到这个儿子;否则再递归合并 A, B 的这一棵子树。
如果两个结点 p 和 q 合并,其实最主要的是 p 和 q 的 set 合并,我们采用启发式合并的方法(小的往大的合)。枚举 p 中的 set 中的每一个行,将这个行插入 q 中的 set,同时求出只包含这一行的子矩阵个数,即在它上面且离它最近的行到它的距离 * 在它下面且离它最近的行到它的距离。
时间复杂度看似很高,实际上总结点数 = 结点大小 = n*m,每次结点合并都会至少减少一个结点,每次子树合并实际上只有结点合并时才会遍历这个结点。而结点的合并只会合并同一深度的结点,同一深度的 set 大小之和刚好等于行数 n,又因为我们采用的是启发式合并,所以每个值最多被合并 log 次。加上 set 的维护是 log 级别的。
所以时间复杂度 O(nlog^2n)(这个 n 是矩阵大小 10^5)。
话说我感觉本题好像不需要子树的启发式合并……
@accepted code@
常数很大,本地测试过不了全部数据。可能是 STL 用得太猛了。
#include<set>
#include<map>
#include<cstdio>
#include<algorithm>
using namespace std;
struct node;
typedef set<int> Set;
typedef set<int>::iterator set_it;
typedef map<int, node*>::iterator map_it;
typedef long long ll;
const int MAXN = 500000;
Set pl1[MAXN + 5], *cnt1;
struct node{
map<int, node*>ch;
Set *s; ll val;
}pl2[MAXN + 5], *root, *cnt2, *nw;
ll nwtot; int n, m, x;
void init() {
cnt1 = &pl1[0], cnt2 = &pl2[0];
root = nw = cnt2;
nwtot = 0;
}
node *newnode() {
cnt2++, cnt2->s = (++cnt1), cnt2->s->insert(0), cnt2->s->insert(n+1);
return cnt2;
}
void insert(int id, int x) {
if( !nw->ch.count(x) ) nw->ch[x] = newnode();
nw = nw->ch[x];
set_it it1 = nw->s->lower_bound(id), it2 = it1; it1--;
ll del = 1LL*(id - (*it1))*((*it2) - id);
nw->val += del, nwtot += del;
nw->s->insert(id);
}
void node_merge(node *a, node *b) {
for(map_it it=a->ch.begin();it!=a->ch.end();it++) {
if( b->ch.count(it->first) ) {
node *tmp = b->ch[it->first];
if( tmp->s->size() < it->second->s->size() ) {
swap(tmp->s, it->second->s);
swap(tmp->val, it->second->val);
}
nwtot -= it->second->val;
for(set_it it2=it->second->s->begin();it2!=it->second->s->end();it2++) {
if( !(*it2) || (*it2) == n+1 ) continue;
set_it it3=tmp->s->lower_bound(*it2), it4 = it3; it3--;
ll del = 1LL*((*it2) - (*it3))*((*it4) - (*it2));
nwtot += del, tmp->val += del;
tmp->s->insert(*it2);
}
node_merge(it->second, tmp);
}
else b->ch[it->first] = it->second;
}
}
void trie_merge() {
node *rt = root->ch.begin()->second;
for(map_it it=root->ch.begin();it!=root->ch.end();it++) {
nwtot -= it->second->val;
if( it != root->ch.begin() )
node_merge(it->second, rt);
}
root = rt;
}
inline int read() {
int x = 0; char ch = getchar();
while( ch > '9' || ch < '0' ) ch = getchar();
while( '0' <= ch && ch <= '9' ) x = 10*x + ch-'0', ch = getchar();
return x;
}
int main() {
init(); n = read(), m = read();
for(int i=0;i<n*m;i++) {
if( i % m == 0 ) nw = root;
x = read(); insert(i/m + 1, x);
}
ll ans = nwtot;
for(int i=1;i<m;i++)
trie_merge(), ans += nwtot;
printf("%lld\n", ans);
}
@details@
trie 还能合并,是真的没想到。
话说我即使不加启发式合并也能跑得很快。随机化合并大法好啊。
STL 多起来的确很容易让人昏昏沉沉的,而且还不好调试。