处理plt.scatter报错的方法

错误概述

我在做Andrew Ng的deeplearning 的class2 week1 1.initialization作业中,在运行以下代码时报错:

plt.title("Model with Zeros initialization")
axes = plt.gca()
axes.set_xlim([-1.5,1.5])
axes.set_ylim([-1.5,1.5])
plot_decision_boundary(lambda x: predict_dec(parameters, x.T), train_X, train_Y.reshape)

报错的内容如下:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-37-8492eee47c7e> in <module>
      3 axes.set_xlim([-1.5,1.5])
      4 axes.set_ylim([-1.5,1.5])
----> 5 plot_decision_boundary(lambda x: predict_dec(parameters, x.T), train_X, train_Y.reshape)

/disk6/kurt/download/courses/deeplearn/deeplearning.ai/02-课后作业/02/01/assignment1/init_utils.py in plot_decision_boundary(model, X, y)
    215     plt.ylabel('x2')
    216     plt.xlabel('x1')
--> 217     plt.scatter(X[0, :], X[1, :], c=y.reshape(x[0, :].shape), cmap=plt.cm.Spectral)
    218     plt.show()
    219 

/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/matplotlib/pyplot.py in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, data, **kwargs)
   2860         vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
   2861         verts=verts, edgecolors=edgecolors, **({"data": data} if data
-> 2862         is not None else {}), **kwargs)
   2863     sci(__ret)
   2864     return __ret

/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1808                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1809                         RuntimeWarning, stacklevel=2)
-> 1810             return func(ax, *args, **kwargs)
   1811 
   1812         inner.__doc__ = _add_data_doc(inner.__doc__,

/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, **kwargs)
   4208         else:
   4209             try:  # First, does 'c' look suitable for value-mapping?
-> 4210                 c_array = np.asanyarray(c, dtype=float)
   4211                 n_elem = c_array.shape[0]
   4212                 if c_array.shape in xy_shape:

/disk6/kurt/anaconda3-5/envs/mytorch/lib/python3.6/site-packages/numpy/core/numeric.py in asanyarray(a, dtype, order)
    542 
    543     """
--> 544     return array(a, dtype, copy=False, order=order, subok=True)
    545 
    546 

TypeError: float() argument must be a string or a number, not 'builtin_function_or_method'

解决方法

当然,这些报文可以先不管,看一个大概可以知道,应该是使用plot_decision_boundary函数的时候出错了,而在使用这个函数的时候,该函数调用了plt.scatter()函数
于是找到plt.scatter函数,发现在init_utils.py这个文件中关于此函数有两处使用,而本处的代码使用的是plot_decision_boundary()中的:

    plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)

参考官方文档后可以发现,其实出错的原因就是因为plt,scatter做了微小的改动,在给参数c赋值的时候,必须保证其是一个行向量,而且要和前面的X[0,:],X[1,:]的维度保持一致
因此可以对此进行更改:

    plt.scatter(X[0, :], X[1, :], c=y.reshape(X[0, :].shape), cmap=plt.cm.Spectral)

当然,也可以在调用plot_decision_boundary的时候对输入变量train_Y进行reshape如下:

plot_decision_boundary(lambda x: predict_dec(parameters, x.T), train_X, train_Y.reshape(train_X[0,:].shape))

总而言之,必须对plt.scatter函数的参数c赋值的时候达到两个要求
1.与前面输入的两个参数维度一样
2. 行向量
对于init_utils.py种另一处使用plt.scatter函数的地方仍然要求如此,但是由于此时输入的train_Y符合要求,就不再进行修改了
当然,关于其后续作业中类似的问题,处理方式也是一样的

By the way

当然,其他的方法也可以参考一下:
https://blog.csdn.net/czp_374/article/details/84331029
包括博客的评论区所提到的方法也是可以的
关于week1种另外几个作业出现的问题,参见博客:
https://blog.csdn.net/skylark0924/article/details/80322165
https://blog.csdn.net/weixin_43748786/article/details/90110180

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值