pytorch基础知识十三【2D函数优化实例】

1. 原函数

在这里插入图片描述

2. 求最小值

在这里插入图片描述

3. 代码

import  numpy as np
from    mpl_toolkits.mplot3d import Axes3D
from    matplotlib import pyplot as plt
import  torch

# 安装了多个版本的numpy会报错,使用以下语句含义为允许多个副本存在
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# 忽略警告。matplotlib版本有更新,旧的版本被弃用,会报警告
import warnings
warnings.filterwarnings('ignore')



def himmelblau(x):
    return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2


x = np.arange(-6, 6, 0.1)
y = np.arange(-6, 6, 0.1)
print('x,y range:', x.shape, y.shape)
X, Y = np.meshgrid(x, y)
print('X,Y maps:', X.shape, Y.shape)
Z = himmelblau([X, Y])

fig = plt.figure('himmelblau')
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z)
ax.view_init(60, -30)
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()


# [1., 0.], [-4, 0.], [4, 0.]
x = torch.tensor([-4., 0.], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-3)
for step in range(20000):

    pred = himmelblau(x)	# 获得预测值

    optimizer.zero_grad()   # 梯度信息清零
    pred.backward()			# 求得x和y的梯度信息
    optimizer.step()		# x和y梯度更新

    if step % 2000 == 0:
        print ('step {}: x = {}, f(x) = {}'
               .format(step, x.tolist(), pred.item()))



输出结果:
	x,y range: (120,) (120,)
	X,Y maps: (120, 120) (120, 120)
	step 0: x = [-3.999000072479248, -0.0009999999310821295], f(x) = 146.0
	step 2000: x = [-3.526559829711914, -2.5002429485321045], f(x) = 19.4503231048584
	step 4000: x = [-3.777446746826172, -3.2777843475341797], f(x) = 0.0012130826944485307
	step 6000: x = [-3.7793045043945312, -3.283174753189087], f(x) = 5.636138666886836e-09
	step 8000: x = [-3.779308319091797, -3.28318190574646], f(x) = 7.248672773130238e-10
	step 10000: x = [-3.7793095111846924, -3.28318452835083], f(x) = 8.822098607197404e-11
	step 12000: x = [-3.7793102264404297, -3.2831854820251465], f(x) = 8.185452315956354e-12
	step 14000: x = [-3.7793102264404297, -3.2831859588623047], f(x) = 0.0
	step 16000: x = [-3.7793102264404297, -3.2831859588623047], f(x) = 0.0
	step 18000: x = [-3.7793102264404297, -3.2831859588623047], f(x) = 0.0


在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值