题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3222
这题写得我巨tm难受。很自然的思路就是把两个字符串的哈希值求出比较,如果不同,再二分位置。但是真的很难写。
先定义一个哈希函数 p p 和模数:
const int p=1e9+7;
const int md=998244353;
首先是字符串解析,我们考虑递归地做,遇到‘[’就去找对应的‘]’,然后递归地进入,否则直接读出串。这样我们就可以建出两棵树,这部分代码如下:
void work(char *s,int cur,int be,int en)
{
cnt++;
G[cur].push_back(cnt);
if(s[be]=='[')
{
int sum=0;
int i=be;
for(;i<=en;i++)
{
if(s[i]=='[')sum++;else if(s[i]==']')sum--;
if(!sum)break;
}
int k=i+1;
long long num=0;
while(k<=en && s[k]>='0' && s[k]<='9')
{
num=num*10LL+(s[k]-'0');
k++;
}
tme[cnt]=num;
work(s,cnt,be+1,i-1);
if(k<=en)work(s,cur,k,en);
}
else
{
int i=be;
val[cnt].l=1;
val[cnt].x=0;
ln[cnt]=0;
str[cnt]="";
while(i<=en && s[i]!='[')
{
val[cnt].x=(1LL*val[cnt].x*p%md+s[i])%md;
val[cnt].l=1LL*val[cnt].l*p%md;
ln[cnt]++;
str[cnt]+=s[i];
i++;
}
if(i<=en)work(s,cur,i,en);
}
}
建完树以后我们对于每一个节点算出它们的长度和哈希值,这会用到快速幂,我是用了一个struct表示一个串并重载了乘法运算,写起来很自然:
struct st{
int x,l;
st operator*(st u)const
{
return (st){(1LL*x*u.l%md+u.x)%md,1LL*l*u.l%md};
}
};
st qmul(st k,long long pw)//快速幂
{
if(pw==0)return (st){0,1};
if(pw==1)return k;
st res=qmul(k,pw>>1LL);
res=res*res;
if(pw&1LL)res=res*k;
return res;
}
void dfs(int x)//计算每个节点的长度及哈希值
{
if(!G[x].size())return;
ln[x]=0;
val[x]=(st){0,1};
for(int i=0;i<G[x].size();i++)
{
dfs(G[x][i]);
ln[x]+=ln[G[x][i]];
val[x]=val[x]*val[G[x][i]];
}
ln[x]*=tme[x];
val[x]=qmul(val[x],tme[x]);
}
最后就是二分长度加求出哈希值比较了,求出一定长度的哈希值的代码:
st solve(int x,long long pw)
{
if(!pw)return (st){0,1};
if(pw>=ln[x])return val[x];
if(!G[x].size())
{
int v1=0,v2=1;
for(int i=0;i<pw;i++)
{
v1=(1LL*v1*p%md+str[x][i])%md;
v2=1LL*v2*p%md;
}
return (st){v1,v2};
}
st res=(st){0,1};
long long L=0;
for(int i=0;i<G[x].size();i++)L+=ln[G[x][i]],res=res*val[G[x][i]];
res=qmul(res,pw/L);
pw%=L;
for(int i=0;i<G[x].size();i++)
{
if(pw>=ln[G[x][i]])
{
res=res*val[G[x][i]];
pw-=ln[G[x][i]];
}
else
{
res=res*solve(G[x][i],pw);
break;
}
}
return res;
}
现在大家应该都懂怎么做了吧,那么最重要的当然是算出总的复杂度喽,设 |s1|,|s2| | s 1 | , | s 2 | 为实际串长,则二分的复杂度为 Θ(log2max(|s1|,|s2|)) Θ ( l o g 2 m a x ( | s 1 | , | s 2 | ) ) ,很容易证明实际字符串长在 long long l o n g l o n g 范围内,所以大约为 64 64 ,快速幂套在二分的 log l o g 内,为 Θ(log2(|s1|+|s2|)) Θ ( l o g 2 ( | s 1 | + | s 2 | ) ) ,也当做 64 64 ,所以总的复杂度约为 Θ(642⋅2⋅20⋅T) Θ ( 64 2 · 2 · 20 · T ) 。实际跑得飞快,只需 0ms 0 m s 。
全部代码如下:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<vector>
using namespace std;
const int p=1e9+7;
const int md=998244353;
struct st{
int x,l;
st operator*(st u)const
{
return (st){(1LL*x*u.l%md+u.x)%md,1LL*l*u.l%md};
}
};
int t;
char s1[21],s2[21];
long long tme[41];
long long ln[41];
st val[41];
vector<int>G[41];
string str[41];
int cnt;
st qmul(st k,long long pw)
{
if(pw==0)return (st){0,1};
if(pw==1)return k;
st res=qmul(k,pw>>1LL);
res=res*res;
if(pw&1LL)res=res*k;
return res;
}
void work(char *s,int cur,int be,int en)
{
cnt++;
G[cur].push_back(cnt);
if(s[be]=='[')
{
int sum=0;
int i=be;
for(;i<=en;i++)
{
if(s[i]=='[')sum++;else if(s[i]==']')sum--;
if(!sum)break;
}
int k=i+1;
long long num=0;
while(k<=en && s[k]>='0' && s[k]<='9')
{
num=num*10LL+(s[k]-'0');
k++;
}
tme[cnt]=num;
work(s,cnt,be+1,i-1);
if(k<=en)work(s,cur,k,en);
}
else
{
int i=be;
val[cnt].l=1;
val[cnt].x=0;
ln[cnt]=0;
str[cnt]="";
while(i<=en && s[i]!='[')
{
val[cnt].x=(1LL*val[cnt].x*p%md+s[i])%md;
val[cnt].l=1LL*val[cnt].l*p%md;
ln[cnt]++;
str[cnt]+=s[i];
i++;
}
if(i<=en)work(s,cur,i,en);
}
}
void dfs(int x)
{
if(!G[x].size())return;
ln[x]=0;
val[x]=(st){0,1};
for(int i=0;i<G[x].size();i++)
{
dfs(G[x][i]);
ln[x]+=ln[G[x][i]];
val[x]=val[x]*val[G[x][i]];
}
ln[x]*=tme[x];
val[x]=qmul(val[x],tme[x]);
}
st solve(int x,long long pw)
{
if(!pw)return (st){0,1};
if(pw>=ln[x])return val[x];
if(!G[x].size())
{
int v1=0,v2=1;
for(int i=0;i<pw;i++)
{
v1=(1LL*v1*p%md+str[x][i])%md;
v2=1LL*v2*p%md;
}
return (st){v1,v2};
}
st res=(st){0,1};
long long L=0;
for(int i=0;i<G[x].size();i++)L+=ln[G[x][i]],res=res*val[G[x][i]];
res=qmul(res,pw/L);
pw%=L;
for(int i=0;i<G[x].size();i++)
{
if(pw>=ln[G[x][i]])
{
res=res*val[G[x][i]];
pw-=ln[G[x][i]];
}
else
{
res=res*solve(G[x][i],pw);
break;
}
}
return res;
}
int main()
{
scanf("%d",&t);
for(int d=1;d<=t;d++)
{
scanf("%s%s",s1,s2);
int l1=strlen(s1),l2=strlen(s2);
cnt=1;
int rt1=cnt;
tme[rt1]=1;
work(s1,cnt,0,l1-1);
dfs(rt1);
cnt++;
int rt2=cnt;
tme[rt2]=1;
work(s2,cnt,0,l2-1);
dfs(rt2);
long long l=0,r=1e18;
while(l<=r)
{
long long mid=(l+r)>>1LL;
if(solve(rt1,mid).x!=solve(rt2,mid).x)r=mid-1;else l=mid+1;
}
printf("Case #%d: ",d);
if(r==1e18)puts("YES");
else
{
printf("NO %lld\n",r+1);
}
for(int i=1;i<=cnt;i++)G[i].clear();
}
return 0;
}