毒瘤
题目背景:
分析:虚树 + DP
毒瘤到了直接拿毒瘤当题目,不想说什么,不过部分分比较高还是灰常资瓷的,首先,对于80分的部分分就是先2d(d = m - n + 1)枚举每一条非树边的其中一个端点的状态,然后直接DP。
Source:
/*
created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
#include <ctime>
#include <bitset>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
// /*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN];
char *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
for (int i = 0; i <= 10; ++i) std::cout << obuf[i];
std::cout << '\n';
fwrite(obuf, 1, oh - obuf, stdout), oh = obuf;
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 100000 + 9;
const int mod = 998244353;
std::vector<int> edge[MAXN];
std::vector<int> down[MAXN];
int cnt, n, m, x, y;
int dp[MAXN][2], mark[MAXN], dep[MAXN], id[20];
bool vis[MAXN];
inline void add_edge(int x, int y) {
edge[x].push_back(y), edge[y].push_back(x);
}
inline void dfs1(int cur, int fa) {
vis[cur] = true, dep[cur] = dep[fa] + 1;
int s = 0;
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (v == fa) continue ;
if (vis[v] && dep[v] < dep[cur]) {
down[cur].push_back(v);
if (down[cur].size() == 1) id[cnt++] = cur;
}
else if (!vis[v]) dfs1(v, cur);
}
}
inline void dfs(int cur) {
vis[cur] = true;
dp[cur][0] = 1, dp[cur][1] = 1;
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (!vis[v]) {
dfs(v), dp[cur][1] = (long long)dp[cur][1] * dp[v][0] % mod;
dp[cur][0] = (long long)dp[cur][0] * (dp[v][1] + dp[v][0]) % mod;
}
}
if (mark[cur] != -1) dp[cur][mark[cur] ^ 1] = 0;
}
inline void solve() {
R(n), R(m);
long long ans = 0;
for (int i = 1; i <= m; ++i) R(x), R(y), add_edge(x, y);
dfs1(1, 0);
for (int i = 1; i <= n; ++i) mark[i] = -1;
for (int i = 0, end = (1 << cnt); i < end; ++i) {
bool flag = true;
for (int j = 0; j < cnt; ++j) {
int pos = id[j];
for (int k = 0; k < down[pos].size(); ++k) {
int v = down[pos][k];
mark[v] = -1;
}
mark[pos] = -1;
}
for (int j = 0; j < cnt; ++j)
if (i & (1 << j)) {
int pos = id[j];
if (mark[pos] == 0) {
flag = false;
break ;
} else mark[pos] = 1;
for (int k = 0; k < down[pos].size(); ++k) {
int v = down[pos][k];
if (mark[v] == 1) {
flag = false;
break ;
} else mark[v] = 0;
}
} else mark[id[j]] = 0;
if (flag == false) continue ;
for (int j = 1; j <= n; ++j) vis[j] = false;
dfs(1), ans += dp[1][1] + dp[1][0];
}
std::cout << ans % mod;
}
int main() {
clock_t start = clock();
freopen("duliu.in", "r", stdin);
freopen("duliu.out", "w", stdout);
solve();
std::cerr << clock() - start << '\n';
return 0;
}
考虑优化,其实,这2d次DP的很多地方的转移是重复的,有所改变的只是非树边的端点位置的DP方式,那么我们可以将这些点建出虚树,然后直接统计虚树上每一条边的转移贡献,我们需要现在原树上DP出需要用的系数,定义g[x][0/1]表示f[x][0/1]的常数,即没有关键点的儿子的贡献,然后定义k[x][0/1][0/1]表示,x下面(包括儿子)的第一个关键点对于x的转移系数,用dp方程表示就是:
f[x][0] = k[x][0][0] * f[v][0] + k[x][0][1] * f[v][1]
f[x][1] = k[x][1][0] * f[v][0] + k[x][1][1] * f[v][1]
边界显然就是,对于一个关键点,
k[x][0][0] = k[x][1][1] = 1
k[x][1][0] = k[x][0][1] = 0
dp出k和g的值,直接在虚树转移即可。
Source:
/*
created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
#include <ctime>
#include <bitset>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
///*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN];
char *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout), oh = obuf;
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 200000 + 10;
const int mod = 998244353;
int n, m, x, y, cnt;
std::vector<int> edge[MAXN];
int up[20], down[20], size[MAXN], dep[MAXN];
bool is_key[MAXN], ban[MAXN][2], vis[MAXN];
long long g[MAXN][2], f[MAXN][2];
struct node {
long long x, y;
node() {}
node(long long x, long long y) : x(x), y(y) {}
inline node operator + (const node &a) const {
return node((x + a.x) % mod, (y + a.y) % mod);
}
inline node operator * (const long long &a) const {
return node(x * a % mod, y * a % mod);
}
} k[MAXN][2];
struct data {
int to;
node f0, f1;
data() {}
data(int to, node f0, node f1) : to(to), f0(f0), f1(f1) {}
} ;
std::vector<data> key[MAXN];
inline void add_edge(int x, int y) {
edge[x].push_back(y), edge[y].push_back(x);
}
inline void dfs(int cur, int fa) {
dep[cur] = dep[fa] + 1, vis[cur] = true;
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (v == fa) continue ;
if (vis[v] && dep[v] < dep[cur]) {
is_key[cur] = true, is_key[v] = true;
down[cnt] = cur, up[cnt++] = v;
} else if (!vis[v]) dfs(v, cur), size[cur] += size[v];
}
is_key[cur] |= (size[cur] >= 2);
size[cur] = (size[cur] || is_key[cur]);
}
inline int pre_work(int cur) {
int last = 0;
vis[cur] = true, g[cur][0] = g[cur][1] = 1;
for (int p = 0; p < edge[cur].size(); ++p) {
int v = edge[cur][p];
if (!vis[v]) {
int temp = pre_work(v);
if (last == 0) last = temp;
if (temp == 0) {
g[cur][0] = g[cur][0] * (g[v][1] + g[v][0]) % mod;
g[cur][1] = g[cur][1] * g[v][0] % mod;
} else if (is_key[cur])
key[cur].push_back(data(temp, k[v][0] + k[v][1], k[v][0]));
else k[cur][0] = k[v][0] + k[v][1], k[cur][1] = k[v][0];
}
}
if (is_key[cur]) k[cur][0] = node(1, 0), k[cur][1] = node(0, 1), last = cur;
else k[cur][0] = k[cur][0] * g[cur][0], k[cur][1] = k[cur][1] * g[cur][1];
return last;
}
inline void dp(int cur) {
f[cur][0] = (ban[cur][0] ^ 1) * g[cur][0];
f[cur][1] = (ban[cur][1] ^ 1) * g[cur][1];
for (int p = 0; p < key[cur].size(); ++p) {
data *e = &key[cur][p];
dp(e->to);
f[cur][0] = f[cur][0] * (e->f0.x * f[e->to][0] % mod
+ e->f0.y * f[e->to][1] % mod) % mod;
f[cur][1] = f[cur][1] * (e->f1.x * f[e->to][0] % mod
+ e->f1.y * f[e->to][1] % mod) % mod;
}
}
inline void read_in() {
R(n), R(m);
for (int i = 1; i <= m; ++i) R(x), R(y), add_edge(x, y);
}
inline void solve() {
long long ans = 0;
dfs(1, 0), is_key[1] = true, memset(vis, false, sizeof(vis)), pre_work(1);
for (int i = 0, end = (1 << cnt); i < end; ++i) {
for (int j = 0; j < cnt; ++j) {
if (i & (1 << j)) ban[down[j]][0] = true, ban[up[j]][1] = true;
else ban[down[j]][1] = true;
}
dp(1), ans += (f[1][0] + f[1][1]);
for (int j = 0; j < cnt; ++j) {
if (i & (1 << j)) ban[down[j]][0] = false, ban[up[j]][1] = false;
else ban[down[j]][1] = false;
}
}
std::cout << ans % mod;
}
int main() {
// freopen("in.in", "r", stdin);
// freopen("out.out", "w", stdout);
read_in();
solve();
return 0;
}