题目大意:有N个字符串,可以对字符串使用三种操作,增加一个字符,删除一个字符,替换一个字符,使其变成另一个字符串,如果这个字符串是所给的字符串的其中一个,且这个字符串再被操作字符串的后面的话就表示阶数加1,现给出N个按字典序排列的字符串,问最大的阶数是多少
解题思路:本来用暴力的,TLE了,去看了别人的题解,发现这题可以转换成一个有向无环图,只要找到这个有向无环图的最长边,就可以求出最大的阶数,参考了staginner大神的方法,由一个点向外扩张,先不建图,而是通过判断进行三次操作的其中一种后所变成的字符串是否存在,如果存在的话,两个点就可以连接起来了,就继续搜索下去,由于是从前往后扩张的且是按字典序排序的,所以后面的字符串要大于前面的字符串,再通过记忆化搜索的方法,就不会TLE了。
staginner的思路:点击打开链接
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 25010
#define HASH 1000003
#define INF 0x7fffffff
using namespace std;
char str[maxn][20],temp[20];
int N,head[HASH],next[maxn],f[maxn];
int hash(char* s){
int seed = 131,h = 0;
while(*s)
h = h * seed + *(s++);
return (h & INF ) % HASH;
}
void insert(int n) {
int h;
h = hash(str[n]);
next[n] = head[h];
head[h] = n;
}
void init() {
N = 0;
memset(head,-1,sizeof(head));
memset(f,-1,sizeof(f));
while(scanf("%s",str[N]) == 1)
insert(N++);
}
int search() {
int h = hash(temp),i;
for(i = head[h]; i != -1; i = next[i])
if(strcmp(str[i],temp) == 0)
break;
return i;
}
void add(int cur,int pos,int n) {
int i,j;
for(i = 0, j = 0; i < pos; i++,j++)
temp[j] = str[cur][i];
temp[j++] = 'a' + n;
for(; str[cur][i]; i++,j++)
temp[j] = str[cur][i];
temp[j] = '\0';
}
void del(int cur,int pos) {
int i , j;
for(i = 0,j = 0; i < pos; i++,j++)
temp[j] = str[cur][i];
i++;
for(;str[cur][i] ; i++,j++)
temp[j] = str[cur][i];
temp[j] = '\0';
}
void change(int cur,int pos,int n) {
strcpy(temp,str[cur]);
temp[pos] = 'a' + n;
}
int dp(int cur) {
if(f[cur] != -1)
return f[cur];
int len = strlen(str[cur]);
int s,t,max = 0;
for(int i = 0; i <= len; i++)
for(int j = 0; j < 26; j++) {
add(cur,i,j);
s = search();
if(s != -1 && strcmp(str[cur],temp) < 0) {
t = dp(s);
if(t + 1 > max)
max = t + 1;
}
}
for(int i = 0; i < len; i++) {
del(cur,i);
s = search();
if(s != -1 && strcmp(str[cur],temp) < 0) {
t = dp(s);
if(t + 1 > max)
max = t + 1;
}
}
for(int i = 0; i < len; i++)
for(int j = 0; j < 26; j++) {
change(cur,i,j);
s = search();
if(s != -1 && strcmp(str[cur],temp) < 0) {
t = dp(s);
if( t + 1 > max)
max = t + 1;
}
}
return f[cur] = max;
}
void solve() {
int ans = 0;
for(int i = 0; i < N; i++) {
ans = max(ans,dp(i));
}
printf("%d\n",ans + 1);
}
int main() {
init();
solve();
return 0;
}