pytroch学习记录三维绘图plot_surface
绘制三维图
流程:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#必要的库
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
#这个是一个线性模型y=2x,后面计算损失函数
def forward(x):
return x * w + b
#返回预测值
def loss(x, y):
y_pred =
原创
2021-10-04 22:08:58 ·
1025 阅读 ·
0 评论