不枉我给C++作业出过十六进制转换、字节解析、字节序、位模式等题目,这次擦边。
二进制解析难测难调,但题目“输入数据保证是合法的”就很客气,而且CSP不罚时,可以放心用断言验证条件。
实现的技巧:
- 函数式风格,短函数实现,使其功能单一,易于测试。
- 把参数解析和数据解压分开,简化函数。
- 用断言把样例都做成单元测试,这样实现出来就是正确的,不用调试。
#include <bits/stdc++.h>
using namespace std;
using ub_t = unsigned char; // 无符号字节类型
using cp_t = const ub_t*; // 常性字节指针
using cc_t = const ub_t* const; // 常性字节指针常量
constexpr int hex_value(char c) { // 十六进制数字转值
return (c >= '0' && c <= '9') ? (c - '0') : (c - 'a' + 10);
}
constexpr ub_t parse_byte(char h, char l) { // 拼装字节
return (hex_value(h) << 4) | hex_value(l); // 1-hex ~ 4-bit
}
// 从ist读取n字节十六进制数(2*n个字符)存到pd缓冲区,返回数据结束位置。
ub_t* read_bytes(ub_t* pd, cc_t pe, istream& ist, const int n) {
assert(pd + n <= pe);
char h, l; // 十六进制高、低半字节
for(int i = 0; i < n && ist >> h >> l; ++i) {
*pd++ = parse_byte(h, l); // 及时递增
}
assert(ist); // 不应EOF
return pd;
}
// 解析导引域,获取原始数据长度,返回未解字节位置
cp_t raw_length(cp_t ps, cc_t pe, int& g) {
g = 0;
int r = 1;
do {
assert(ps < pe);
g += (*ps & 0x7F) * r; // 低7位是值,相当于128进制。
r *= 128;
} while(*ps++ & 0x80); // 长度末字节最高位是0
return ps;
}
// 解析字面量元素的长度,返回未解字节位置
cp_t literal_length(cp_t ps, cc_t pe, int& l) {
assert(ps < pe);
ub_t b = *ps++;
b >>= 2;
if(b < 60) { // l <= 60时
l = b + 1; // 第一个字节的高6位表示(l − 1)
} else {
b -= 59; // 60~63分别表示后续有1~4字节长度
assert(ps + b <= pe);
l = 0;
for(int r = 1; b > 0; --b, ++ps) {
l += *ps * r;
r *= 256;
}
++l;
}
return ps;
}
// 从(p-o)处反复提取直到l字节,填充到p处,返回数据结束位置。
ub_t* extract(ub_t* p, cc_t pe, int o, int l) {
assert(p + l <= pe); // 缓冲区够大
for(cc_t pb = p - o; l > 0;) { // 起始位置pb不变
const int d = min(o, l);
memcpy(p, pb, d);
p += d;
l -= d;
}
return p;
}
cp_t back_ref_1(cp_t ps, cc_t pe, int& o, int& l) {
assert(ps + 2 <= pe);
l = ((*ps & 0b00011100) >> 2) + 4;
o = ((*ps & 0b11100000) << 3) | ps[1];
assert(l >= 4 && l <= 11);
assert(o >= 1 && o <= 2047);
return ps + 2;
}
cp_t back_ref_2(cp_t ps, cc_t pe, int& o, int& l) {
assert(ps + 3 <= pe);
l = (*ps >> 2) + 1;
o = (ps[2] * 256) + ps[1];
assert(l >= 1 && l <= 64);
assert(o >= 1 && o <= 65535);
return ps + 3;
}
ub_t* decompress(ub_t* pd, cc_t de, cp_t ps, cc_t pe) {
for(int o, l; ps < pe;) {
switch(*ps & 3) { // 每个元素的第一个字节的最低两位表示了元素的类型。
case 0: // 当最低两位为 0 时,表示这是一个字面量
ps = literal_length(ps, pe, l);
assert(l >= 1);
assert(ps + l <= pe);
assert(pd + l <= de);
memcpy(pd, ps, l); // 字面量包含一些字节,解压时直接将其输出即可。
ps += l;
pd += l;
break;
case 1: // 回溯引用1型
ps = back_ref_1(ps, pe, o, l);
pd = extract(pd, de, o, l);
break;
case 2: // 回溯引用2型
ps = back_ref_2(ps, pe, o, l);
pd = extract(pd, de, o, l);
break;
default: // 元素的首字节的最低两位不允许是11
assert(false); // 输入数据保证是合法的
}
}
return pd;
}
void print(cp_t ps, cc_t pe, const char hex_digits[16]) {
for(int i = 0; ps != pe; ++ps, ++i) {
cout << hex_digits[*ps >> 4] << hex_digits[*ps & 0xF];
if((i % 8) == 7) {
cout << '\n';
}
}
}
void test(const char hex_digits[16]) {
#ifndef ONLINE_JUDGE
{
assert(hex_value('0') == 0);
assert(hex_value('9') == 9);
assert(hex_value('a') == 10);
assert(hex_value('f') == 15);
}
{
assert(parse_byte('0', '0') == 0x00U);
assert(parse_byte('0', '8') == 0x08U);
assert(parse_byte('8', '0') == 0x80U);
assert(parse_byte('f', 'f') == 0xffU);
}
{
istringstream iss(hex_digits);
ub_t buf[10];
cp_t pt = read_bytes(buf, buf + sizeof(buf), iss, 8);
assert(pt == buf + 8);
assert(buf[0] == 0x01);
assert(buf[1] == 0x23);
assert(buf[2] == 0x45);
assert(buf[3] == 0x67);
assert(buf[4] == 0x89);
assert(buf[5] == 0xAB);
assert(buf[6] == 0xCD);
assert(buf[7] == 0xEF);
}
{
int g = -1;
const ub_t s1[] = "\xAC\x0A";
auto pt = raw_length(s1, s1 + sizeof(s1), g);
assert(pt = s1 + 2);
assert(g == 1324);
}
{
int l;
const ub_t s[] = "\xE8";
auto pt = literal_length(s, s + sizeof(s), l);
assert(pt == s + 1);
assert(l == 58 + 1);
}
{
int l;
const ub_t s[] = "\xF4\x01\x0A";
auto pt = literal_length(s, s + sizeof(s), l);
assert(pt == s + 3);
assert(l == (0x0A01 + 1));
}
{
int o, l;
const ub_t s[] = "\x2D\x1A";
auto pt = back_ref_1(s, s + sizeof(s), o, l);
assert(pt = s + 2);
assert(o == 282);
assert(l == 7);
}
{
int o, l;
const ub_t s[] = "\x3E\x1A\x01";
auto pt = back_ref_2(s, s + sizeof(s), o, l);
assert(pt = s + 3);
assert(o == 282);
assert(l == 16);
}
#endif
}
const char hex_digits[] = "0123456789abcdef";
constexpr int N = 1024 * 1024 * 3;
ub_t vs[N], vd[N]; // 源(src)、目的(dst)数据
int main() {
ios::sync_with_stdio(false); // IO多达几十万项,要优化
cin.tie(0);
cout.tie(0);
test(hex_digits);
int s = 1e9, d = 1e9; // 源数据字节数、解压数据字节数
cin >> s >> ws;
assert(s <= N); // 源缓冲区够大
cc_t se = read_bytes(vs, vs + N, cin, s);
assert(se == vs + s);
cc_t ps = raw_length(vs, se, d);
assert(d <= N); // 源缓冲区够大
cc_t de = decompress(vd, vd + d, ps, se);
assert(de == vd + d);
print(vd, de, hex_digits);
return 0;
}