多项式函数曲线拟合的Scala实现

代码从以下书本中的相关章节改写而来,为简化问题,这里只用来拟合三次多项式函数。
在这里插入图片描述
M3 的代码是列主元消去法解线性方程:

import scala.math._

object M3 {
	def eliminate(ac: Array[Array[Double]], bc: Array[Double], k: Int): Unit = {
		val n = bc.length - 1
		var m: Double = 0.1
		for (i <- k + 1 to n) {
			m = ac(i)(k) / ac(k)(k)
			for (j <- k + 1 to n) {
				ac(i)(j) = ac(i)(j) - m * ac(k)(j)
			}
			bc(i) = bc(i) - m * bc(k)
			//println(bc(i))
		}		
	}

	def backSubstitute(ac: Array[Array[Double]], bc: Array[Double]): Unit = {
		val n = bc.length - 1
		bc(n) = bc(n) / ac(n)(n)
		for (i <- n - 1 to 0 by -1) {
			var s = 0.0
			for (j <- i+1 to n) {
				s = s + ac(i)(j) * bc(j)
			}
			bc(i) = (bc(i) - s) / ac(i)(i)
		}
	}

	def selectColumnPivot(ac: Array[Array[Double]], k: Int): (Int, Double) = {
		var ik: Int = 0
		var acMax: Double = 0.0
		val n = ac.length - 1
		for (i <- k to n) {
			if (abs(ac(i)(k)) > acMax) {
				acMax = abs(ac(i)(k))
				ik = i
			}
		}
		(ik, acMax)
	}
	
	def columnPivotGauss(ac: Array[Array[Double]], bc: Array[Double]): Unit = {
		val n = bc.length - 1
		var ik = 0
		var acMax = 0.00
		for (k <- 0 to n-1){
			var (ik, acMax) = selectColumnPivot(ac, k)
			if (ik != k) {
				for (j <- k to n) {
					var sa = (ac(k)(j), ac(ik)(j))
					ac(k)(j) = sa._2
					ac(ik)(j) = sa._1
				}
				var sb = (bc(k), bc(ik))
				bc(k) = sb._2
				bc(ik) = sb._1
			}
			eliminate(ac, bc, k)
		}				
		backSubstitute(ac, bc)
	}	
}

M2代码用于把线性拟合问题转化为解线性方程组的问题:

import scala.math._

object M2 {
	
	def fxi(x: Double, m: Int): Double = {
		if (m == 0) 1.0 else pow(x, m)		
	}

	def vectorATimesB(ac: Array[Double], bc: Array[Double], wc: Array[Double]): Double = {		
		val n = ac.length
		(0 until n).map(i => ac(i) * bc(i) * wc(i)).sum
	}
	
	def getVI(xc: Array[Double], m: Int): Array[Double] = {
		xc.map(i => fxi(i, m)).toArray
	}

	def formAandB(xc: Array[Double], yc: Array[Double], wc: Array[Double]):
	    (Array[Array[Double]], Array[Double]) = {
			
		val xl = xc.length - 1
		val n = 4
		var ma: Array[Array[Double]] = Array.ofDim[Double](4, 4)
		var mb: Array[Double] = new Array(4)
		
		for (cm <- 0 to 3) {
			var vi: Array[Double] = getVI(xc, cm)
			ma(cm)(cm) = vectorATimesB(vi, vi, wc)
			mb(cm) = vectorATimesB(vi, yc, wc)
			
			for (j <- cm+1 to 3) {
				var vj: Array[Double] = getVI(xc, j)
				ma(cm)(j) = vectorATimesB(vi, vj, wc)
				ma(j)(cm) = ma(cm)(j)
			}
		}			
		(ma, mb)
	}

	def sqrFit(xc: Array[Double], yc: Array[Double], wc: Array[Double]): Array[Double] = {
		val AandB = formAandB(xc, yc, wc)
		val ma = AandB._1
		var mb = AandB._2
		M3.columnPivotGauss(ma, mb)
		mb					
	}		
}

Main 用来输入两列数值与设置权重数据,并求得三次函数拟合的参数:

import scala.io.StdIn.readLine

object Main {
	def main(args: Array[String]): Unit = {		
		var xc = Array[Double](0.00, 0.25, 0.50, 0.75, 1.00)
		var yc = Array[Double](0.1, 0.35, 0.81, 1.09, 1.96)
		var wc = Array[Double](1.0, 1.0, 1.0, 1.0, 1.0)
		
		val cs = M2.sqrFit(xc, yc, wc)
		
		cs.foreach(println(_))	

		readLine()
	}	
}

拟合结果如下,依次是常数、一次项参数、二次项参数、三次项参数:
在这里插入图片描述

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值