《算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
“ 前缀压缩技术” ,链接: http://oj.ecustacm.cn/problem.php?id=2140
题目描述
【题目描述】 现有一个网站有 n 个用户,每个用户名均为长度为 L 个字符串,存储所有用户名需要存储 n*L 个字符。
现在有一个压缩内存的方法:不必存储每个用户完整用户名,只需存储该用户名的前缀,只要没有其他用户名具有相同的前缀即可。
例如,如果只有 james 和 jacob 这两个名称,仅存储 jam 和 jac,并仍能够识别它们两个。
利用该压缩技术,存储 n 个用户名需要多少个字符?
【输入格式】 第一行为正整数 n,L,2≤n≤10000,1≤L≤1000。输入保证 n * L ≤ 1000000。
接下来 n 行,每行一个长度为 L 的字符串,仅包含小写字母,且所有字符串互不相同。
【输出格式】 输出一个整数表示答案。
【输入样例】
样例1:
2 5
james
jacob
样例2:
4 4
xxxx
yxxx
xxyx
yxxy
【输出样例】
样例1:
6
样例2:
14
题解
用两种方法求解:模拟、字典树。
【重点】 字典树。
C++代码-1
第一种方法是模拟。题目要求统计有多少个不同的前缀,步骤如下:
(1)对所有字符串排序;
(2)连续比较相邻的2个字符串的前缀,遇到不相同的字符就停止,记录一次前缀长度,然后继续比较下两个相邻的字符串。
#include<bits/stdc++.h>
using namespace std;
string s[10001];
int ans[10001];
int main(){
int n, L; cin >> n >> L;
for(int i = 0; i < n; i++) cin >> s[i];
sort(s, s + n); //字符串排序
int sum = 0;
for(int i = 0; i < n - 1; i++){ //连续比较字符串s[i] 和 s[i+1]
for(int j = 0; j < L; j++){
ans[i] = max(ans[i], j+1);
ans[i+1] = max(ans[i+1], j+1);
if(s[i][j] != s[i+1][j]) break; //第j个字符不同
}
sum += ans[i]; //统计第i个字符串的前缀:它的前j个字符和其他字符都不同
}
cout<<sum + ans[n-1];
return 0;
}
C++代码-2
下面用字典树编码(字典树模板,见《算法竞赛》清华大学出版社,罗勇军,郭卫斌著,562页)。
首先建一棵字典树tree[]。tree[u].son[v]表示编号为u的节点的下一个字母为v的节点的编号。
然后用Insert()把所有字符串插入到字典树中。从字典树根节点开始找这个字符串中的字符是否出现,如果出现过,这个前缀的数量num++。如果num=1,说明这个前缀只出现过1次,是独一无二的。
最后用Find()统计所有不同的前缀的总长度,也就是查询num=1的前缀,累加它们的长度。
#include<bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
struct node{
int son[26]; //26个字母
int num; //这个前缀出现的次数
}t[N]; //字典树tree[u].son[v]:编号为u的节点的下一个字母为v的节点的编号
int cnt = 1; //当前分配的存储位置。把cnt=0留给根节点
void Insert(char *s) { //往字典树中插入字符串s
int now = 0; //从根节点开始找插入前缀的位置
for(int i = 0; s[i]; i++) {
int ch = s[i]-'a';
if(t[now].son[ch]==0) //这个字符还没有存储过
t[now].son[ch] = cnt++; //把cnt位置分配给这个字符
now = t[now].son[ch]; //沿着字典树往下走
t[now].num++; //统计这个前缀出现多少次
}
}
int Find(int u, int len){
if(t[u].num == 1) return len; //这个前缀只出现了1次,说明是唯一的前缀
int ans = 0;
for(int i = 0; i < 26; i++) //遍历整个字典树,统计只出现一次的前缀
if(t[u].son[i])
ans += Find(t[u].son[i], len+1);
return ans;
}
char s[1010];
int main(){
int n, L; scanf("%d%d", &n, &L);
while(n--){
scanf("%s", s);
Insert(s);
}
cout<<Find(0, 0)<<endl; //从根节点0开始找,此时字符长度为0
return 0;
}
Java代码-1
模拟
import java.util.Arrays;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int L = scanner.nextInt();
String[] s = new String[n];
for (int i = 0; i < n; i++) s[i] = scanner.next();
Arrays.sort(s); // 字符串排序
int[] ans = new int[n];
int sum = 0;
for (int i = 0; i < n - 1; i++) { // 连续比较字符串s[i] 和 s[i+1]
for (int j = 0; j < L; j++) {
ans[i] = Math.max(ans[i], j+1);
ans[i+1] = Math.max(ans[i+1],j+1);
if (s[i].charAt(j) != s[i + 1].charAt(j)) break; // 第j个字符不同
}
sum += ans[i]; // 统计第i个字符串的前缀:它的前j个字符和其他字符都不同
}
System.out.println(sum + ans[n - 1]);
}
}
Java代码-2
字典树
import java.util.*;
public class Main {
static class Node {
int[] son = new int[26];
int num;
}
static Node[] t = new Node[1000010];
static int cnt = 1;
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
int n = input.nextInt();
int L = input.nextInt();
for (int i = 0; i < t.length; i++)
t[i] = new Node();
for (int i = 0; i < n; i++) {
String s = input.next();
Insert(s);
}
System.out.println(Find(0, 0));
}
static void Insert(String s) {
int now = 0;
for (int i = 0; i < s.length(); i++) {
int ch = s.charAt(i) - 'a';
if (t[now].son[ch] == 0) {
t[now].son[ch] = cnt++;
}
now = t[now].son[ch];
t[now].num++;
}
}
static int Find(int u, int len) {
if (t[u].num == 1) return len;
int ans = 0;
for (int i = 0; i < 26; i++)
if (t[u].son[i] != 0)
ans += Find(t[u].son[i], len + 1);
return ans;
}
}
Python代码-1
模拟
n, L = map(int, input().split())
s = []
for i in range(n): s.append(input())
s.sort() # 字符串排序
ans = [0] * n
sum = 0
for i in range(n - 1): # 连续比较字符串s[i] 和 s[i+1]
for j in range(L):
ans[i] = max(ans[i], j + 1)
ans[i+1] = max(ans[i + 1], j + 1)
if s[i][j] != s[i+1][j]: break # 第j个字符不同
sum += ans[i] # 统计第i个字符串的前缀:它的前j个字符和其他字符都不同
print(sum + ans[n - 1])
Python代码-2
字典树
import sys
sys.setrecursionlimit(1000000)
class Node:
def __init__(self):
self.son = [0] * 26
self.num = 0
t = [Node() for _ in range(1000010)]
cnt = 1
def Insert(s):
global cnt
now = 0
for c in s:
ch = ord(c) - ord('a')
if t[now].son[ch] == 0:
t[now].son[ch] = cnt
cnt += 1
now = t[now].son[ch]
t[now].num += 1
def Find(u, length):
if t[u].num == 1: return length
ans = 0
for i in range(26):
if t[u].son[i] != 0:
ans += Find(t[u].son[i], length + 1)
return ans
n, L = map(int, input().split())
for _ in range(n):
s = input().strip()
Insert(s)
print(Find(0, 0))