洛谷传送门
题目背景
这是一道模板题
题目描述
由小学知识可知,
n
n
n 个点
(
x
i
,
y
i
)
(x_i,y_i)
(xi,yi) 可以唯一地确定一个多项式
现在,给定 n n n 个点,请你确定这个多项式,并将 k k k 代入求值
求出的值对 998244353 998244353 998244353 取模
输入输出格式
输入格式:
第一行两个正整数 n , k n,k n,k ,含义如题
接下来 n n n 行,每行两个正整数 x i , y i x_i,y_i xi,yi ,含义如题
输出格式:
一个整数表示答案
输入输出样例
输入样例#1:
3 100
1 4
2 9
3 16
输出样例#1:
10201
输入样例#2:
3 100
1 1
2 2
3 3
输出样例#2:
100
说明
n ≤ 2000        x i , y i , k ≤ 998244353 n \leq 2000 \; \; \; x_i,y_i,k \leq 998244353 n≤2000xi,yi,k≤998244353;
样例一中的三个点确定的多项式是 f ( x ) = x 2 + 2 x + 1 f(x)=x^2+2x+1 f(x)=x2+2x+1,将 100 100 100代入求值得到 10201 10201 10201
样例二中的三个点确定的多项式是 f ( x ) = x f(x)=x f(x)=x ,将 100 100 100 代入求值得到 100 100 100
如果你不会拉格朗日插值,你可以到这里去学习一下
此外,请注意算法的常数问题,建议开启O2优化
解题分析
拉格朗日插值法是什么? 就是下面这个东西:
f
(
x
)
=
∑
i
=
1
n
f
(
x
i
)
∏
j
=
1
,
j
≠
i
n
x
−
x
j
x
i
−
x
j
f(x)=\sum_{i=1}^{n}f(x_i)\prod_{j=1,j\neq i}^{n}\frac{x-x_j}{x_i-x_j}
f(x)=i=1∑nf(xi)j=1,j̸=i∏nxi−xjx−xj
那么我们代入值进去计算即可。
复杂度 O ( N 2 ) O(N^2) O(N2)(貌似有 O ( N l o g 2 N ) O(Nlog^2N) O(Nlog2N)的多项式多点插值算法??蒟蒻不会…)
代码如下:
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <cmath>
#define W while
#define R register
#define IN inline
#define MX 2005
#define ll long long
#define gc getchar()
#define MOD 998244353ll
bool neg;
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc;
W (!isdigit(c))
{if(c == '-') neg = true; c = gc;}
W (isdigit(c))
x = (x << 1) + (x << 3) + c - 48, c = gc;
if(neg) neg = false, x = -x;
}
int num;
ll tar;
int x[MX], y[MX];
ll qpow(ll now, ll tim)
{
ll base = now, ret = 1;
for (R int i = 0; i <= 31; ++i, base = base * base % MOD)
{if(tim & (1 << i)) ret = ret * base % MOD;}
return ret;
}
ll lagrange()
{
ll ret = 0, up, down;
for (R int i = 1; i <= num; ++i)
{
up = down = 1;
for (R int j = 1; j <= num; ++j)
{
if(i == j) continue;
up = up * (tar - x[j]) % MOD;
down = down * (x[i] - x[j]) % MOD;
}
ret = (ret + up * qpow(down, MOD - 2) % MOD * y[i] % MOD + MOD) % MOD;
}
return ret;
}
int main(void)
{
in(num), in(tar);
for (R int i = 1; i <= num; ++i) in(x[i]), in(y[i]);
printf("%lld", lagrange());
}