【模拟赛】贸易(多项式多点求值,点分治)

22 篇文章 0 订阅

题面

M国有 n n n 座城市,城市之间有一些道路相连,使得整个M国的地图形如一个树形结构。

小Y是M国的一个商人,他打算在M国的城市之间进行一次贸易。

经过市场调查,小Y得知,现在市面上可供贸易的货物共有 m m m 种,第 i i i 种货物有三种属性 a i , b i , p i a_i,b_i,p_i ai,bi,pi 。他可以选择任意一种货物,在任意一座城市以 a i a_i ai え的价格买入,沿道路到达任意一座城市并以 b i b_i bi え的价格卖出,但贸易行为是有风险的,每当小Y试图经过一条道路时,他手上的货物都有 p i p_i pi 的概率损毁,一旦损毁小Y就无法卖出货物,只能白白承担全部损失。

为了进一步进行市场调查,并分析比较选择哪种货物更赚钱,小Y决定随机选择两座城市(可能相同)作为起点和终点,并在这两座城市之间展开贸易。

你需要帮小Y计算:对于每一种货物而言,小Y选择这种货物并按如上策略进行贸易的的期望收益是多少。

所有答案对 998244353 998244353 998244353 取模。

1 ≤ n , m ≤ 1 0 5 1\leq n,m\leq 10^5 1n,m105

题解

这道题非常的舒服啊,很快就可以推完结论,然后开始漫长的肝代码时间。


第零步,我们需要知道,货物损毁的概率仅仅取决于自身和道路的长度。

那么首先,令 f i f_i fi 表示道路长度为 i i i 的概率,一个货物 j j j 的答案就是
∑ i = 0 n − 1 f i ( 1 − p j ) i \sum_{i=0}^{n-1}f_i(1-p_j)^i i=0n1fi(1pj)i

我们怎么快速求出 f i f_i fi ?不论怎么往树形DP优化、启发式合并、树链剖分上想都是没结果的。只有点分治之类可以,点分治的时候用 N T T NTT NTT 求出每一次的贡献,这也很好想。

那么下一步,不难翻译出,我们就是要对于一个多项式
f ( x ) = ∑ i = 0 n − 1 f i x i f(x)=\sum_{i=0}^{n-1}f_ix^i f(x)=i=0n1fixi

以及 m m m 个点
( 1 − p 1 ) , ( 1 − p 2 ) , . . . , ( 1 − p m ) (1-p_1),(1-p_2),...,(1-p_m) (1p1),(1p2),...,(1pm)

进行多项式多点求值

如果是传统的多点求值,涉及这方面的知识有三个


多点求值
在这里插入图片描述
万能的却迟迟不更新的OIwiki

多项式取模
在这里插入图片描述
万能的却迟迟不更新的OIwiki

多项式求逆
在这里插入图片描述
My Blog


近年来多了一种小常数新方法:无多项式取模
在这里插入图片描述

CODE

“理论存在,实践开始”

#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <stack>
#include <random>
#include <vector>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define MAXN 100005
#define LL long long
#define ULL unsigned long long
#define ENDL putchar('\n')
#define DB double
#define lowbit(x) (-(x) & (x))
#define FI first
#define SE second
int xchar() {
    static const int maxn = 1000000;
    static char b[maxn];
    static int pos = 0, len = 0;
    if (pos == len)
        pos = 0, len = fread(b, 1, maxn, stdin);
    if (pos == len)
        return -1;
    return b[pos++];
}
//#define getchar() xchar()
LL read() {
    LL f = 1, x = 0;
    int s = getchar();
    while (s < '0' || s > '9') {
        if (s < 0)
            return -1;
        if (s == '-')
            f = -f;
        s = getchar();
    }
    while (s >= '0' && s <= '9') {
        x = (x << 1) + (x << 3) + (s ^ 48);
        s = getchar();
    }
    return f * x;
}
void putpos(LL x) {
    if (!x)
        return;
    putpos(x / 10);
    putchar((x % 10) ^ 48);
}
void putnum(LL x) {
    if (!x) {
        putchar('0');
        return;
    }
    if (x < 0)
        putchar('-'), x = -x;
    return putpos(x);
}
void AIput(LL x, int c) {
    putnum(x);
    putchar(c);
}

const int MOD = 998244353;
int n, m, s, o, k;
int qkpow(int a, int b) {
    int res;
    for (res = 1; b > 0; b >>= 1) {
        if (b & 1)
            res = res * 1ll * a % MOD;
        a = a * 1ll * a % MOD;
    }
    return res;
}
const int rm = 3;
int om, xm[MAXN << 2], rev[MAXN << 2], tg = 0;
void NTT(int *s, int n, int op) {
    if (n * op != tg) {
        for (int i = 1; i < n; i++) rev[i] = ((rev[i >> 1]) >> 1) | ((i & 1) ? (n >> 1) : 0);
        om = qkpow(rm, (MOD - 1) / n);
        if (op < 0)
            om = qkpow(om, MOD - 2);
        xm[0] = 1;
        for (int i = 1; i <= n; i++) xm[i] = xm[i - 1] * 1ll * om % MOD;
        tg = n * op;
    }
    for (int i = 1; i < n; i++)
        if (rev[i] < i)
            swap(s[rev[i]], s[i]);
    for (int k = 2, t = n >> 1; k <= n; k <<= 1, t >>= 1) {
        for (int j = 0; j < n; j += k) {
            for (int i = j, l = 0; i < j + (k >> 1); i++, l += t) {
                int A = s[i], B = s[i + (k >> 1)];
                s[i] = (xm[l] * 1ll * B + A) % MOD;
                s[i + (k >> 1)] = (MOD - xm[l] * 1ll * B % MOD + A) % MOD;
            }
        }
    }
    if (op < 0) {
        int iv = qkpow(n, MOD - 2);
        for (int i = 0; i < n; i++) s[i] = s[i] * 1ll * iv % MOD;
    }
    return;
}
int a_[MAXN << 2], b_[MAXN << 2];
void polyinv(int *s, int n) {  // mod x^n
    for (int i = 0; i < n; i++) a_[i] = s[i], s[i] = 0;
    for (int i = n; i < (n << 1); i++) s[i] = 0;
    s[0] = qkpow(a_[0], MOD - 2);
    int le = 1;
    while (le < n) {
        le <<= 1;
        for (int i = 0; i < le; i++) b_[i] = a_[i];
        NTT(b_, le << 1, 1);
        NTT(s, le << 1, 1);
        for (int i = 0; i < (le << 1); i++)
            s[i] = (2ll + MOD - b_[i] * 1ll * s[i] % MOD) * 1ll * s[i] % MOD, b_[i] = 0;
        NTT(s, le << 1, -1);
        for (int i = le; i < (le << 1); i++) s[i] = 0;
    }
    for (int i = n; i < le; i++) s[i] = 0;
    for (int i = 0; i < n; i++) a_[i] = 0;
    return;
}
void polyrev(int *s, int n) {  // mod x^{n+1}
    for (int i = 0; (i << 1) <= n; i++) swap(s[i], s[n - i]);
    return;
}
int f_[MAXN << 2], h_[MAXN << 2], g_[MAXN << 2];
void polymod(int *f, int *g, int *r, int n, int m) {
    for (int i = 0; i <= n; i++) f_[i] = f[i];
    polyrev(f_, n);
    for (int i = 0; i <= m; i++) g_[i] = g[i];
    polyrev(g_, m);
    int le = 1;
    while (le <= 2 * n) le <<= 1;
    polyinv(g_, n - m + 1);
    NTT(f_, le, 1);
    NTT(g_, le, 1);
    for (int i = 0; i < le; i++) h_[i] = f_[i] * 1ll * g_[i] % MOD, g_[i] = 0;
    NTT(h_, le, -1);
    for (int i = n - m + 1; i < le; i++) h_[i] = 0;
    polyrev(h_, n - m);
    for (int i = 0; i <= m; i++) g_[i] = g[i];
    NTT(h_, le, 1);
    NTT(g_, le, 1);
    for (int i = 0; i < le; i++) h_[i] = h_[i] * 1ll * g_[i] % MOD;
    NTT(h_, le, -1);
    for (int i = 0; i <= n; i++) r[i] = (f[i] + MOD - h_[i]) % MOD;
    for (int i = 0; i < le; i++) h_[i] = f_[i] = g_[i] = 0;
    return;
}
vector<int> G[MAXN << 2];
void polyinit(int a, int *g, int l, int r) {
    if (l > r)
        return;
    if (l == r) {
        G[a].clear();
        G[a].push_back((MOD - g[l]) % MOD);
        G[a].push_back(1);
        return;
    }
    int md = (l + r) >> 1, l1 = md - l + 1, l2 = r - md;
    polyinit(a << 1, g, l, md);
    polyinit(a << 1 | 1, g, md + 1, r);
    for (int i = 0; i <= l1; i++) a_[i] = G[a << 1][i];
    for (int i = 0; i <= l2; i++) b_[i] = G[a << 1 | 1][i];
    int le = 1;
    while (le <= r - l + 1) le <<= 1;
    NTT(a_, le, 1);
    NTT(b_, le, 1);
    for (int i = 0; i < le; i++) a_[i] = a_[i] * 1ll * b_[i] % MOD;
    NTT(a_, le, -1);
    G[a].clear();
    //	printf("[%d,%d] : \n",l,r);
    //	for(int i = 0;i <= r-l+1;i ++) printf("%d ",a_[i]); ENDL;
    for (int i = 0; i <= r - l + 1; i++) G[a].push_back(a_[i]);
    for (int i = 0; i < le; i++) a_[i] = b_[i] = 0;
    return;
}
vector<int> f2[MAXN << 2];
int ww[MAXN << 2];
void polyquick(int a, int *f, int *g, int n, int l, int r) {
    if (l > r)
        return;
    if (l == r) {
        int y = 0;
        for (int i = 0, pw = 1; i <= n; i++, pw = pw * 1ll * g[i] % MOD) (y += f[i] * 1ll * pw % MOD) %= MOD;
        g[l] = y;
        return;
    }
    int md = (l + r) >> 1, l1 = md - l + 1, l2 = r - md;
    f2[a].clear();
    for (int i = 0; i <= n; i++) f2[a].push_back(f[i]);
    for (int i = 0; i <= l1; i++) ww[i] = G[a << 1][i];
    polymod(f, ww, f, n, l1);
    for (int i = 0; i <= l1; i++) ww[i] = 0;
    polyquick(a << 1, f, g, l1, l, md);

    for (int i = 0; i <= n; i++) f[i] = f2[a][i];
    for (int i = 0; i <= l2; i++) ww[i] = G[a << 1 | 1][i];
    polymod(f, ww, f, n, l2);
    for (int i = 0; i <= l2; i++) ww[i] = 0;
    polyquick(a << 1 | 1, f, g, l2, md + 1, r);
    return;
}
// ----------------------------------------------------------
// ----------------------------------------------------------
// ----------------------------------------------------------
int hd[MAXN], nx[MAXN << 1], v[MAXN << 1], cne;
void ins(int x, int y) {
    nx[++cne] = hd[x];
    v[cne] = y;
    hd[x] = cne;
}
bool f[MAXN], swt;
int d[MAXN], siz[MAXN], SIZ, hv, mh;
int ct[MAXN << 2];
int c[MAXN << 2];
void dfs(int x, int ff) {
    d[x] = d[ff] + 1;
    if (swt)
        c[d[x]]++, mh = max(mh, d[x]);
    siz[x] = 1;
    bool tg = 1;
    for (int i = hd[x]; i; i = nx[i]) {
        if (v[i] != ff && !f[v[i]]) {
            dfs(v[i], x);
            siz[x] += siz[v[i]];
            if (siz[v[i]] > SIZ / 2)
                tg = 0;
        }
    }
    if (SIZ - siz[x] <= SIZ / 2 && tg)
        hv = x;
    return;
}
void calc(int x, int op, int ad) {
    d[0] = ad - 1;
    swt = 1;
    mh = 0;
    dfs(x, 0);
    swt = 0;
    int le = 1;
    while (le <= mh * 2) le <<= 1;
    NTT(c, le, 1);
    for (int i = 0; i < le; i++) c[i] = c[i] * 1ll * c[i] % MOD;
    NTT(c, le, -1);
    for (int i = 0; i < le; i++) {
        if (i <= mh * 2)
            (ct[i] += (MOD + op) * 1ll * c[i] % MOD) %= MOD;
        c[i] = 0;
    }
    return;
}
void solve(int x, int n) {
    calc(x, 1, 0);
    f[x] = 1;
    for (int i = hd[x]; i; i = nx[i]) {
        int y = v[i];
        if (!f[y]) {
            SIZ = siz[y];
            calc(y, -1, 1);
            solve(hv, siz[y]);
        }
    }
    return;
}
int p[MAXN], ai[MAXN], bi[MAXN];
int main() {
    freopen("trade.in", "r", stdin);
    freopen("trade.out", "w", stdout);
    n = read();
    m = read();
    for (int i = 1; i < n; i++) {
        s = read();
        o = read();
        ins(s, o);
        ins(o, s);
    }
    SIZ = n;
    dfs(1, 0);
    solve(hv, n);
    int invn = qkpow(n * 1ll * n % MOD, MOD - 2);
    //	for(int i = 0;i <= n;i ++) printf("ct[%d] = %d\n",i,ct[i]);
    for (int i = 0; i <= n; i++) ct[i] = ct[i] * 1ll * invn % MOD;
    for (int i = 1; i <= m; i++) {
        ai[i] = read();
        bi[i] = read();
        s = read();
        o = read();
        p[i] = s * 1ll * qkpow(o, MOD - 2) % MOD;
        p[i] = (MOD + 1 - p[i]) % MOD;
    }
    polyinit(1, p, 1, m);
    polyquick(1, ct, p, n, 1, m);
    for (int i = 1; i <= m; i++) {
        int ans = (bi[i] * 1ll * p[i] % MOD + MOD - ai[i]) % MOD;
        AIput(ans, '\n');
    }
    return 0;
}

🤮

还好,平时肝习惯了

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值