C. Code a Trie
大佬题解,代码基本就是抄的
对于每一个值计算所有串的LCA,也就是最长公共前缀,将该节点(Trie树的节点)标记,对于这些字符串在LCA下面的点一定不存在(如果存在他们不会返回相同的值)
每个Trie树中的节点只能被标记一次,并且从跟到LCA路径上的变必须存在
dfs贪心计算每个子树中最少的节点
插入时统计cnt[u]
表示它的子树中被标记为LCA的点的数量
- 如果cnt[u]>1,这个点必选,如果说该节点没被标记为LCA,那么它可以替代它一个儿子称为那个值的LCA,如果被标记为LCA,它的儿子被标记那就必须选。
- 如果cnt[u]=1,贪心选择该点,儿子不选
- 如果cnt[u]=0,贪心不选
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N=500010;
int n,a[N],b[N],cnt[N],ans;
string s[N];
vector<string> g[N];
int t[N][27],idx;
bool lca[N];
void init()
{
cin>>n;
int m=0;
for(int i=1;i<=n;i++) {
cin>>s[i]>>a[i];
m+=s[i].size();
}
idx=0;
for(int i=1;i<=n;i++) g[i].clear();
for(int i=0;i<=m;i++) memset(t[i],0,sizeof t[i]);
for(int i=0;i<=m;i++) lca[i]=0,cnt[i]=0;
}
bool check(vector<string> &v) {
//暴力寻找LCA
sort(v.begin(),v.end(),[&](string a, string b) {
return a.size()<b.size();
});
int len=0;
for(int i=0;i<v[0].size();i++) {
int ok=1;
for(int j=0;j<v.size()&&ok;j++)
if(v[j][i]!=v[0][i]) ok=0;
if(ok) len++;
else break;
}
// Trie树插入
int p=0;
for(int i=0;i<len;i++) {
cnt[p]++;//子树中的lca
int c=v[0][i]-'a';
if(t[p][c]==-1) return 0; //节点不存在
if(!t[p][c]) t[p][c]=++idx;
p=t[p][c];
}
if(lca[p]) return 0;
lca[p]=1;
cnt[p]++;
// lca 后面的一定不存在 打上标记
for(int i=0;i<v.size();i++) {
if(v[i].size()<=len) continue;
int c=v[i][len]-'a';
if(t[p][c]>0) return 0;// 不存在的点存在了
t[p][c]=-1;
}
return 1;
}
void dfs(int u) {
if(cnt[u]>1) ans++;
bool fl=lca[u]==0;
for(int i=0;i<26;i++) {
int v=t[u][i];
if(!v||v==-1) continue;
if(cnt[v]==1) {
if(!fl)
ans++;
else
fl=0;
}
else dfs(v);
}
}
void co(int cid, int x) {
cout << "Case #" << cid << ": " << x<<'\n';
}
void work(int cid) {
int m=0;
for(int i=1;i<=n;i++)
b[++m]=a[i];
// 离散化
sort(b+1,b+1+m);
m=unique(b+1,b+1+m)-b-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
// 统计相同值的字符串
for(int i=1;i<=n;i++) g[a[i]].push_back(s[i]);
// 判断进行插入
for(int i=1;i<=m;i++)
if(!check(g[i])) {
co(cid, -1);
return;
}
ans=0;
cnt[0]++; //根节点必须选
dfs(0);
co(cid, ans);
}
int main() {
ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
int T=1;
cin>>T;
for(int i=1;i<=T;i++) {
init();
work(i);
}
return 0;
}