【知识建设】线性插值——代码导向

目标

client端生成预测的方式,考虑将网络的输入(start, end)先转化为该点对应的prob,然后通过线性插值的方法,将N个样本扩充为server端聚合的标准维度

线性插值

因为要实现的维度看起来无法直接掉包(torch.nn.functional.interpolate),这里找到一个博主的python实现,进行注释与理解

import matplotlib.pyplot as plt

"""
@brief:   计算n阶差商 f[x0, x1, x2 ... xn]
@param:   xi   所有插值节点的横坐标集合 在vmr任务中就是N个(start, end)对
@param:   fi   所有插值节点的纵坐标集合 在vmr任务中就是N个经过fc转化的prob值
@return:  返回xi的i阶差商(i为xi长度减1) 暂时不太明白
@notice:  a. 必须确保xi与fi长度相等
          b. 由于用到了递归,所以留意不要爆栈了.
          c. 递归减递归(每层递归包含两个递归函数), 每层递归次数呈二次幂增长,总次数是一个满二叉树的所有节点数量(所以极易栈溢出)                 
"""
def get_order_diff_quot(xi = [], fi = []):
    if len(xi) > 2 and len(fi) > 2:
        return (get_order_diff_quot(xi[:len(xi) - 1], fi[:len(fi) - 1]) - get_order_diff_quot(xi[1:len(xi)], fi[1:len(fi)])) / float(xi[0] - xi[-1])
    return (fi[0] - fi[1]) / float(xi[0] - xi[1])




"""
@brief:  获得Wi(x)函数;
         Wi的含义举例 W1 = (x - x0); W2 = (x - x0)(x - x1); W3 = (x - x0)(x - x1)(x - x2)
@param:  i  i阶(i次多项式)
@param:  xi  所有插值节点的横坐标集合
@return: 返回Wi(x)函数
"""
def get_Wi(i = 0, xi = []):
    def Wi(x):
        result = 1.0
        for each in range(i):
            result *= (x - xi[each])
        return result
    return Wi
"""
@brief: 获得牛顿插值函数
@
"""
def get_Newton_inter(xi = [], fi = []):  
# 这个函数比较难理解,第一次调用的时候,并不会执行下面的Newton_inter
# 而是把Newton_inter这个函数作为一个对象赋予了声明的变量
# 之后再调用这个变量并传入Newton_inter需要的参数x才是真正执行了下面了过程
    def Newton_inter(x):
        result = fi[0]
        for i in range(2, len(xi)):
            result += (get_order_diff_quot(xi[:i], fi[:i]) * get_Wi(i-1, xi)(x))
        return result
    return Newton_inter
"""
demo:
"""
if __name__ == '__main__':

    ''' 插值节点, 这里用二次函数生成插值节点,每两个节点x轴距离位10 '''
    sr_x = [i for i in range(-50, 51, 10)]
    # 在[-50, 50]间按照10为step取值赋予x,则sr_x维度为(11,)
    # 这里需要给定client端的自变量(start, end)对,长度为N
    sr_fx = [i**2 for i in sr_x]                
    # sr_x中值的平方,维度为(11,)
    # 这里需要给定client端的预测值prob,长度为N

    Nx = get_Newton_inter(sr_x, sr_fx)
    # 获得插值函数

    tmp_x = [i for i in range(-50, 51)]
    # 要取样的位置(X, Y)对,或者positions,大小为100*100
    tmp_y = [Nx(i) for i in tmp_x]               
    # 根据插值函数获得要预测的分布prob,大小为10000

    ''' 画图 '''
    plt.figure("I love china")
    ax1 = plt.subplot(111)
    plt.sca(ax1)
    plt.plot(sr_x, sr_fx, linestyle = '', marker='o', color='b')
    plt.plot(tmp_x, tmp_y, linestyle = '--', color='r')
    plt.show()

但是目前这是单变量的线性插值,我需要的是双变量的线性插值,于是简单查了一下资料,发现插值法是数学分析中的一大类方法,按照插值是否在提供的边界之内,可以分为内插和外插;按照插值作用的函数维度,可以分为一维插值法和二维插值法:

一维插值法二维插值法
牛顿插值法最近邻插值法
拉格朗日插值法双线性插值法
三次多项式插值法

以上仅是简单列举,还有很多方法

插值法在python中已经有很多实现方式了,比较常用的是scipy包和torch

torch包提供的方法是torch.nn.functional.interpolate,比较局限,大致看了一眼,需要提供scale_factor参数(放缩倍数),而且不知道是否支持外插,不太符合我的使用场景

scipy包提供的方法就比较多了,总能找到适合自己的一款,这里简单总结一下:

interp1dinterp2dgriddate
一维插值二维插值非结构化数据插值

导入包:

from scipy.interpolate import interp1d, interp2d, griddata

当然,scipy提供的插值包不止上面这三种,大家可以按需选取

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值