Whale Optimization Algorithm

Whale Optimization Algorithm

Author:雾雨霜星

Time:2023-02-27

前言

新学期开始了,考完试之后没什么任务,无聊之下,于是想着找些东西来看看,尝试使用Python复现各类优化算法。鲸鱼优化算法(Whale Optimization Algorithm)就是第一个尝试。大约三天,复现了鲸鱼优化算法、基于柯西分布变异和自适应权重的鲸鱼优化算法、增强型鲸鱼优化算法。

MATLAB源码:WOA (seyedalimirjalili.com)

Python源码:WhaleOptimizationAlgorithm(gitee.com)

WOA原理及其源码构建

作为一种生物启发式优化算法,WAO通过类比鲸鱼群体寻找食物的过程,来构建优化算法流程。

基本步骤包括:种群初始化、确认座头鲸、搜索觅食阶段(随机游走、螺旋行走)、搜索包围阶段。

  1. 问题描述

    对于待优化最小值问题f,解是一个d维向量,且每个维度取值范围为lb[i]<di<ub[i]。

    类比为鲸鱼种群寻找食物,每个解看作一个位置,最优解就是食物所在的位置。

  2. 初始化

    根据范围列表ub和lb随机生成n个d维向量,作为问题f的初始解。

    类比为生成一个鲸鱼种群X,每个个体具有一个位置向量X[i]。

    X = [[random.uniform(lb[i], ub[i]) for i in range(d)] for j in range(n)]
    X = numpy.asarray(X)
    
  3. 开始迭代

  4. 计算适应度以及更新种群最优解

    首先对解向量范围检查,越界的值设置回边界值。然后计算适应度,与当前最优适应度进行比较。

  5. 计算当前轮次的计算参数

    参数a与a2都是随着迭代衰减的参数。

    # param
    a = 2 * (1 - t / T)
    a2 = -1 - t / T
    

    个体位置更新

    • 生成随机数计算行动步长参数A和随机参数C

      # random seed
      r1 = random.random()
      r2 = random.random()
      # Param of "Surrounding predation stage"
      A = 2 * a * r1 - a
      C = 2 * r2
      
    • 生成行动随机概率

      用于模拟鲸鱼个体的行动随机性。

      p = random.random()
      
    • 若随机行动参数小于0.5

      此时,认为鲸鱼进入搜索包围阶段,根据行动步长系数A判断是随机游走还是朝着最优位置游走。

      if p < 0.5:  # Shrink Surround
          if abs(A) >= 1:  # Random walk
              rand_index = random.randint(0, n - 1)
              D_rand = abs(C * X[rand_index][j] - x[j])
              X[i][j] = X[rand_index][j] - A * D_rand
          else:  # Walk towards the best position
              D = abs(C * X[best_index][j] - x[j])
              X[i][j] = X[best_index][j] - A * D
      
    • 若行动参数大于0.5

      此时,认为鲸鱼进入螺旋行走阶段

      else:  # Spiral update position
          b = 1
          Dl = abs(X[best_index][j] - x[j])
          X[i][j] = Dl * math.exp(b * g) * math.cos(2 * math.pi * g) + X[best_index][j].copy()
      
  6. 记录当前轮次最优值。

WOAWC原理及其源码构建

参考文献:[1]郭振洲,王平,马云峰等.基于自适应权重和柯西变异的鲸鱼优化算法[J].微电子学与计算机,2017,34(09):20-25.

WOAWC:基于自适应权重和柯西变异的鲸鱼优化算法。

在WOA的基础上,在随机游走阶段,引入柯西变异(采用柯西逆累积分布函数)进行位置更新;在朝最优位置游走阶段内引入自适应权重。

  • 柯西变异更新位置

    通过柯西分布有很长的尾巴的特点,让鲸鱼个体朝更广的范围变异。
    X ⃗ ( t + 1 ) = X ⃗ ( t ) + A ⋅ tan ⁡ ( π ⋅ ( r − 1 2 ) ) \vec{X}(t+1)=\vec{X}(t)+A\cdot\tan(\pi\cdot(r-\frac{1}{2})) X (t+1)=X (t)+Atan(π(r21))

  • 自适应权重计算式

    根据当前迭代轮次计算权重,基于最优位置更新个体位置。
    ω = sin ⁡ ( π ⋅ t 2 T + π ) + 1 X ⃗ ( t + 1 ) = ω X ∗ ⃗ ( t ) − A ⋅ D \omega=\sin(\frac{\pi\cdot{t}}{2T}+\pi)+1\\\vec{X}(t+1)=\omega\vec{X^{*}}(t)-A\cdot{D} ω=sin(2Tπt+π)+1X (t+1)=ωX (t)AD

其中,r是0~1范围内的随机数,T是最大迭代次数,A,D参数计算与WOA原理一致。

因此,在原WOA源码的基础上进行改进:

首先是在Random walk部分:

if abs(A) >= 1:  # Random walk
    # Random number analogy is random probability
    r = random.random()
    # Introduction of Cauchy inverse cumulative distribution for variation
    X[i][j] = X[i][j] + A * math.tan(math.pi * (r - 1/2))

然后是Walk towards the best position部分:

else:  # Walk towards the best position
    D = abs(C * X[best_index][j] - x[j])
    # Adaptive weight when introducing local optimization
    w = math.sin(math.pi * (t / T) / 2 + math.pi) + 1
    X[i][j] = w * X[best_index][j] - A * D

Mohammad.H-EWOA原理及其源码构建

参考文献:Nadimi-Shahraki Mohammad H. and Zamani Hoda and Mirjalili Seyedali. Enhanced whale optimization algorithm for medical feature selection: A COVID-19 case study[J]. Computers in Biology and Medicine, 2022, 148 : 105858-105858.

目前鲸鱼优化算法领域最新的研究进展。据文中所言,该算法有一种二进制形式的变体在医药领域有极好的表现。

Mohammad.H-EWOA在WOAWC的基础上(但是A参数的计算和螺旋更新位置的计算上存在差异),引入了池化机制(Pool mechanism)、迁移搜索策略(Migrating search strategy)、优先选择的搜索策略(Preferential selecting search strategy),对包围捕食阶段进行了改进。

Pool mechanism

将表现较差的解与表现最好的解进行交叉生成新的解,放入Pool中,在新一轮迭代中使用。

  1. 基础数据

    包括计算Pool剩余未填充空间、获取最优解的max和min值。

    pool_length = len(pool)
    free_space = pool_size - pool_length
    best_max = max(X_best)
    best_min = min(X_best)
    
  2. 生成当前最优解附近的随机解

    # Random location near X_best
    X_brnd = [random.random() * (best_max - best_min) + best_min for j in range(len(X_best))]
    
  3. 生成随机交叉向量

    B = [(0 if random.random() < 0.5 else 1) for j in range(len(X_best))]  # random binary vector
    
  4. 交叉产生新解

    # cross
    Pi = [X_brnd[j] if B[j] == 1 else X_worst[i][j] for j in range(len(B))]
    P_row.append(Pi)
    
  5. 对新生成解进行重复解排查

    # Exclude recurring elements
    P = []
    [P.append(pt) for pt in P_row if pt not in P]
    
  6. 随机覆盖历史记录以更新Pool

    若剩余空间充足,则将新产生解填入Pool,否则则进行填充直至空间不足,将余下未填充解随机选取Pool历史记录解进行覆盖。

    # Updata pool
    pool_row = []
    if free_space > 0:
        if free_space >= len(P):  # Enough space left for direct storage
            pool_row = pool + P
        else:  # Put in part, and the rest will randomly cover the history
            pool_row = pool + P[:free_space]
            rm = len(P) - free_space
            random_index_list = [random.randint(0, pool_length - 1) for i in range(rm)]
            for i, rix in enumerate(random_index_list):
                pool_row[rix] = P[free_space + i]
    else:
        pool_row = pool.copy()
        random_index_list = [random.randint(0, pool_length - 1) for i in range(len(P))]
        for i, rix in enumerate(random_index_list):
            pool_row[rix] = P[free_space + i]
    
  7. 排除Pool中重复解

    # Exclude recurring elements
    pool_res = []
    [pool_res.append(pt) for pt in pool_row if pt not in pool_res]
    

Pool mechanism被封装为一个函数,只需要提供当前Pool、较差解集合、当前最优解即可使用。

在EWOA的每一轮次迭代的最后都会进行Pool更新。

Migrating search strategy

每一轮迭代中,随机选取一部分种群内个体,进行迁移搜索,以扩大搜索范围,保持种群多样性。

迁移搜索策略中位置更新的计算公式如下:
X i t + 1 = X r n d t − X b r n d t X r n d t = r 1 ∗ ( δ m a x − δ m i n ) + δ m i n X b r n d t = r 2 ∗ ( δ b e s t _ m a x − δ b e s t _ m a x ) + δ b e s t _ m a x X_{i}^{t+1}=X_{rnd}^{t}-X_{brnd}^{t}\\X_{rnd}^{t}=r1*(\delta_{max}-\delta_{min})+\delta_{min}\\X_{brnd}^{t}=r2*(\delta_{best\_max}-\delta_{best\_max})+\delta_{best\_max} Xit+1=XrndtXbrndtXrndt=r1(δmaxδmin)+δminXbrndt=r2(δbest_maxδbest_max)+δbest_max
其中,r1与r2是0~1随机数,δ代表了取值范围以及最优解的max和min值。

随机选取种群个体:

# Select a part of the population at random
selected_index = random.sample([i for i in range(n)], int(ro * n))

迁移搜索更新位置:

if i in selected_index:
    X_rnd = [random.random() * (ub[k] - lb[k]) + lb[k] for k in range(d)]
    X_brnd = [random.random() * (max(X_best) - min(X_best)) + min(X_best) for k in range(d)]
    X[i] = [X_rnd[k] - X_brnd[k] for k in range(d)]
Preferential selecting search strategy

此步骤与WOAWC中Shrink Surround阶段的Random walk步骤类似。

区别在于,此处采用Pool中的两个随机解代替了解集中的随机解和当前个体解。

  1. 计算行动步长参数

    据论文中所言,此优先搜索策略也采用了柯西变异,而论文的MATLAB源码中,将柯西变异相关计算放到了该A参数处,但论文内却没有做出详细解释。
    A = 0.5 + 0.1 tan ⁡ ( π ⋅ ( r − 1 2 ) ) A=0.5+0.1\tan(\pi\cdot(r-\frac{1}{2})) A=0.5+0.1tan(π(r21))
    代码如下:

    A = 0.5 + 0.1 * math.tan(math.pi * (random.random() - 0.5))
    while A < 0:
        A = 0.5 + 0.1 * math.tan(math.pi * (random.random() - 0.5))
    A = min(A, 1)
    
  2. 随机选取Pool池解

    参考MATLAB源码,设计者规避了index上的重复。

    r1 = random.randint(0, len(pool) - 1)
    while r1 == i: r1 = random.randint(0, len(pool) - 1)
    r2 = random.randint(0, len(pool) - 1)
    while r2 == i or r2 == r1: r2 = random.randint(0, len(pool) - 1)
    
  3. 位置更新

    for j in range(d):
        X[i][j] = Positions[i][j] + A * (C * pool[r1][j] - pool[r2][j])
    
Enriched encircling prey search strategy

丰富的包围捕食策略,类似于WOA中Shrink Surround的Walk towards the best position。

区别在于,使用Pool中随机解代替了当前个体解。
X i t + 1 = X b e s t t − A i t ⋅ D D = ∣ C i t ⋅ X b e s t t − P r n d t ∣ X_{i}^{t+1}=X_{best}^{t}-A_{i}^{t}\cdot{D}\\D=|C_{i}^{t}\cdot{X_{best}^{t}}-P_{rnd}^{t}| Xit+1=XbesttAitDD=CitXbesttPrndt
代码如下:

 for j in range(d):
        r3 = random.randint(0, len(pool) - 1)
        X[i][j] = X_best[j] - A * abs(C * X_best[j] - pool[r3][j])
整体思路
  1. 初始化

    首先随机生成解向量

    # Position initialization
    Positions = [[random.uniform(lb[i], ub[i]) for i in range(d)] for j in range(n)]
    Positions = numpy.asarray(Positions)
    X_best = Positions[0].copy()
    

    然后计算适应度并排序,选取后面几个作为较差位置并更新当前最优解

    # Record fitness
    fm = [0 for i in range(n)]
    for i, x in enumerate(Positions):  fm[i] = f(Positions[i].copy())  # Calculate fitness
    # Sort the fitness from small to large and return the sorted index list(find min value use "reverse=False")
    sorted_index = sorted(range(len(fm)), key=lambda fmx: fm[fmx], reverse=False)
    # Set best position as leader(find min value use <)
    if fm[sorted_index[0]] < best_score:
        best_score = fm[sorted_index[0]]
        X_best = Positions[sorted_index[0]].copy()
    # Find the worst position
    worst_size = int(n - 0.3 * poolSize + 1)
    X_worst = [Positions[i] for i in sorted_index[-worst_size:]]
    

    接着使用池机制生成初始pool

    # pool initialization
    pool = []
    poolSize = int(1.5 * n)
    # Pool Mechanism
    pool = poolMechanism(pool=pool, X_best=X_best.copy(), X_worst=X_worst, pool_size=poolSize)
    
  2. 开始迭代

  3. 随机选取index确定参与迁移策略个体

    # Select a part of the population at random
    selected_index = random.sample([i for i in range(n)], int(ro * n))
    
  4. 随机生成行动概率p

  5. 判断当前解个体index是否在selected_index内,若在则进行迁移搜索策略

  6. 对于不在selected_index内个体,根据行动概率p,若p>0.5则进行螺旋更新位置

  7. 对于上述个体若p<0.5,根据当前参数A的绝对值若|A|>=0.5则采用Preferential selecting search strategy,否则采用Enriched encircling prey search strategy

  8. 对于新生成的解,检查边界并将越界值返回到边界

  9. 对新生成的解计算适应度,若较上一轮次解表现更好,则将上一轮次解加入X_worst并更新解记录

    # Updata best position
    X_worst = []
    for i in range(n):
        y = f(X[i].copy())  # Calculate fitness
        if y < best_score:
            best_score = y
            X_best = X[i].copy()
        # Compare individual fitness before and after iteration to update the position
        if y < fm[i]:
            # Put the poor position in the last iteration result into the X_worst
            X_worst.append(Positions[i].copy())
            # Updata position and fitness record
            Positions[i] = X[i].copy()
            fm[i] = y
    
  10. 采用池机制更新pool

    # Pool Mechanism
    pool = poolMechanism(pool=pool, X_best=X_best.copy(), X_worst=X_worst, pool_size=poolSize)
    
  11. 记录当前轮次最优适应度

EWOA

参考文献:[1]冯文涛,宋科康.一种增强型鲸鱼优化算法[J].计算机仿真,2020,37(11):275-279+357.

注:下文称此算法为EWOA2!

与Mohammad.H的EWOA算法不同,该EWOA是在WOA的基础上改进的,主要在以下几个方面:

  • 非线性时变的自适应权重
    ω = { 1 2 [ 1 + cos ⁡ ( π t T ) ] 1 / k , t ≤ T / 2 1 2 [ 1 − cos ⁡ ( π + π t T ) ] 1 / k , x > T / 2 \omega=\begin{cases}\frac{1}{2}[1+\cos(\frac{\pi{t}}{T})]^{1/k},\quad t \leq T/2 \\\frac{1}{2}[1-\cos(\pi+\frac{\pi{t}}{T})]^{1/k},\quad x > T/2\end{cases} ω={21[1+cos(Tπt)]1/k,tT/221[1cos(π+Tπt)]1/k,x>T/2
    其中,T是最大迭代次数,k是调整参数(论文中取定值2)。

    代码如下:

    # Calculate Weught
    if t <= T / 2:
        w = 0.5 * pow(1 + math.cos(math.pi * t / T), 1 / k)
    else:
        w = 0.5 * pow(1 - math.cos(math.pi + math.pi * t / T), 1 / k)
    

    将此权重应用于Random walk和Spiral update position阶段。(后者的代码详见后面)

    if abs(A) >= 1:  # Random walk
        rand_index = random.randint(0, n - 1)
        D_rand = abs(C * X[rand_index][j] - x[j])
        positions[i][j] = X[rand_index][j] - w * A * D_rand
    
  • 差分变异微扰因子
    λ = F ( X b e s t ( t ) − X ( t ) ) \lambda=F(X_{best}(t)-X(t)) λ=F(Xbest(t)X(t))
    其中F是差异尺度系数,论文中取定值0.6。

    将此差分变异因子应用到Walk towards the best position阶段。

    # Walk towards the best position
    lamda = F * (X[best_index][j] - X[i][j])  # Differential perturbation factor
    D = abs(C * X[best_index][j] - x[j])
    positions[i][j] = X[best_index][j] - A * D + lamda
    
  • 改进的螺旋更新方式

    原本WOA中螺旋更新位置部分,其中的指数计算部分去除指数,保留指数的幂直接乘上。

    代码如下:

    b = 1
    Dl = abs(X[best_index][j] - x[j])
    positions[i][j] = w * Dl * (b * g) * math.cos(2 * math.pi * g) + X[best_index][j].copy()
    

性能对比

选取两个测试函数对上述WOA及其变体算法进行测试比较。

测试函数

f 1 ( X ) = min ⁡ ∑ i = 0 k − 1 [ 100 ∗ ( x i + 1 − x i 2 ) 2 + ( x i − 1 ) 2 ] x i ∈ [ − 30 , 30 ] , k = 20 f 2 ( X ) = min ⁡ ∑ i = 0 k ∣ x i ∣ + ∏ i = 0 k ∣ x i ∣ x i ∈ [ − 10 , 10 ] , k = 10 f_{1}(X)=\min\sum^{k-1}_{i=0}[100*(x_{i+1}-x_{i}^{2})^{2}+(x_{i}-1)^{2}]\quad x_{i}\in[-30,30],k=20\\ f_{2}(X)=\min\sum^{k}_{i=0}|x_{i}|+\prod^{k}_{i=0}|x_{i}|\quad x_{i}\in[-10,10],k=10 f1(X)=mini=0k1[100(xi+1xi2)2+(xi1)2]xi[30,30],k=20f2(X)=mini=0kxi+i=0kxixi[10,10],k=10

f1与f2的理论最优值均为0。

算法参数
# 对于f1
population_numbers1 = 100
dimension1 = 20
iterations1 = 100
# 对于f2
population_numbers2 = 100
dimension2 = 10
iterations2 = 100
结果对比
  • 测试函数1结果对比(此处EWOA为Mohammad.H’s EWOA):

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nNZvWLHe-1682955776190)(http://marisa-kirisame.gitee.io/bin/pictures/WOA算法及其变体性能对比测试函数1.PNG)]

  • 测试函数2结果对比(此处EWOA为Mohammad.H’s EWOA):

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dAGIMHj2-1682955776190)(http://marisa-kirisame.gitee.io/bin/pictures/WOA算法及其变体性能对比测试函数2.PNG)]

  • 补充EWOA后的结果对比:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gDOKecJB-1682955776191)(http://marisa-kirisame.gitee.io/bin/pictures/WOA算法及其变体性能对比测试函数(补充EWOA2)].PNG)

  • 算法性能测试结果:

    参数配置:种群数量为100,迭代次数为100。

    重复100次计算求解结果和求解时间的平均值。以下记录均为平均值结果。

    Algorithmf1 resf1 timef2 resf2 time
    WOA18.64720.62400.0013420.2459
    WOAWC18.58090.63104.4×10-320.2415
    M’sEWOA17291.80.87815.99×10-60.3761
    EWOA17.83610.67541.4×10-200.2540

插曲

在复现Mohammad.H-EWOA时,我用Python写完后,运行MATLAB源码内提供的测试函数进行检验,发现MATLAB的计算结果总是可以去到1×10-68左右,而我在Python上运行的结果总是只有1×10-54。检查了很久还是没能想明白为什么会出现这个问题,以至于我得出了Python运行结果可能由于精度等问题不及MATLAB的结果。直到今天开始写记录总结,层层分析检查下,终于发现,我在最后一步中X_worst=[]的操作放在了for循环里,这意味着pool机制可能一直没发挥作用(因为没有提供X_worst),我马上修改将其移植for循环前,终于运行结果也达到了与MATLAB运行结果类似的水平。

再一次让我感受到写总结的作用。(其实就是对着代码检查没能察觉这些细微的问题…)

新BUG的发现

部分代码中,在对越界变量的设置时没有考虑原解集的重新赋值,该问题需要将check bounds部分修改为:

# Check bounds
            for j in range(len(x)):
                if x[j] > ub[j]:
                    x[j] = ub[j]
                    X[i][j] = ub[j]
                elif x[j] < lb[j]:
                    x[j] = lb[j]
                    X[i][j] = lb[j]
                else:
                    continue

rst=[]的操作放在了for循环里,这意味着pool机制可能一直没发挥作用(因为没有提供X_worst),我马上修改将其移植for循环前,终于运行结果也达到了与MATLAB运行结果类似的水平。

再一次让我感受到写总结的作用。(其实就是对着代码检查没能察觉这些细微的问题…)

新BUG的发现

部分代码中,在对越界变量的设置时没有考虑原解集的重新赋值,该问题需要将check bounds部分修改为:

# Check bounds
            for j in range(len(x)):
                if x[j] > ub[j]:
                    x[j] = ub[j]
                    X[i][j] = ub[j]
                elif x[j] < lb[j]:
                    x[j] = lb[j]
                    X[i][j] = lb[j]
                else:
                    continue
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
1. whale_optimization 在 Matlab 中实现鲸鱼优化算法(Whale Optimization Algorithm)需要先定义目标函数。以下是一个简单的示例: ``` function f = obj_fun(x) f = x(1)^2 + x(2)^2; end ``` 然后,可以使用以下代码来实现鲸鱼优化算法: ``` function [bestSol, bestFitness] = whale_optimization(obj_fun, nVar, lb, ub) % 参数说明: % obj_fun:目标函数句柄 % nVar:变量个数 % lb:每个变量的下限 % ub:每个变量的上限 % 初始化种群 popSize = 10; maxIter = 100; emptyWhale.Position = []; emptyWhale.Fitness = []; pop = repmat(emptyWhale, popSize, 1); for i = 1:popSize pop(i).Position = unifrnd(lb, ub, 1, nVar); pop(i).Fitness = obj_fun(pop(i).Position); end % 迭代优化 for it = 1:maxIter for i = 1:popSize % 更新位置 A = 2 * rand(1, nVar) - 1; C = 2 * rand(1, nVar); l = rand(); p = rand(); for j = 1:nVar if p < 0.5 if abs(A(j)) >= 1 rand_leader_index = floor(popSize * rand() + 1); X_rand = pop(rand_leader_index).Position; D_X_rand = abs(C(j) * X_rand(j) - pop(i).Position(j)); pop(i).Position(j) = X_rand(j) - A(j) * D_X_rand; else D_Leader = abs(C(j) * bestSol.Position(j) - pop(i).Position(j)); pop(i).Position(j) = bestSol.Position(j) - A(j) * D_Leader; end else dist = abs(bestSol.Position(j) - pop(i).Position(j)); pop(i).Position(j) = dist * exp(b * l) * cos(2 * pi * l) + bestSol.Position(j); end end % 对位置进行限制 pop(i).Position = max(pop(i).Position, lb); pop(i).Position = min(pop(i).Position, ub); % 更新适应度 pop(i).Fitness = obj_fun(pop(i).Position); % 更新最优解 if pop(i).Fitness < bestSol.Fitness bestSol = pop(i); end end end % 返回最优解及其适应度 bestFitness = bestSol.Fitness; bestSol = bestSol.Position; end ``` 2. gru_loss 在 Matlab 中实现 GRU 模型的损失函数需要使用交叉熵损失函数。以下是一个简单的示例: ``` function loss = gru_loss(y_pred, y_true) % 参数说明: % y_pred:模型预测结果,大小为 [batch_size, num_classes] % y_true:真实标签,大小为 [batch_size, num_classes] eps = 1e-10; y_pred = max(min(y_pred, 1 - eps), eps); % 防止出现 log(0) 的情况 loss = -sum(y_true .* log(y_pred), 2); end ``` 3. gru_predict 在 Matlab 中使用 GRU 模型进行预测需要先定义模型。以下是一个示例: ``` function model = gru_model(num_classes) % 参数说明: % num_classes:分类数 inputSize = 100; hiddenSize = 64; outputSize = num_classes; model = struct(); model.Wx = randn(inputSize, hiddenSize); model.Wh = randn(hiddenSize, hiddenSize); model.b = zeros(1, hiddenSize); model.Wy = randn(hiddenSize, outputSize); model.by = zeros(1, outputSize); end ``` 然后,可以使用以下代码进行预测: ``` function y_pred = gru_predict(model, X) % 参数说明: % model:GRU 模型 % X:输入数据,大小为 [batch_size, inputSize, sequence_length] [batch_size, inputSize, sequence_length] = size(X); hiddenSize = size(model.Wx, 2); outputSize = size(model.Wy, 2); h = zeros(batch_size, hiddenSize); for t = 1:sequence_length x_t = reshape(X(:, :, t), [batch_size, inputSize]); z_t = sigmoid(x_t * model.Wx + h * model.Wh + model.b); r_t = sigmoid(x_t * model.Wxr + h * model.Whr + model.br); h_tilde_t = tanh(x_t * model.Wxh + (r_t .* h) * model.Whh + model.bh); h = (1 - z_t) .* h + z_t .* h_tilde_t; end y_pred = softmax(h * model.Wy + model.by); end ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值