拉格朗日插值学习笔记

  • 问题

    给定 n n n个点 ( x i , y i ) (x_i,y_i) (xi,yi),其中 x i x_i xi互不相同,用这些点确定一个多项式 f ( x ) f(x) f(x)

  • 求解

    用这 n n n个点可以构造出一个 n × n n \times n n×n的线性方程组,显然系数行列式是一个范德蒙德行列式,又 x i x_i xi互不相同,所以这 n n n个点可以唯一确定一个小于等于 n − 1 n-1 n1次的多项式。

    我们可以构造出 g i ( x ) , 1 ≤ i ≤ n g_i(x),1 \le i \le n gi(x),1in,满足 g i ( x i ) = y i g_i(x_i)=y_i gi(xi)=yi g i ( x j ) = 0 , j ≠ i g_i(x_j)=0,j \ne i gi(xj)=0,j=i,那么 f ( x ) = ∑ i = 1 n g i ( x ) \displaystyle f(x)=\sum_{i=1}^ng_i(x) f(x)=i=1ngi(x),因为把这 n n n个点代进去都成立,又由于 f ( x ) f(x) f(x)的唯一性,所以可以这样确定 f ( x ) f(x) f(x)

    构造 g i ( x ) = y i ∏ j ≠ i x − x j x i − x j \displaystyle g_i(x)=y_i \prod_{j \ne i}\frac{x-x_j}{x_i-x_j} gi(x)=yij=ixixjxxj即可,那么 f ( x ) = ∑ i = 1 n y i ∏ j ≠ i x − x j x i − x j \displaystyle f(x)=\sum_{i=1}^ny_i \prod_{j \ne i}\frac{x-x_j}{x_i-x_j} f(x)=i=1nyij=ixixjxxj

  • 例题
    • P4781 【模板】拉格朗日插值

      k k k代入得 f ( k ) = ∑ i = 1 n y i ∏ j ≠ i k − x j x i − x j \displaystyle f(k)=\sum_{i=1}^ny_i \prod_{j \ne i}\frac{k-x_j}{x_i-x_j} f(k)=i=1nyij=ixixjkxj

      复杂度: O ( n 2 ) O(n^2) O(n2)
      代码:
      #include<cstdio>
      #include<iostream>
      #include<algorithm>
      #include<cstring>
      #include<cmath>
      #include<vector>
      #include<queue>
      #include<stack>
      #include<map>
      #include<set>
      #include<string>
      #include<bitset>
      #include<sstream>
      #include<ctime>
      //#include<chrono>
      //#include<random>
      //#include<unordered_map>
      using namespace std;
      
      #define ll long long
      #define ls o<<1
      #define rs o<<1|1
      #define pii pair<int,int>
      #define fi first
      #define se second
      #define pb push_back
      #define mp make_pair
      #define sz(x) (int)(x).size()
      #define all(x) (x).begin(),(x).end()
      const double pi=acos(-1.0);
      const double eps=1e-6;
      const int mod=998244353;
      const int INF=0x3f3f3f3f;
      const int maxn=2e3+5;
      ll read(){
      	ll x=0,f=1;
      	char ch=getchar();
      	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
      	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
      	return x*f;
      }
      int n,k;
      ll x[maxn],y[maxn];
      ll f[maxn];
      ll qpow(ll a,ll p=mod-2){
      	ll res=1;
      	while(p){
      		if(p&1)res=res*a%mod;
      		a=a*a%mod;
      		p>>=1;
      	}
      	return res;
      }
      //f(k)
      ll Lagrange_Interpolation1(ll *x,ll *y,int n,ll k){
      	ll res=0;
      	for(int i=1;i<=n;i++){
      		ll up=y[i],down=1;
      		for(int j=1;j<=n;j++){
      			if(j==i)continue;
      			down=down*(x[i]-x[j])%mod;
      			up=up*(k-x[j])%mod;
      		}
      		up=(up+mod)%mod;
      		down=(down+mod)%mod;
      		res+=up*qpow(down)%mod;
      		res%=mod;
      	}
      	return res;
      }
      int main(void){
      	// freopen("in.txt","r",stdin);
      	scanf("%d%d",&n,&k);
      	for(int i=1;i<=n;i++){
      		scanf("%lld%lld",&x[i],&y[i]);
      	}
      	printf("%lld\n",Lagrange_Interpolation1(x,y,n,k));
      	return 0;
      }
      
    • P5158 【模板】多项式快速插值的前30分

      上一个题目是求解一个 f ( k ) f(k) f(k),现在我们需要得到 f ( x ) f(x) f(x)的各项系数。

      模拟多项式的运算即可,具体来说,令 g i ( x ) = ∏ j ≠ i ( x − x j ) , h ( x ) = ∏ j = 1 n ( x − x j ) , c i = y i ∏ j ≠ i 1 x i − x j \displaystyle g_i(x)=\prod_{j \ne i}(x-x_j),h(x)=\prod_{j=1}^n(x-x_j),c_i=y_i\prod_{j \ne i}\frac{1}{x_i-x_j} gi(x)=j=i(xxj),h(x)=j=1n(xxj),ci=yij=ixixj1,那么 g i ( x ) = h ( x ) x − x i , f ( x ) = ∑ i = 1 n c i g i ( x ) \displaystyle g_i(x)=\frac{h(x)}{x-x_i},f(x)=\sum_{i=1}^nc_ig_i(x) gi(x)=xxih(x),f(x)=i=1ncigi(x)

      计算 h ( x ) h(x) h(x)的时候利用 d p dp dp,令 h i , j h_{i,j} hi,j为前 i i i个式子拼成的多项式中 x j x^j xj的系数,转移方程为 h i , j = h i − 1 , j − 1 + h i − 1 , j ⋅ ( − x i ) h_{i,j}=h_{i-1,j-1}+h_{i-1,j} \cdot (-x_i) hi,j=hi1,j1+hi1,j(xi),这里可以用滚动数组降维。

      计算 g i ( x ) g_i(x) gi(x)的时候也是利用 d p dp dp,令 g i , j g_{i,j} gi,j表示 g i ( x ) g_i(x) gi(x) x j x^j xj的系数,转移方程为 g i , j = ( h n , j − g i , j − 1 ) ⋅ ( − x i ) g_{i,j}=(h_{n,j}-g_{i,j-1})\cdot (-x_i) gi,j=(hn,jgi,j1)(xi)

      复杂度: O ( n 2 ) O(n^2) O(n2)
      代码:
      #include<cstdio>
      #include<iostream>
      #include<algorithm>
      #include<cstring>
      #include<cmath>
      #include<vector>
      #include<queue>
      #include<stack>
      #include<map>
      #include<set>
      #include<string>
      #include<bitset>
      #include<sstream>
      #include<ctime>
      //#include<chrono>
      //#include<random>
      //#include<unordered_map>
      using namespace std;
      
      #define ll long long
      #define ls o<<1
      #define rs o<<1|1
      #define pii pair<int,int>
      #define fi first
      #define se second
      #define pb push_back
      #define mp make_pair
      #define sz(x) (int)(x).size()
      #define all(x) (x).begin(),(x).end()
      const double pi=acos(-1.0);
      const double eps=1e-6;
      const int mod=998244353;
      const int INF=0x3f3f3f3f;
      const int maxn=100005;
      ll read(){
      	ll x=0,f=1;
      	char ch=getchar();
      	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
      	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
      	return x*f;
      }
      int n;
      ll f[maxn],x[maxn],y[maxn];
      ll g[maxn],c[maxn],h[maxn];
      ll qpow(ll a,ll p=mod-2){
      	ll res=1;
      	while(p){
      		if(p&1)res=res*a%mod;
      		a=a*a%mod;
      		p>>=1;
      	}
      	return res;
      }
      void Lagrange_Interpolation2(ll *x,ll *y,int n,ll *f){
      	fill(f,f+n,0);
      	fill(h,h+n+1,0);
      	for(int i=1;i<=n;i++){
      		c[i]=y[i];
      		ll tmp=1;
      		for(int j=1;j<=n;j++){
      			if(j==i)continue;
      			tmp=tmp*(x[i]-x[j])%mod;
      		}
      		tmp=(tmp+mod)%mod;
      		c[i]=c[i]*qpow(tmp)%mod;
      	}
      	h[0]=1;
      	for(int i=1;i<=n;i++){
      		for(int j=n;j>0;j--){
      			h[j]=(h[j-1]+h[j]*(mod-x[i])%mod)%mod;
      		}
      		h[0]=h[0]*(mod-x[i])%mod;
      	}
      	for(int i=1;i<=n;i++){
      		ll inv=qpow(mod-x[i]);
      		g[0]=h[0]*inv%mod;
      		for(int j=1;j<n;j++){
      			g[j]=(h[j]-g[j-1])*inv%mod;
      		}
      		for(int j=0;j<n;j++){
      			f[j]+=c[i]*g[j]%mod;
      			f[j]=(f[j]+mod)%mod;
      		}
      	}
      }
      int main(void){
      	// freopen("in.txt","r",stdin);
      	scanf("%d",&n);
      	for(int i=1;i<=n;i++){
      		scanf("%lld%lld",&x[i],&y[i]);
      	}
      	Lagrange_Interpolation2(x,y,n,f);
      	for(int i=0;i<n;i++){
      		printf("%lld ",f[i]);
      	}
      	puts("");
      	return 0;
      }
      
    • 练习题
      • codeforces622F The Sum of the k-th Powers

        ​ 令 f ( x ) = ∑ i = 1 x i k \displaystyle f(x)=\sum_{i=1}^xi^k f(x)=i=1xik,那么我们要求的就是 f ( n ) f(n) f(n)。易知 f ( x ) f(x) f(x)是一个 k + 1 k+1 k+1次的多项式,所以我们可以用1,2,…… k + 2 k+2 k+2进行插值,但直接插的复杂度为 O ( n 2 ) O(n^2) O(n2),不能通过此题。因为插值用的点是连续的,所以可以不用每次重新算,可以递推出来。

        复杂度: O ( k l o g k ) O(klogk) O(klogk)

      #include<cstdio>
      #include<iostream>
      #include<algorithm>
      #include<cstring>
      #include<cmath>
      #include<vector>
      #include<queue>
      #include<stack>
      #include<map>
      #include<set>
      #include<string>
      #include<bitset>
      #include<sstream>
      #include<ctime>
      //#include<chrono>
      //#include<random>
      //#include<unordered_map>
      using namespace std;
      
      #define ll long long
      #define ls o<<1
      #define rs o<<1|1
      #define pii pair<int,int>
      #define fi first
      #define se second
      #define pb push_back
      #define mp make_pair
      #define sz(x) (int)(x).size()
      #define all(x) (x).begin(),(x).end()
      const double pi=acos(-1.0);
      const double eps=1e-6;
      const int mod=1e9+7;
      const int INF=0x3f3f3f3f;
      const int maxn=1e6+5;
      ll read(){
      	ll x=0,f=1;
      	char ch=getchar();
      	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
      	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
      	return x*f;
      }
      int n,k;
      ll qpow(ll a,ll p=mod-2){
      	ll res=1;
      	while(p){
      		if(p&1)res=res*a%mod;
      		a=a*a%mod;
      		p>>=1;
      	}
      	return res;
      }
      //f(k)
      ll x[maxn],y[maxn];
      ll Lagrange_Interpolation1(ll *x,ll *y,int n,ll k){
      	ll res=0;
      	ll p1=1;
      	for(int i=1;i<=n;i++){
      		p1=p1*(k-x[i])%mod;
      	}
      	ll p2=1;
      	for(int i=2;i<=n-1;i++){
      		p2=p2*i%mod;
      	}
      	for(int i=1;i<=n;i++){
      		ll up=y[i]*p1%mod*qpow(k-x[i])%mod;
      		ll down=qpow(((n-i)%2==0)?p2:(mod-p2));
      		res+=up*down%mod;
      		res%=mod;
      		if(i!=n){
      			p2=p2*qpow(n-i)%mod*i%mod;
      		}
      	}
      	return res;
      }
      int main(void){
      	// freopen("in.txt","r",stdin);
      	scanf("%d%d",&n,&k);
      	if(n<=k+2){
      		ll ans=0;
      		for(int i=1;i<=n;i++){
      			ans=(ans+qpow(i,k))%mod;
      		}
      		printf("%lld\n",ans);
      	}
      	else{
      		ll tmp=0;
      		for(int i=1;i<=k+2;i++){
      			x[i]=i;
      			tmp=(tmp+qpow(i,k))%mod;
      			y[i]=tmp;
      		}
      		printf("%lld\n",Lagrange_Interpolation1(x,y,k+2,n));
      	}
      	return 0;
      }
      
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值