复读数组
先考虑
k
=
1
k=1
k=1的情况,最暴力的方法就是枚举所有区间,但是由于每个元素的贡献是独立的,我们单独考虑每个元素的贡献。
发现整个区间被相同的元素断成了若干个小区间,而这些小区间都是对答案没有贡献的,我们可以
拿
所
有
的
区
间
数
−
小
区
间
数
拿所有的区间数-小区间数
拿所有的区间数−小区间数得到一个元素的贡献,对于
k
≠
1
k\not=1
k=1的情况,考虑重复的出现次数即可,中间的各出现了
k
k
k,首尾相接的出现了
k
−
1
k-1
k−1次,还有一个首段和一个尾段。
时间复杂度
O
(
n
)
O(n)
O(n)
#include <cstdio>
#include <algorithm>
using namespace std;
#define int __int128
const int MAXN = 100005;
const int MOD = 1e9+7;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,k,ans,a[MAXN],b[MAXN];
int sum[MAXN],la[MAXN],fi[MAXN];
int get(int x)
{
x%=MOD;
return x*(x+1)/2%MOD;
}
signed main()
{
//freopen("a.in","r",std);
n=read();k=read();
for(int i=1;i<=n;i++)
a[i]=b[i]=read();
sort(b+1,b+1+n);
int len=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++)
{
a[i]=lower_bound(b+1,b+1+len,a[i])-b;
}
for(int i=1;i<=n;i++)
{
if(la[a[i]])
sum[a[i]]=(sum[a[i]]+k*get(i-la[a[i]]-1))%MOD;
else
fi[a[i]]=i;
la[a[i]]=i;
}
for(int i=1;i<=len;i++)
{
sum[i]=(sum[i]+(k-1)*get(n-la[i]+fi[i]-1))%MOD;
sum[i]=(sum[i]+get(fi[i]-1)+get(n-la[i]))%MOD;
ans=(ans+get(n*k)-sum[i])%MOD;
}
printf("%d\n",(ans%MOD+MOD)%MOD);
}
路径计数机
无脑枚举然后暴力检查,时间复杂度
O
(
n
5
)
O(n^5)
O(n5),但是实际时间复杂度要小得多,可以过掉40分。
考虑正解,我们先枚举
(
u
,
v
)
(u,v)
(u,v)在快速算出符合的
(
x
,
y
)
(x,y)
(x,y),难点在于判相交,从
l
c
a
lca
lca的角度考虑:
- 如果 l c a x , y lca_{x,y} lcax,y在 l c a u , v lca_{u,v} lcau,v的子树内,那么当且仅当 l c a x , y lca_{x,y} lcax,y不在 ( u , v ) (u,v) (u,v)的树上路径上满足条件。
- 如果 l c a x , y lca_{x,y} lcax,y在 l c a u , v lca_{u,v} lcau,v的子树外,那么 ( x , y ) (x,y) (x,y)的树上路径不经过 l c a u , v lca_{u,v} lcau,v的返祖边。
可以考虑那
全
部
−
不
满
足
条
件
的
全部-不满足条件的
全部−不满足条件的,算出不满足第一个条件的可以预处理每个点为
l
c
a
lca
lca有多少距离为
Q
Q
Q的点对,然后算出路径上所有点为
l
c
a
lca
lca的情况,用树上差分最后一次性算。
对于不满足第二个条件的,可以预处理算出路径经过每条边的点对有多少,往上往下强行硬搜,用经过边数为
j
j
j和
Q
−
j
−
1
Q-j-1
Q−j−1的情况相乘然后求和即可。
由于
n
≤
3000
n\leq 3000
n≤3000,可以用
t
a
r
j
a
n
tarjan
tarjan预处理每两个点的
l
c
a
lca
lca(这是我第一次写
t
a
r
j
a
n
tarjan
tarjan qwq)。
时间复杂度
O
(
n
2
)
O(n^2)
O(n2)。
#include <cstdio>
#define LL long long
const int MAXN = 3005;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^'0'),c=getchar();
return x*flag;
}
int n,P,Q,tot,lca[MAXN][MAXN],t1[MAXN],t2[MAXN];
int f[MAXN],fa[MAXN],p[MAXN],dep[MAXN],c[MAXN];
int cnt,cnt1[MAXN],cnt2[MAXN];LL ans;
struct edge
{
int v,next;
}e[MAXN*2];
int findSet(int x)
{
if(x^p[x]) p[x]=findSet(p[x]);
return p[x];
}
void tarjan(int u,int par)
{
p[u]=u;
fa[u]=par;
dep[u]=dep[par]+1;
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==par) continue;
tarjan(v,u);
p[v]=u;
}
for(int i=1;i<=n;i++)
if(p[i])
lca[u][i]=lca[i][u]=findSet(i);
}
int dist(int x,int y)
{
return dep[x]+dep[y]-2*dep[lca[x][y]];
}
void up(int u,int len,int from)
{
t1[len]++;
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==from) continue;
up(v,len+1,u);
}
}
void down(int u,int len)
{
t2[len]++;
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa[u]) continue;
down(v,len+1);
}
}
void dfs(int u)
{
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa[u]) continue;
dfs(v);
c[u]+=c[v];
}
ans-=c[u]*cnt1[u];
}
int main()
{
n=read();P=read();Q=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
e[++tot]=edge{v,f[u]},f[u]=tot;
e[++tot]=edge{u,f[v]},f[v]=tot;
}
tarjan(1,0);
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
if(dist(i,j)==Q)
{
cnt++;
cnt1[lca[i][j]]++;
}
for(int i=2;i<=n;i++)
{
for(int j=0;j<=Q;j++)
t1[j]=t2[j]=0;
up(fa[i],0,i);
down(i,0);
for(int j=0;j<Q;j++)
cnt2[i]+=t1[j]*t2[Q-j-1];
cnt2[i]*=2;
}
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
if(dist(i,j)==P)
{
ans+=cnt-cnt2[lca[i][j]];
c[i]++;c[j]++;c[lca[i][j]]--;c[fa[lca[i][j]]]--;
}
dfs(1);
printf("%lld\n",ans);
}