题目大意:
定义一个序列是好的:维护一个栈,一开始为空,遍历序列,如果当前元素值与栈顶元素相同,弹出栈顶元素,否则把这个元素入栈。当遍历完后栈为空,则它是好的。
给你一个n个元素的序列,求它有多少个非空子序列是好的。
解题思路:
DP版本:
设
d
p
[
i
]
dp[i]
dp[i]为以i为左端点满足条件的子序列个数。
对于每个位置
i
i
i,设以这个位置为起点往后遍历,第一次使得栈为空的位置为
n
x
t
[
i
]
nxt[i]
nxt[i],那么容易知道以i为左端点满足条件的子序列个数
d
p
[
i
]
=
1
+
d
p
[
n
x
t
[
i
]
+
1
]
dp[i]=1+dp[nxt[i]+1]
dp[i]=1+dp[nxt[i]+1]
那么问题转换成对每个
i
i
i,如何快速的求得这个
n
x
t
[
i
]
nxt[i]
nxt[i]。
可以知道,
(
i
,
n
x
t
[
i
]
)
(i,nxt[i])
(i,nxt[i])这段子序列,去掉i和nxt[i]之后是一个好序列(假设空序列也是好的)。设
N
X
T
[
i
]
[
x
]
NXT[i][x]
NXT[i][x]表示以i为左端点,最右边为x,且除去最右边这个x之后为好序列的序列的最小右端点,那么对于i这个位置,如果
N
X
T
[
i
+
1
]
[
a
[
i
]
]
NXT[i+1][a[i]]
NXT[i+1][a[i]]存在,那么
n
x
t
[
i
]
=
N
X
T
[
i
+
1
]
[
a
[
i
]
]
nxt[i] = NXT[i+1][a[i]]
nxt[i]=NXT[i+1][a[i]]。NXT[I]可以是一个map或者啥的数据结构。
现在我们又需要快速的求得这个NXT[i]。可以发现,如果i存在nxt[i],那么它可以继承nxt[i]+1的NXT(因为好序列+好序列=好序列),所以一路求nxt[i]并继承下来就可以了。注意用swap而不是=来继承(据说swap是O(1)的)总复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 3e5 + 50;
ll dp[maxn];
map<int,int> NXT[maxn];
int nxt[maxn];
int n;
int a[maxn];
void init(){
scanf("%d", &n);
for(int i = 1; i <= n; ++i) scanf("%d", &a[i]), nxt[i] = -1, NXT[i].clear();
NXT[n+1].clear();
}
void sol(){
for(int i = n; i > 0; --i){
if(NXT[i+1].find(a[i]) != NXT[i+1].end()){
int pos = NXT[i+1][a[i]];
nxt[i] = pos;
swap(NXT[i], NXT[pos+1]);
}
NXT[i][a[i]] = i;
}
ll ans = 0;
dp[n+1] = dp[n] = 0;
for(int i = n-1; i > 0; --i){
//cout<<"i:"<<i<<" nxt:"<<nxt[i]<<endl;
if(nxt[i] != -1) dp[i] = 1 + dp[nxt[i]+1], ans += dp[i];
else dp[i] = 0;
}
printf("%I64d\n", ans);
}
int main()
{
int T; cin>>T;
while(T--){
init(); sol();
}
}
分治版本:
对于这种统计连续段的信息的问题,都可以往分治想一想,是一个好的尝试。分治后我们需要统计每个区间跨越中线的符合条件的区间。
我们在右半边从左到右模拟栈操作,从左半边从右到左模拟栈操作。可以发现,如果左半边的[i,mid]要和右半边的[mid,j]能组成一个好序列,当且仅当它们产生的栈的内容是相同的。可以利用hash和map来存储右半边的信息,然后遍历左半边来统计信息。复杂度
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ull unsigned long long
using namespace std;
const int maxn = 3e5 + 50;
const ull sed = 1e9 + 7;
int a[maxn], n;
int s[maxn], tp;
ull ha[maxn];
ll ans;
void init(){
scanf("%d", &n); for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
}
map<ll,int> mp;
void sol(int l, int r){
if(l >= r) return;
sol(l, mid); sol(mid+1, r);
mp.clear();
tp = 0; ha[0] = 0;
for(int i = mid+1; i <= r; ++i){
if(tp && s[tp] == a[i]) tp--;
else s[++tp] = a[i], ha[tp] = ha[tp-1]*sed + a[i];
mp[ha[tp]]++;
}
tp = 0; ha[0] = 0;
for(int i = mid; i >= l; --i){
if(tp && s[tp] == a[i]) tp--;
else s[++tp] = a[i], ha[tp] = ha[tp-1]*sed + a[i];
ans += mp[ha[tp]];
}return;
}
int main()
{
int T; cin>>T;
while(T--){
init();
ans = 0;
sol(1, n);
printf("%I64d\n", ans);
}
}