bzoj4231 回忆树(AC自动机+fail树+KMP(+树状数组))

83 篇文章 0 订阅

bzoj4231 回忆树

原题地址http://www.lydsy.com/JudgeOnline/problem.php?id=3881

题意:
回忆树是树。
具体来说,是n个点n-1条边的无向连通图,点标号为1~n,每条边上有一个字符(出于简化目的,我们认为只有小写字母)。
对一棵回忆树来说,回忆当然是少不了的。
一次回忆是这样的:你想起过往,触及心底…唔,不对,我们要说题目。
这题中我们认为回忆是这样的:给定2个点u,v(u可能等于v)和一个非空字符串s,问从u到v的简单路径上的所有边按照到u的距离从小到大的顺序排列后,边上的字符依次拼接形成的字符串中给定的串s出现了多少次。

数据范围
n<=100000,m<=100000,询问串的总长<=300000

题解:
好题…
要求字符串在一条链上匹配多少次,对回忆树建树是没有办法的,不能像bzoj3926那样每个叶子提出来建一棵Trie。
于是这道题是离线,对要查询的串(的正串和反串)建AC自动机。

同样,A包含串B多少次,就是A在AC自动机上的每个节点,有多少在B结尾节点的fail树子树中。
于是,DFS原树的同时,在AC自动机上匹配,

由于链是要拐弯的,同时还有方向,这个不好处理,于是把一条链拆成三部分:
拐弯处:长度为2|T|,用KMP暴力匹配
剩下的两段一个是正着匹配,一个是倒着,于是询问串需要把正反串都插入AC自动机,

于是询问都变成了从根到某个点路径的一部分(直了…),就可以一边dfs一边处理了。

就如同天天爱跑步的处理方式,在对应的起始端点push进这个询问,以及是需要加还是减,查询的是fail树的哪个子树。

然后在DFS原树同时在AC自动机上转移时,进入这个点把AC自动机上对应点+1,离开时-1,
询问就是查询fail树子树权值和。

代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<queue>
using namespace std;
const int N=110005;
const int M=300015;
queue<int> Q;
vector<int> V[N],ID[N];
char str[M];
namespace AC
{
    int ch[2*M][26],root,tail,head[2*M],nxt[2*M],to[2*M],num,fail[2*M];
    int in[2*N],out[2*M],inc,C[2*M];
    inline void build(int u,int v) {num++; to[num]=v; nxt[num]=head[u]; head[u]=num; inc=0;}
    inline void init(){root=1,tail=1,num=0;}
    inline int insert(int opt)
    {
        int len=strlen(str); int tmp=root;
        for(int j=0,i;j<len;j++)
        {
            if(opt==1) i=j; else i=len-j-1;
            int c=str[i]-'a';
            if(!ch[tmp][c]) ch[tmp][c]=++tail;
            tmp=ch[tmp][c];
        }
        return tmp;
    }
    inline void getfail()
    {
        for(int i=0;i<26;i++) if(ch[root][i]) fail[ch[root][i]]=root,build(root,ch[root][i]),Q.push(ch[root][i]); else ch[root][i]=root;
        while(!Q.empty())
        {
            int top=Q.front(); Q.pop();
            for(int i=0;i<26;i++)
            {
                if(!ch[top][i]) ch[top][i]=ch[fail[top]][i];
                else
                {
                    int u=ch[top][i];
                    fail[u]=ch[fail[top]][i];
                    build(fail[u],u);
                    Q.push(u);
                }
            }
        }       
    }
    inline void dfs(int u)
    {
        inc++; in[u]=inc;
        for(int i=head[u];i;i=nxt[i]) dfs(to[i]);
        out[u]=inc;
    }
    inline void add(int x,int d){for(int i=x;i<=inc;i=i+(i&(-i))) C[i]+=d;}
    inline int query(int x){int ret=0;for(int i=x;i;i=i-(i&(-i))) ret+=C[i]; return ret;}
}

int head[N],to[2*N],w[2*N],nxt[2*N],num=0,n,m,ans[N],fa[N],dep[N],_w[N],size[N],son[N],top[N],dfn=0,seq[N],loc[N];
int nx[M],s[M],t[M],pos[N][2];
inline void build(int u,int v,int ww)
{
    num++;
    w[num]=ww;
    to[num]=v;
    nxt[num]=head[u];
    head[u]=num;
}
inline void dfs(int u,int f)
{
    dep[u]=dep[f]+1; fa[u]=f; size[u]=1;
    for(int i=head[u];i;i=nxt[i]) 
    {
        if(to[i]==f) continue;
        _w[to[i]]=w[i];
        dfs(to[i],u);
        size[u]+=size[to[i]];
        if(size[son[u]]<size[to[i]]) son[u]=to[i];
    }
}
inline void dfs1(int u,int f,int tp)
{
    loc[u]=++dfn,seq[dfn]=u; top[u]=tp;
    if(son[u]) dfs1(son[u],u,tp);
    for(int i=head[u];i;i=nxt[i])
    {
        if(to[i]==f||to[i]==son[u]) continue;
        dfs1(to[i],u,to[i]);
    }
}
inline int getlca(int u,int v)
{
    while(top[u]!=top[v]) 
    {
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        u=fa[top[u]];
    }
    return dep[u]<dep[v]?u:v;
}
inline int getpoint(int u,int d)
{
    while(dep[u]-dep[top[u]]<d){d-=(dep[u]-dep[top[u]]+1),u=fa[top[u]];}
    return seq[loc[u]-d];
}
inline void kmp(int u,int v,int lca,int id)
{
    int lent=strlen(str),lens=0;
    int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1));
    int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1));   
    lens=dep[x]-dep[lca]+dep[y]-dep[lca];
    int tmp=x,i=0,j; while(tmp!=lca) s[i++]=_w[tmp],tmp=fa[tmp];
    tmp=y,i=1; while(tmp!=lca) s[lens-i]=_w[tmp],tmp=fa[tmp],i++;
    for(int i=0;i<lent;i++) t[i]=str[i]-'a';
    nx[0]=-1; i=0,j=-1;
    while(i<lent) 
    {
        if(j==-1||t[i]==t[j]) {i++; j++; nx[i]=j;}
        else j=nx[j];
    }
    i=0,j=0; int ret=0;
    while(i<lens)
    {
        if(j==-1||s[i]==t[j]) 
        {
            i++; j++;
            if(j==lent){ret++; j=nx[j];}
        }
        else j=nx[j];
    }
    pos[id][0]=AC::insert(1); pos[id][1]=AC::insert(-1);
    ans[id]=ret;
    if(u!=x)
    {
        ID[x].push_back(-id); ID[u].push_back(id);
        V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]);
    }
    if(v!=y)
    {
        ID[y].push_back(-id); ID[v].push_back(id);
        V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]);
    }
}
inline void dfs2(int u,int f,int x)
{
    AC::add(AC::in[x],1); 
    int sz=V[u].size(); 
    for(int i=0;i<sz;i++)
    {

        int ret=AC::query(AC::out[V[u][i]])-AC::query(AC::in[V[u][i]]-1);
        if(ID[u][i]>0) ans[ID[u][i]]+=ret;
        else ans[-ID[u][i]]-=ret;
    }
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f) continue;
        dfs2(v,u,AC::ch[x][w[i]]);
    }
    AC::add(AC::in[x],-1);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int u,v; scanf("%d%d",&u,&v); scanf("%s",str);
        build(u,v,str[0]-'a'); build(v,u,str[0]-'a');
    }
    dfs(1,1); dfs1(1,1,1); AC::init();
    for(int i=1;i<=m;i++)
    {
        int u,v; scanf("%d%d",&u,&v);  scanf("%s",str);

        if(u==v) continue; int lca=getlca(u,v); 
        kmp(u,v,lca,i);
    }
    AC::getfail();  
    AC::dfs(1);
    dfs2(1,1,1);
    for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
    return 0;
}

然后这是我原来写的倍增版本的代码(改正后),没有namespace套namespace这种鬼畜玩意儿:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<queue>
using namespace std;
const int N=100005;
const int M=300015;
const int P=17;
queue<int> Q;
vector<int> V[N],ID[N];
char str[M];
struct AC_
{
    int ch[2*M][26],root,tail,head[2*M],nxt[2*M],to[2*M],num,fail[2*M];
    int in[2*N],out[2*M],inc,C[2*M];
    void build(int u,int v) {num++; to[num]=v; nxt[num]=head[u]; head[u]=num; inc=0;}
    void init(){root=1,tail=1,num=0;}
    int insert(int opt)
    {
        int len=strlen(str); int tmp=root;
        for(int j=0,i;j<len;j++)
        {
            if(opt==1) i=j; else i=len-j-1;
            int c=str[i]-'a';
            if(!ch[tmp][c]) ch[tmp][c]=++tail;
            tmp=ch[tmp][c];
        }
        return tmp;
    }
    void getfail()
    {
        for(int i=0;i<26;i++) if(ch[root][i]) fail[ch[root][i]]=root,build(root,ch[root][i]),Q.push(ch[root][i]); else ch[root][i]=root;
        while(!Q.empty())
        {
            int top=Q.front(); Q.pop();
            for(int i=0;i<26;i++)
            {
                if(!ch[top][i]) ch[top][i]=ch[fail[top]][i];
                else
                {
                    int u=ch[top][i];
                    fail[u]=ch[fail[top]][i];
                    build(fail[u],u);
                    Q.push(u);
                }
            }
        }       
    }
    void dfs(int u)
    {
        inc++; in[u]=inc;
        for(int i=head[u];i;i=nxt[i]) dfs(to[i]);
        out[u]=inc;
    }
    void add(int x,int d){for(int i=x;i<=inc;i=i+(i&(-i))) C[i]+=d;}
    inline int query(int x){int ret=0;for(int i=x;i;i=i-(i&(-i))) ret+=C[i]; return ret;}
}AC;
int head[N],to[2*N],w[2*N],nxt[2*N],num=0,n,m,ans[N],anc[N][P+3],dep[N],_w[N];
int nx[M],s[M],t[M],pos[N][2];
void build(int u,int v,int ww)
{
    num++;
    w[num]=ww;
    to[num]=v;
    nxt[num]=head[u];
    head[u]=num;
}
void dfs1(int u,int f)
{
    dep[u]=dep[f]+1; anc[u][0]=f;
    for(int i=1;i<P;i++) anc[u][i]=anc[anc[u][i-1]][i-1];
    for(int i=head[u];i;i=nxt[i]) if(to[i]!=f) _w[to[i]]=w[i],dfs1(to[i],u);
}
inline int getlca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    int d=dep[u]-dep[v];
    for(int i=0;d;d>>=1,i++) if(d&1) u=anc[u][i];
    if(u==v) return u;
    for(int i=P-1;i>=0;i--)
    if(anc[u][i]!=anc[v][i]) u=anc[u][i],v=anc[v][i];
    return anc[u][0];
}
inline int getpoint(int u,int d) {for(int i=0;d;d>>=1,i++) if(d&1) u=anc[u][i]; return u;}
void kmp(int u,int v,int lca,int id)
{
    int lent=strlen(str),lens=0;
    int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1));
    int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1));
    lens=dep[x]-dep[lca]+dep[y]-dep[lca];
    int tmp=x,i=0,j; while(tmp!=lca) s[i++]=_w[tmp],tmp=anc[tmp][0];
    tmp=y,i=1; while(tmp!=lca) s[lens-i]=_w[tmp],tmp=anc[tmp][0],i++;
    for(int i=0;i<lent;i++) t[i]=str[i]-'a';
    nx[0]=-1; i=0,j=-1;
    while(i<lent) 
    {
        if(j==-1||t[i]==t[j]) {i++; j++; nx[i]=j;}
        else j=nx[j];
    }
    i=0,j=0; int ret=0;
    while(i<lens)
    {
        if(j==-1||s[i]==t[j]) 
        {
            i++; j++;
            if(j==lent){ret++; j=nx[j];}
        }
        else j=nx[j];
    }
    pos[id][0]=AC.insert(1); pos[id][1]=AC.insert(-1);
    ans[id]=ret;
    if(u!=x)
    {
        ID[x].push_back(-id); ID[u].push_back(id);
        V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]);
    }
    if(v!=y)
    {
        ID[y].push_back(-id); ID[v].push_back(id);
        V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]);
    }
}
void dfs2(int u,int f,int x)
{
    AC.add(AC.in[x],1); 
    int sz=V[u].size(); 
    for(int i=0;i<sz;i++)
    {   
        int ret=AC.query(AC.out[V[u][i]])-AC.query(AC.in[V[u][i]]-1);
        if(ID[u][i]>0) ans[ID[u][i]]+=ret;
        else ans[-ID[u][i]]-=ret;
    }
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f) continue;
        dfs2(v,u,AC.ch[x][w[i]]);
    }
    AC.add(AC.in[x],-1);
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int u,v; scanf("%d%d%s",&u,&v,str);
        build(u,v,str[0]-'a'); build(v,u,str[0]-'a');
    }
    dfs1(1,1); AC.init();
    for(int i=1;i<=m;i++)
    {
        int u,v; scanf("%d%d",&u,&v); scanf("%s",str);
        if(u==v) continue; int lca=getlca(u,v); 
        kmp(u,v,lca,i);
    }
    AC.getfail();   
    AC.dfs(1);
    dfs2(1,1,1);
    for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
    return 0;
}

昨晚上脑子有点不清楚这里写图片描述

自己没写对拍就一直交,卡了一晚上评测,非常抱歉,感谢大家的不杀之恩。

最后查出来的错:
1.求lca时 swap(u,u);
2.kmp忘记 j=nxt[j];
3.自己nxt[N]nx[N]两个数组搞混了。

最初以为是倍增慢了,没有去查无限循环的错,改成链剖还是T,才想到是不是哪里写挂了,
长代码一定要静态查错+对拍。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
题目描述 有一个 $n$ 个点的棋盘,每个点上有一个数字 $a_i$,你需要从 $(1,1)$ 走到 $(n,n)$,每次只能往右或往下走,每个格子只能经过一次,路径上的数字和为 $S$。定义一个点 $(x,y)$ 的权值为 $a_x+a_y$,求所有满足条件的路径中,所有点的权值和的最小值。 输入格式 第一行一个整数 $n$。 接下来 $n$ 行,每行 $n$ 个整数,表示棋盘上每个点的数字。 输出格式 输出一个整数,表示所有满足条件的路径中,所有点的权值和的最小值。 数据范围 $1\leq n\leq 300$ 输入样例 3 1 2 3 4 5 6 7 8 9 输出样例 25 算法1 (形dp) $O(n^3)$ 我们可以先将所有点的权值求出来,然后将其看作是一个有权值的图,问题就转化为了在这个图中求从 $(1,1)$ 到 $(n,n)$ 的所有路径中,所有点的权值和的最小值。 我们可以使用形dp来解决这个问题,具体来说,我们可以将这个图看作是一棵,每个点的父节点是它的前驱或者后继,然后我们从根节点开始,依次向下遍历,对于每个节点,我们可以考虑它的两个儿子,如果它的两个儿子都被遍历过了,那么我们就可以计算出从它的左儿子到它的右儿子的路径中,所有点的权值和的最小值,然后再将这个值加上当前节点的权值,就可以得到从根节点到当前节点的路径中,所有点的权值和的最小值。 时间复杂度 形dp的时间复杂度是 $O(n^3)$。 C++ 代码 算法2 (动态规划) $O(n^3)$ 我们可以使用动态规划来解决这个问题,具体来说,我们可以定义 $f(i,j,s)$ 表示从 $(1,1)$ 到 $(i,j)$ 的所有路径中,所有点的权值和为 $s$ 的最小值,那么我们就可以得到如下的状态转移方程: $$ f(i,j,s)=\min\{f(i-1,j,s-a_{i,j}),f(i,j-1,s-a_{i,j})\} $$ 其中 $a_{i,j}$ 表示点 $(i,j)$ 的权值。 时间复杂度 动态规划的时间复杂度是 $O(n^3)$。 C++ 代码

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值