如果输入数据长度为2,上一章的方程就无法满足需求了,需要修改方程:
z
=
w
1
x
+
w
2
y
+
b
z=w_1x+w_2y+b
z=w1x+w2y+b
数据产生器:
import matplotlib.pyplot as plt
import numpy as np
class DataGenerator2Input:
"""
线性回归数据产生器, 方程:z = w1 * x + w2 * y + b
"""
def __init__(self, w1, w2, b):
self.w1 = w1
self.w2 = w2
self.b = b
def __call__(self, data_len):
input_data = np.random.uniform(-50, 50, [data_len, 2]) # 生成 x, y
labels = self.w1 * input_data[:, 0] + self.w2 * input_data[:, 1] + self.b # 生成 z
# 加随机误差
noise = np.random.uniform(-20, 20, data_len)
labels += noise
return input_data, labels
w1, w2, b = 3.5, 7.1, 17
input_datas, labels = DataGenerator2Input(w1, w2, b)(5000)
# 可视化
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.scatter(labels, input_datas[:, 0], input_datas[:, 0])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
分段函数问题