错误概述
我在做Andrew Ng的deeplearning 的class2 week1 1.initialization作业中,在运行以下代码时报错:
plt.title("Model with Zeros initialization")
axes = plt.gca()
axes.set_xlim([-1.5,1.5])
axes.set_ylim([-1.5,1.5])
plot_decision_boundary(lambda x: predict_dec(parameters, x.T), train_X, train_Y.reshape)
报错的内容如下:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-37-8492eee47c7e> in <module>
3 axes.set_xlim([-1.5,1.5])
4 axes.set_ylim([-1.5,1.5])
----> 5 plot_decision_boundary(lambda x: predict_dec(parameters, x.T), train_X, train_Y.reshape)
/disk6/kurt/download/courses/deeplearn/deeplearning.ai/02-课后作业/02/01/assignment1/init_utils.py in plot_decision_boundary(model, X, y)
215 plt.ylabel('x2')
216 plt.xlabel('x1')
--> 217 plt.scatter(X[0, :], X[1, :], c=y.reshape(x[0, :].shape), cmap=plt.cm.Spectral)
218 plt.show()
219
/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/matplotlib/pyplot.py in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, data, **kwargs)
2860 vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
2861 verts=verts, edgecolors=edgecolors, **({"data": data} if data
-> 2862 is not None else {}), **kwargs)
2863 sci(__ret)
2864 return __ret
/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
1808 "the Matplotlib list!)" % (label_namer, func.__name__),
1809 RuntimeWarning, stacklevel=2)
-> 1810 return func(ax, *args, **kwargs)
1811
1812 inner.__doc__ = _add_data_doc(inner.__doc__,
/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, **kwargs)
4208 else:
4209 try: # First, does 'c' look suitable for value-mapping?
-> 4210 c_array = np.asanyarray(c, dtype=float)
4211 n_elem = c_array.shape[0]
4212 if c_array.shape in xy_shape:
/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/numpy/core/numeric.py in asanyarray(a, dtype, order)
542
543 """
--> 544 return array(a, dtype, copy=False, order=order, subok=True)
545
546
TypeError: float() argument must be a string or a number, not 'builtin_function_or_method'
解决方法
当然,这些报文可以先不管,看一个大概可以知道,应该是使用plot_decision_boundary函数的时候出错了,而在使用这个函数的时候,该函数调用了plt.scatter()函数
于是找到plt.scatter函数,发现在init_utils.py这个文件中关于此函数有两处使用,而本处的代码使用的是plot_decision_boundary()中的:
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
参考官方文档后可以发现,其实出错的原因就是因为plt,scatter做了微小的改动,在给参数c赋值的时候,必须保证其是一个行向量,而且要和前面的X[0,:],X[1,:]的维度保持一致
因此可以对此进行更改:
plt.scatter(X[0, :], X[1, :], c=y.reshape(X[0, :].shape), cmap=plt.cm.Spectral)
当然,也可以在调用plot_decision_boundary的时候对输入变量train_Y进行reshape如下:
plot_decision_boundary(lambda x: predict_dec(parameters, x.T), train_X, train_Y.reshape(train_X[0,:].shape))
总而言之,必须对plt.scatter函数的参数c赋值的时候达到两个要求
1.与前面输入的两个参数维度一样
2. 行向量
对于init_utils.py种另一处使用plt.scatter函数的地方仍然要求如此,但是由于此时输入的train_Y符合要求,就不再进行修改了
当然,关于其后续作业中类似的问题,处理方式也是一样的
By the way
当然,其他的方法也可以参考一下:
https://blog.csdn.net/czp_374/article/details/84331029
包括博客的评论区所提到的方法也是可以的
关于week1种另外几个作业出现的问题,参见博客:
https://blog.csdn.net/skylark0924/article/details/80322165
https://blog.csdn.net/weixin_43748786/article/details/90110180