方法 3:拉格朗日插值法
根据多项式除法那么我们会有:
f ( x ) ≡ f ( a ) ( m o d ( x − a ) ) f(x)\equiv f(a)\pmod{(x-a)} f(x)≡f(a)(mod(x−a))
这是显然的,因为 f ( x ) − f ( a ) = ( a 0 − a 0 ) + a 1 ( x 1 − a 1 ) + a 1 ( x 2 − a 2 ) + ⋯ + a n ( x n − a n ) f(x)-f(a)=(a_0-a_0)+a_1(x^1-a^1)+a_1(x^2-a^2)+\cdots +a_n(x^n-a^n) f(x)−f(a)=(a0−a0)+a1(x1−a1)+a1(x2−a2)+⋯+an(xn−an),这个式子显然有 ( x − a ) (x-a) (x−a) 这个因式,所以得证。
这样我们就可以列一个关于 f ( x ) f(x) f(x) 的多项式线性同余方程组:
{ f ( x ) ≡ y 1 ( m o d ( x − x 1 ) ) f ( x ) ≡ y n ( m o d ( x − x 2 ) ) ⋯ f ( x ) ≡ y n ( m o d ( x − x n ) ) \begin{cases} f(x)\equiv y_1\pmod{(x-x_1)}\\ f(x)\equiv y_n\pmod{(x-x_2)}\\ \cdots\\ f(x)\equiv y_n\pmod{(x-x_n)} \end{cases} ⎩⎪⎪⎪⎨⎪⎪⎪⎧f(x)≡y1(mod(x−x1))f(x)≡yn(mod(x−x2))⋯f(x)≡yn(mod(x−xn))
我们根据中国剩余定理,有:
M = ∏ i = 1 n ( x − x i ) , m i = M x − x i = ∏ j ≠ i ( x − x j ) M=\prod_{i=1}^n{(x-x_i)},m_i=\dfrac M{x-x_i}=\prod_{j\ne i}{(x-x_j)} M=i=1∏n(x−xi),mi=x−xiM=j=i∏(x−xj)
则 m i m_i mi 模 ( x − x i ) (x-x_i) (x−xi) 意义下的逆元就是:
m i − 1 = ∏ j ≠ i 1 x i − x j m_i^{-1}=\prod_{j\ne i}{\dfrac 1{x_i-x_j}} mi−1=j=i∏xi−xj1
所以就有:
f ( x ) ≡ ∑ i = 1 n y i m i m i − 1 ≡ ∑ i = 1 n y i ∏ j ≠ i x − x j x i − x j ( m o d M ) f(x)\equiv\sum_{i=1}^n{y_im_im_i^{-1}}\equiv\sum_{i=1}^n{y_i\prod_{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}}\pmod M f(x)≡i=1∑nyimimi−1≡i=1∑nyij=i∏xi−xjx−xj(modM)
所以在模意义下 f ( x ) f(x) f(x) 就是唯一的,即:
f ( x ) = ∑ i = 1 n y i ∏ j ≠ i x − x j x i − x j f(x)=\sum_{i=1}^n{y_i\prod_{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}} f(x)=i=1∑nyij=i∏xi−xjx−xj
这就是拉格朗日插值的表达式。
如果要将每一项的系数都算出来,时间复杂度仍为 O ( n 2 ) O(n^2) O(n2),但是本题中只用求出 f ( k ) f(k) f(k) 的值,所以在计算上式的过程中直接将 k k k 代入即可。
f ( k ) = ∑ i = 1 n y i ∏ j ≠ i k − x j x i − x j f(k)=\sum_{i=1}^{n} y_i\prod_{j\neq i }\frac{k-x_j}{x_i-x_j} f(k)=i=1∑nyij=i∏xi−xjk−xj
本题中,还需要求解逆元。如果先分别计算出分子和分母,再将分子乘进分母的逆元,累加进最后的答案,时间复杂度的瓶颈就不会在求逆元上,时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
代码实现
#include<bits/stdc++.h>
#include <unordered_map>
using namespace std;
template<class...Args>
void debug(Args... args) {//Parameter pack
auto tmp = { (cout << args << ' ', 0)... };
cout << "\n";
}
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>pll;
typedef pair<int, int>pii;
const ll N = 1e5 + 5;
const ll INF = 0x7fffffff;
const ll MOD = 1e9 + 7;
ll qpow(ll x, ll n, ll mod) {
ll ans = 1;
while (n) {
if (n & 1)ans = (ans * x) % mod;
x = (x * x) % mod;
n >>= 1;
}
return ans % mod;
}
ll get_inv(ll num,ll mod) {
return qpow(num, mod - 2, mod);
}
ll lagrange(vector<ll> x,vector<ll> y,int k ,int mod) {
int n = x.size();
ll ans = 0;
ll s1, s2;
for (int i = 0; i < n; i++) {
s1 = y[i] % mod;
s2 = 1;
for (int j = 0; j < n; j++) {
if (i != j) {
s1 = s1 * (k - x[j]) % mod;
s2 = s2 * ((x[i] - x[j] % mod) % mod) % mod;
}
}
ans += s1 * get_inv(s2, mod) % mod;
ans = (ans + mod) % mod;
}
return ans;
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int n, k;
cin >> n >> k;
vector<ll>x(n),y(n);
for (int i = 0; i < n; i++) cin >> x[i] >> y[i];
cout << lagrange(x,y,k, 998244353);
return 0;
}