【飞桨】【图像分类】【PaddlePaddle】【opencv库基本操作】【简单的神经网络搭建】【学习心得】


前言

作为一个白帽子,第一次跨界学习ai相关知识,这不仅代表了ai的强大所带来的吸引力,也说明ai是当今技术发展最具前言和前景的一项技术能力,在飞桨领航班的带领下,这也是我初次接触cv——计算机视觉相关知识,希望通过努力学习,增加自身多方面的技术能力的发展。


一、图像处理的概念与基本操作

1.灰度图片

代码如下(示例):

# 引入依赖包
%matplotlib inline
import numpy as np
import cv2
import matplotlib.pyplot as plt
import paddle
from PIL import Image

In [2]

# 加载一张手写数字的灰度图片
# 从Paddle2.0内置数据集中加载手写数字数据集,本文第3章会进一步说明
from paddle.vision.datasets import MNIST
# 选择测试集
mnist = MNIST(mode='test')
# 遍历手写数字的测试集
for i in range(len(mnist)):
    # 取出第一张图片
    if i == 0:
        sample = mnist[i]
        # 打印第一张图片的形状和标签
        print(sample[0].size, sample[1])
# 查看测试集第一个数字
plt.imshow(mnist[0][0])
print('手写数字是:', mnist[0][1])

总结

分辨率=画面水平方向的像素值 * 画面垂直方向的像素值
屏幕分辨率

例如,屏幕分辨率是1024×768,也就是说设备屏幕的水平方向上有1024个像素点,垂直方向上有768个像素点。像素的大小是没有固定长度的,不同设备上一个单位像素色块的大小是不一样的。

例如,尺寸面积大小相同的两块屏幕,分辨率大小可以是不一样的,分辨率高的屏幕上面像素点(色块)就多,所以屏幕内可以展示的画面就更细致,单个色块面积更小。而分辨率低的屏幕上像素点(色块)更少,单个像素面积更大,可以显示的画面就没那么细致。
图像分辨率

例如,一张图片分辨率是500x200,也就是说这张图片在屏幕上按1:1放大时,水平方向有500个像素点(色块),垂直方向有200个像素点(色块)。

在同一台设备上,图片分辨率越高,这张图片1:1放大时,图片面积越大;图片分辨率越低,这张图片1:1缩放时,图片面积越小。(可以理解为图片的像素点和屏幕的像素点是一个一个对应的)。

但是,在屏幕上把图片超过100%放大时,为什么图片上像素色块也变的越大,其实是设备通过算法对图像进行了像素补足,我们把图片放的很大后看到的一块一块的方格子,虽然理解为一个图像像素,但是其实是已经补充了很多个屏幕像素;同理,把图片小于100%缩小时,也是通过算法将图片像素进行减少。。

二、图像处理第一节课作业实例

In [1]

import cv2
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

In [2]

filename = '1.jpg'
## [Load an image from a file]
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)

<matplotlib.image.AxesImage at 0x7f684ccc5590>

<Figure size 432x288 with 1 Axes>

In [3]

print(img.shape)

(350, 350, 3)

1.图片缩放
In [4]

class Resize:
    def __init__(self, size):
        self.size=size
        
    def __call__(self, img):
        return cv2.resize(img, self.size)
        
        # 此处插入代码



resize=Resize((600, 600))
img2=resize(img)
plt.imshow(img2)

<matplotlib.image.AxesImage at 0x7f684cbfd690>

<Figure size 432x288 with 1 Axes>

2.图片翻转
In [5]

class Flip:
    def __init__(self, mode):
        self.mode=mode
        assert mode in [-1, 0, 1
                             ], "mode should be a value in [-1, 0, 1]"
        self.mode = mode

    def __call__(self, img):
        import random
        if random.randint(0, 1) == 1:
            return cv2.flip(img, self.mode)
        else:
            return img
        # 此处插入代码

filename = '1.jpg'
## [Load an image from a file]
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

flip=Flip(mode=0)
img2=flip(img)
plt.imshow(img2)

<matplotlib.image.AxesImage at 0x7f684cbeb850>

<Figure size 432x288 with 1 Axes>

3图片旋转
In [6]

class Rotate:
    def __init__(self, degree,size):
        self.degree=degree
        self.size=size
      

    def __call__(self, img):
        height,width,_=img.shape
        MatRotate= cv2.getRotationMatrix2D((height*0.5,width*0.5), self.degree, self.size)
        return cv2.warpAffine(img, MatRotate, (height, width))


        # 此处插入代码




rotate=Rotate( 45, 0.7)
img2=rotate(img)
plt.imshow(img2)

<matplotlib.image.AxesImage at 0x7f684cb555d0>

<Figure size 432x288 with 1 Axes>

4.图片亮度调节
In [7]

class Brightness:
    def __init__(self,brightness_factor):
        self.brightness_factor=brightness_factor

    def __call__(self, img):
        shape=img.shape
        dst=np.zeros(shape,img.dtype)
        return cv2.addWeighted(img,self.brightness_factor,dst,1-self.brightness_factor,3)
        # 此处插入代码




brightness=Brightness(0.6)
img2=brightness(img)
plt.imshow(img2)

<matplotlib.image.AxesImage at 0x7f684cabf3d0>

<Figure size 432x288 with 1 Axes>

5.图片随机裁剪
In [9]

import random
import math

class RandomErasing(object):
    def __init__(self, EPSILON=0.5, sl=0.02, sh=0.4, r1=0.3,
                 mean=[0., 0., 0.]):
        self.EPSILON = EPSILON
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1

    def __call__(self, img):
        if random.uniform(0, 1) > self.EPSILON:
            return img

        for attempt in range(100):
            area = img.shape[0] * img.shape[1]

            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1 / self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < img.shape[0] and h < img.shape[1]:
                x1 = random.randint(0, img.shape[1] - h)
                y1 = random.randint(0, img.shape[0] - w)
                if img.shape[2] == 3:
                    img[ x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
                    img[ x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
                    img[ x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
                else:
                    img[x1:x1 + h, y1:y1 + w,0] = self.mean[0]
                return img
            
            # 此处插入代码


        return img


erase = RandomErasing()
img2=erase(img)
plt.imshow(img2)    

<matplotlib.image.AxesImage at 0x7f684ca23fd0>

<Figure size 432x288 with 1 Axes>

2. OpenCV库更多进阶操作

In [2]

import math
import random
import numpy as np
%matplotlib inline
import cv2
import matplotlib.pyplot as plt

In [3]

# 创建一副图片
img = cv2.imread('cat.png')
# 转换颜色通道
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

图像基本操作

学习ROI感兴趣区域,通道分离合并等基本操作。
ROI

ROI:Region of Interest,感兴趣区域。。

截取ROI非常简单,指定图片的范围即可
In [4]

# 截取猫脸ROI
face = img[0:740, 400:1000]
plt.imshow(face)

<matplotlib.image.AxesImage at 0x7fb21cc61cd0>

<Figure size 432x288 with 1 Axes>

通道分割与合并

彩色图的BGR三个通道是可以分开单独访问的,也可以将单独的三个通道合并成一副图像。分别使用cv2.split()和cv2.merge():
In [5]

# 创建一副图片
img = cv2.imread('lena.jpg')

In [6]

# 通道分割
b, g, r = cv2.split(img)

In [7]

# 通道合并
img = cv2.merge((b, g, r))

In [8]

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)

<matplotlib.image.AxesImage at 0x7fb1ea902750>

<Figure size 432x288 with 1 Axes>

In [9]

RGB_Image=cv2.merge([b,g,r])
RGB_Image = cv2.cvtColor(RGB_Image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(12,12))
#显示各通道信息
plt.subplot(141)
plt.imshow(RGB_Image,'gray')
plt.title('RGB_Image')
plt.subplot(142)
plt.imshow(r,'gray')
plt.title('R_Channel')
plt.subplot(143)
plt.imshow(g,'gray')
plt.title('G_Channel')
plt.subplot(144)
plt.imshow(b,'gray')
plt.title('B_Channel')

Text(0.5,1,'B_Channel')

<Figure size 864x864 with 4 Axes>

颜色空间转换

最常用的颜色空间转换如下:

    RGB或BGR到灰度(COLOR_RGB2GRAY,COLOR_BGR2GRAY)
    RGB或BGR到YcrCb(或YCC)(COLOR_RGB2YCrCb,COLOR_BGR2YCrCb)
    RGB或BGR到HSV(COLOR_RGB2HSV,COLOR_BGR2HSV)
    RGB或BGR到Luv(COLOR_RGB2Luv,COLOR_BGR2Luv)
    灰度到RGB或BGR(COLOR_GRAY2RGB,COLOR_GRAY2BGR)

    经验之谈:颜色转换其实是数学运算,如灰度化最常用的是:gray=R*0.299+G*0.587+B*0.114。

    参考资料:OpenCV中的颜色空间

In [10]

img = cv2.imread('lena.jpg')
# 转换为灰度图
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 保存灰度图
cv2.imwrite('img_gray.jpg', img_gray)

True

特定颜色物体追踪

HSV是一个常用于颜色识别的模型,相比BGR更易区分颜色,转换模式用COLOR_BGR2HSV表示。

    经验之谈:OpenCV中色调H范围为[0,179],饱和度S是[0,255],明度V是[0,255]。虽然H的理论数值是0°~360°,但8位图像像素点的最大值是255,所以OpenCV中除以了2,某些软件可能使用不同的尺度表示,所以同其他软件混用时,记得归一化。

相关参考知识:

    RGB、HSV和HSL颜色空间

https://pic4.zhimg.com/v2-e9f9c843e7d60e8f7aa7de1cd61d1818_1440w.jpg?source=172ae18b

现在,我们实现一个使用HSV来只显示视频中蓝色物体的例子,步骤如下:

    捕获视频中的一帧
    从BGR转换到HSV
    提取蓝色范围的物体
    只显示蓝色物体

In [11]

# 加载一张有天空的图片
sky = cv2.imread('sky.jpg')

In [12]

# 蓝色的范围,不同光照条件下不一样,可灵活调整
lower_blue = np.array([15, 60, 60])
upper_blue = np.array([130, 255, 255])

In [13]

# 从BGR转换到HSV
hsv = cv2.cvtColor(sky, cv2.COLOR_BGR2HSV)
# inRange():介于lower/upper之间的为白色,其余黑色
mask = cv2.inRange(sky, lower_blue, upper_blue)
# 只保留原图中的蓝色部分
res = cv2.bitwise_and(sky, sky, mask=mask)

In [14]

# 保存颜色分割结果
cv2.imwrite('res.jpg', res)

True

In [15]

res = cv2.imread('res.jpg')
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
plt.imshow(res)

<matplotlib.image.AxesImage at 0x7fb1e9ef1150>

<Figure size 432x288 with 1 Axes>

其中,bitwise_and()函数暂时不用管,后面会讲到。那蓝色的HSV值的上下限lower和upper范围是怎么得到的呢?其实很简单,我们先把标准蓝色的BGR值用cvtColor()转换下:
In [16]

blue = np.uint8([[[255, 0, 0]]])
hsv_blue = cv2.cvtColor(blue, cv2.COLOR_BGR2HSV)
print(hsv_blue)

[[[120 255 255]]]

结果是[120, 255, 255],所以,我们把蓝色的范围调整成了上面代码那样。

    经验之谈:Lab颜色空间也经常用来做颜色识别,有兴趣的同学可以了解下。

阈值分割

    使用固定阈值、自适应阈值和Otsu阈值法"二值化"图像
    OpenCV函数:cv2.threshold(), cv2.adaptiveThreshold()

固定阈值分割

固定阈值分割很直接,一句话说就是像素点值大于阈值变成一类值,小于阈值变成另一类值。

cv2.threshold()用来实现阈值分割,ret是return value缩写,代表当前的阈值。函数有4个参数:

    参数1:要处理的原图,一般是灰度图
    参数2:设定的阈值
    参数3:最大阈值,一般为255
    参数4:阈值的方式,主要有5种,详情:ThresholdTypes
        0: THRESH_BINARY  当前点值大于阈值时,取Maxval,也就是第四个参数,否则设置为0
        1: THRESH_BINARY_INV 当前点值大于阈值时,设置为0,否则设置为Maxval
        2: THRESH_TRUNC 当前点值大于阈值时,设置为阈值,否则不改变
        3: THRESH_TOZERO 当前点值大于阈值时,不改变,否则设置为0
        4:THRESH_TOZERO_INV  当前点值大于阈值时,设置为0,否则不改变

    参考资料:基于opencv的固定阈值分割_自适应阈值分割

In [17]

import cv2

# 灰度图读入
img = cv2.imread('lena.jpg', 0)
# 颜色通道转换
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 阈值分割
ret, th = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)

plt.imshow(th)

<matplotlib.image.AxesImage at 0x7fb1e9ed1a50>

<Figure size 432x288 with 1 Axes>

In [18]

th[100]

array([[  0,   0,   0],
       [  0,   0,   0],
       [  0,   0,   0],
       ...,
       [255, 255, 255],
       [255, 255, 255],
       [255, 255, 255]], dtype=uint8)

In [19]

# 应用5种不同的阈值方法
# THRESH_BINARY  当前点值大于阈值时,取Maxval,否则设置为0
ret, th1 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
# THRESH_BINARY_INV 当前点值大于阈值时,设置为0,否则设置为Maxval
ret, th2 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
# THRESH_TRUNC 当前点值大于阈值时,设置为阈值,否则不改变
ret, th3 = cv2.threshold(img, 127, 255, cv2.THRESH_TRUNC)
# THRESH_TOZERO 当前点值大于阈值时,不改变,否则设置为0
ret, th4 = cv2.threshold(img, 127, 255, cv2.THRESH_TOZERO)
# THRESH_TOZERO_INV  当前点值大于阈值时,设置为0,否则不改变
ret, th5 = cv2.threshold(img, 127, 255, cv2.THRESH_TOZERO_INV)

titles = ['Original', 'BINARY', 'BINARY_INV', 'TRUNC', 'TOZERO', 'TOZERO_INV']
images = [img, th1, th2, th3, th4, th5]

In [20]

plt.figure(figsize=(12,12))
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.imshow(images[i], 'gray')
    plt.title(titles[i], fontsize=8)
    plt.xticks([]), plt.yticks([])

<Figure size 864x864 with 6 Axes>

    经验之谈:很多人误以为阈值分割就是二值化。从上图中可以发现,两者并不等同,阈值分割结果是两类值,而不是两个值。

自适应阈值

看得出来固定阈值是在整幅图片上应用一个阈值进行分割,它并不适用于明暗分布不均的图片。 cv2.adaptiveThreshold()自适应阈值会每次取图片的一小部分计算阈值,这样图片不同区域的阈值就不尽相同。它有5个参数,其实很好理解,先看下效果:

    参数1:要处理的原图
    参数2:最大阈值,一般为255
    参数3:小区域阈值的计算方式
        ADAPTIVE_THRESH_MEAN_C:小区域内取均值
        ADAPTIVE_THRESH_GAUSSIAN_C:小区域内加权求和,权重是个高斯核
    参数4:阈值方式(跟前面讲的那5种相同)
    参数5:小区域的面积,如11就是11*11的小块
    参数6:最终阈值等于小区域计算出的阈值再减去此值

建议读者调整下参数看看不同的结果。
In [21]

# 自适应阈值对比固定阈值
img = cv2.imread('lena.jpg', 0)

# 固定阈值
ret, th1 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
# 自适应阈值, ADAPTIVE_THRESH_MEAN_C:小区域内取均值
th2 = cv2.adaptiveThreshold(
    img, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 4)
# 自适应阈值, ADAPTIVE_THRESH_GAUSSIAN_C:小区域内加权求和,权重是个高斯核
th3 = cv2.adaptiveThreshold(
    img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 17, 6)

titles = ['Original', 'Global(v = 127)', 'Adaptive Mean', 'Adaptive Gaussian']
images = [img, th1, th2, th3]
plt.figure(figsize=(12,12))
for i in range(4):
    plt.subplot(2, 2, i + 1), plt.imshow(images[i], 'gray')
    plt.title(titles[i], fontsize=8)
    plt.xticks([]), plt.yticks([])

<Figure size 864x864 with 4 Axes>

Otsu阈值

在前面固定阈值中,我们是随便选了一个阈值如127,那如何知道我们选的这个阈值效果好不好呢?答案是:不断尝试,所以这种方法在很多文献中都被称为经验阈值。Otsu阈值法就提供了一种自动高效的二值化方法。
小结

    cv2.threshold()用来进行固定阈值分割。固定阈值不适用于光线不均匀的图片,所以用 cv2.adaptiveThreshold()进行自适应阈值分割。
    二值化跟阈值分割并不等同。针对不同的图片,可以采用不同的阈值方法。

图像几何变换

    实现旋转、平移和缩放图片
    OpenCV函数:cv2.resize(), cv2.flip(), cv2.warpAffine()

缩放图片

缩放就是调整图片的大小,使用cv2.resize()函数实现缩放。可以按照比例缩放,也可以按照指定的大小缩放: 我们也可以指定缩放方法interpolation,更专业点叫插值方法,默认是INTER_LINEAR,全部可以参考:InterpolationFlags

缩放过程中有五种插值方式:

    cv2.INTER_NEAREST 最近邻插值
    cv2.INTER_LINEAR 线性插值
    cv2.INTER_AREA 基于局部像素的重采样,区域插值
    cv2.INTER_CUBIC 基于邻域4x4像素的三次插值
    cv2.INTER_LANCZOS4 基于8x8像素邻域的Lanczos插值

In [22]

img = cv2.imread('cat.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# 按照指定的宽度、高度缩放图片
res = cv2.resize(img, (400, 500))
# 按照比例缩放,如x,y轴均放大一倍
res2 = cv2.resize(img, None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR)
plt.imshow(res)

<matplotlib.image.AxesImage at 0x7fb1e9c6eb10>

<Figure size 432x288 with 1 Axes>

In [23]

plt.imshow(res2)

<matplotlib.image.AxesImage at 0x7fb1e9df5dd0>

<Figure size 432x288 with 1 Axes>

翻转图片

镜像翻转图片,可以用cv2.flip()函数: 其中,参数2 = 0:垂直翻转(沿x轴),参数2 > 0: 水平翻转(沿y轴),参数2 < 0: 水平垂直翻转。
In [24]

dst = cv2.flip(img, 1)

In [25]

plt.imshow(dst)

<matplotlib.image.AxesImage at 0x7fb1e9d0c390>

<Figure size 432x288 with 1 Axes>

平移图片

要平移图片,我们需要定义下面这样一个矩阵,tx,ty是向x和y方向平移的距离:

M=[10tx01ty] M = \left[ \begin{matrix} 1 & 0 & t_x \newline 0 & 1 & t_y \end{matrix} \right]
M=[10​tx​01​ty​​]

平移是用仿射变换函数cv2.warpAffine()实现的:
In [26]

# 平移图片
import numpy as np
# 获得图片的高、宽
rows, cols = img.shape[:2]

In [27]

# 定义平移矩阵,需要是numpy的float32类型
# x轴平移200,y轴平移500
M = np.float32([[1, 0, 100], [0, 1, 500]])
# 用仿射变换实现平移
dst = cv2.warpAffine(img, M, (cols, rows))

In [28]

plt.imshow(dst)

<matplotlib.image.AxesImage at 0x7fb1e9cabbd0>

<Figure size 432x288 with 1 Axes>

绘图功能

    绘制各种几何形状、添加文字
    OpenCV函数:cv2.line(), cv2.circle(), cv2.rectangle(), cv2.ellipse(), cv2.putText()

绘制形状的函数有一些共同的参数,提前在此说明一下:

    img:要绘制形状的图片
    color:绘制的颜色
        彩色图就传入BGR的一组值,如蓝色就是(255,0,0)
        灰度图,传入一个灰度值就行
    thickness:线宽,默认为1;对于矩形/圆之类的封闭形状而言,传入-1表示填充形状
    lineType:线的类型。默认情况下,它是8连接的。cv2.LINE_AA 是适合曲线的抗锯齿线。

画线

画直线只需指定起点和终点的坐标就行:
In [29]

img = cv2.imread('lena.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

In [30]

# 画一条线宽为5的红色直线,参数2:起点,参数3:终点
cv2.line(img, (0, 0), (800, 512), (255, 0, 0), 5)
plt.imshow(img)

<matplotlib.image.AxesImage at 0x7fb1e9d81310>

<Figure size 432x288 with 1 Axes>

画矩形

画矩形需要知道左上角和右下角的坐标:
In [31]

# 画一个矩形,左上角坐标(40, 40),右下角坐标(80, 80),框颜色为绿色
img = cv2.rectangle(img, (40, 40), (80, 80), (0, 255, 0),2) 

In [32]

plt.imshow(img)

<matplotlib.image.AxesImage at 0x7fb1e9e07210>

<Figure size 432x288 with 1 Axes>

In [33]

# 画一个矩形,左上角坐标(40, 40),右下角坐标(80, 80),框颜色为绿色,填充这个矩形
img = cv2.rectangle(img, (40, 40), (80, 80), (0, 255, 0),-1) 
plt.imshow(img)

<matplotlib.image.AxesImage at 0x7fb1e9ba3810>

<Figure size 432x288 with 1 Axes>

添加文字

使用cv2.putText()添加文字,它的参数也比较多,同样请对照后面的代码理解这几个参数:

    参数2:要添加的文本
    参数3:文字的起始坐标(左下角为起点)
    参数4:字体
    参数5:文字大小(缩放比例)

In [34]

# 添加文字,加载字体
font = cv2.FONT_HERSHEY_SIMPLEX
# 添加文字hello
cv2.putText(img, 'hello', (10, 200), font,
            4, (255, 255, 255), 2, lineType=cv2.LINE_AA)

array([[[255,   0,   0],
        [255,   0,   0],
        [255,   0,   0],
        ...,
        [237, 148, 130],
        [232, 143, 125],
        [196, 107,  89]],

       [[255,   0,   0],
        [255,   0,   0],
        [255,   0,   0],
        ...,
        [233, 145, 125],
        [231, 144, 124],
        [195, 108,  88]],

       [[255,   0,   0],
        [255,   0,   0],
        [255,   0,   0],
        ...,
        [239, 151, 129],
        [234, 150, 126],
        [195, 113,  89]],

       ...,

       [[ 87,  26,  60],
        [ 92,  28,  62],
        [ 91,  24,  57],
        ...,
        [160,  68,  83],
        [165,  67,  82],
        [163,  63,  75]],

       [[ 80,  18,  55],
        [ 90,  26,  61],
        [ 93,  26,  59],
        ...,
        [167,  71,  85],
        [174,  72,  86],
        [173,  69,  80]],

       [[ 79,  17,  54],
        [ 91,  27,  62],
        [ 96,  29,  64],
        ...,
        [169,  73,  87],
        [178,  73,  87],
        [180,  74,  86]]], dtype=uint8)

In [35]

plt.imshow(img)

<matplotlib.image.AxesImage at 0x7fb1e9b8e410>

<Figure size 432x288 with 1 Axes>

In [36]

# 参考资料 https://blog.csdn.net/qq_41895190/article/details/90301459
# 引入PIL的相关包
from PIL import Image, ImageFont,ImageDraw
from numpy import unicode

def paint_chinese_opencv(im,chinese,pos,color):
    img_PIL = Image.fromarray(cv2.cvtColor(im,cv2.COLOR_BGR2RGB))
    # 加载中文字体
    font = ImageFont.truetype('NotoSansCJKsc-Medium.otf',25)
    # 设置颜色
    fillColor = color
    # 定义左上角坐标
    position = pos
    # 判断是否中文字符
    if not isinstance(chinese,unicode):
        # 解析中文字符
        chinese = chinese.decode('utf-8')
    # 画图
    draw = ImageDraw.Draw(img_PIL)
    # 画文字
    draw.text(position,chinese,font=font,fill=fillColor)
    # 颜色通道转换
    img = cv2.cvtColor(np.asarray(img_PIL),cv2.COLOR_RGB2BGR)
    return img

In [37]

plt.imshow(paint_chinese_opencv(img,'中文',(100,100),(255,255,0)))

<matplotlib.image.AxesImage at 0x7fb1e9a92dd0>

<Figure size 432x288 with 1 Axes>

小结

    cv2.line()画直线,cv2.circle()画圆,cv2.rectangle()画矩形,cv2.ellipse()画椭圆,cv2.polylines()画多边形,cv2.putText()添加文字。
    画多条直线时,cv2.polylines()要比cv2.line()高效很多。
    要在图像中打上中文,可以用PIL库结合OpenCV实现。

图像间数学运算

    图片间的数学运算,如相加、按位运算等
    OpenCV函数:cv2.add(), cv2.addWeighted(), cv2.bitwise_and()

图片相加

要叠加两张图片,可以用cv2.add()函数,相加两幅图片的形状(高度/宽度/通道数)必须相同。numpy中可以直接用res = img + img1相加,但这两者的结果并不相同:
In [38]

x = np.uint8([250])
y = np.uint8([10])
print(cv2.add(x, y))  # 250+10 = 260 => 255
print(x + y)  # 250+10 = 260 % 256 = 4

[[255]]
[4]

如果是二值化图片(只有0255两种值),两者结果是一样的(用numpy的方式更简便一些)。
图像混合

图像混合cv2.addWeighted()也是一种图片相加的操作,只不过两幅图片的权重不一样,γ相当于一个修正值:

dst=α×img1+β×img2+γdst = \alpha\times img1+\beta\times img2 + \gamma
dst=α×img1+β×img2+γ
In [39]

img1 = cv2.imread('lena.jpg')
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img2 = cv2.imread('cat.png')
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = cv2.resize(img2, (350, 350))
# 两张图片相加
res = cv2.addWeighted(img1, 0.6, img2, 0.4, 0)

In [40]

plt.imshow(res)

<matplotlib.image.AxesImage at 0x7fb1e9a83650>

<Figure size 432x288 with 1 Axes>

按位操作

按位操作包括按位与///异或操作,有什么用途呢?

如果将两幅图片直接相加会改变图片的颜色,如果用图像混合,则会改变图片的透明度,所以我们需要用按位操作。首先来了解一下 掩膜(mask)的概念:掩膜是用一副二值化图片对另外一幅图片进行局部的遮挡
In [41]

img1 = cv2.imread('lena.jpg')
img2 = cv2.imread('logo.jpg')
img2 = cv2.resize(img2, (350, 350))
# 把logo放在左上角,所以我们只关心这一块区域
rows, cols = img2.shape[:2]
roi = img1[:rows, :cols]

# 创建掩膜
img2gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
ret, mask = cv2.threshold(img2gray, 10, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask)

# 保留除logo外的背景
img1_bg = cv2.bitwise_and(roi, roi, mask=mask_inv)
dst = cv2.add(img1_bg, img2)  # 进行融合
img1[:rows, :cols] = dst  # 融合后放在原图上

In [42]

plt.imshow(dst)

<matplotlib.image.AxesImage at 0x7fb1e99e3e50>

<Figure size 432x288 with 1 Axes>

小结

    cv2.add()用来叠加两幅图片,cv2.addWeighted()也是叠加两幅图片,但两幅图片的权重不一样。
    cv2.bitwise_and(), cv2.bitwise_not(), cv2.bitwise_or(), cv2.bitwise_xor()分别执行按位与///异或运算。掩膜就是用来对图片进行全局或局部的遮挡。

平滑图像

    模糊/平滑图片来消除图片噪声
    OpenCV函数:cv2.blur(), cv2.GaussianBlur(), cv2.medianBlur(), cv2.bilateralFilter()

滤波与模糊

关于滤波和模糊:

    它们都属于卷积,不同滤波方法之间只是卷积核不同(对线性滤波而言)
    低通滤波器是模糊,高通滤波器是锐化

低通滤波器就是允许低频信号通过,在图像中边缘和噪点都相当于高频部分,所以低通滤波器用于去除噪点、平滑和模糊图像。高通滤波器则反之,用来增强图像边缘,进行锐化处理。

    常见噪声有椒盐噪声和高斯噪声,椒盐噪声可以理解为斑点,随机出现在图像中的黑点或白点;高斯噪声可以理解为拍摄图片时由于光照等原因造成的噪声。

均值滤波

均值滤波是一种最简单的滤波处理,它取的是卷积核区域内元素的均值,用cv2.blur()实现,如3×3的卷积核:

kernel=19[111111111] kernel = \frac{1}{9}\left[ \begin{matrix} 1 & 1 & 1 \newline 1 & 1 & 1 \newline 1 & 1 & 1 \end{matrix} \right]
kernel=91[111111111]

img = cv2.imread('lena.jpg')
blur = cv2.blur(img, (3, 3))  # 均值模糊

In [43]

img = cv2.imread('lena.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
blur = cv2.blur(img, (9, 9))  # 均值模糊
plt.imshow(blur)

<matplotlib.image.AxesImage at 0x7fb2184a66d0>

<Figure size 432x288 with 1 Axes>

方框滤波

方框滤波跟均值滤波很像,如3×3的滤波核如下:

k=a[111111111]k = a\left[ \begin{matrix} 1 & 1 & 1 \newline 1 & 1 & 1 \newline 1 & 1 & 1 \end{matrix} \right]
k=a[111111111]

用cv2.boxFilter()函数实现,当可选参数normalize为True的时候,方框滤波就是均值滤波,上式中的a就等于1/9;normalize为False的时候,a=1,相当于求区域内的像素和。
In [44]

# 前面的均值滤波也可以用方框滤波实现:normalize=True
blur = cv2.boxFilter(img, -1, (9, 9), normalize=True)
plt.imshow(blur)

<matplotlib.image.AxesImage at 0x7fb1e81162d0>

<Figure size 432x288 with 1 Axes>

高斯滤波

前面两种滤波方式,卷积核内的每个值都一样,也就是说图像区域中每个像素的权重也就一样。高斯滤波的卷积核权重并不相同:中间像素点权重最高,越远离中心的像素权重越小。

显然这种处理元素间权值的方式更加合理一些。图像是2维的,所以我们需要使用2维的高斯函数,比如OpenCV中默认的3×3的高斯卷积核:

k=[0.06250.1250.06250.1250.250.1250.06250.1250.0625]k = \left[ \begin{matrix} 0.0625 & 0.125 & 0.0625 \newline 0.125 & 0.25 & 0.125 \newline 0.0625 & 0.125 & 0.0625 \end{matrix} \right]
k=[0.06250.1250.06250.1250.250.1250.06250.1250.0625]

OpenCV中对应函数为cv2.GaussianBlur(src,ksize,sigmaX): 参数3 σx值越大,模糊效果越明显。高斯滤波相比均值滤波效率要慢,但可以有效消除高斯噪声,能保留更多的图像细节,所以经常被称为最有用的滤波器。均值滤波与高斯滤波的对比结果如下(均值滤波丢失的细节更多)
In [45]

# 均值滤波vs高斯滤波
gaussian = cv2.GaussianBlur(img, (9, 9), 1)  # 高斯滤波
plt.imshow(gaussian)

<matplotlib.image.AxesImage at 0x7fb1e807d090>

<Figure size 432x288 with 1 Axes>

中值滤波

中值又叫中位数,是所有数排序后取中间的值。中值滤波就是用区域内的中值来代替本像素值,所以那种孤立的斑点,如0255很容易消除掉,适用于去除椒盐噪声和斑点噪声。中值是一种非线性操作,效率相比前面几种线性滤波要慢。
In [46]

median = cv2.medianBlur(img, 9)  # 中值滤波
plt.imshow(median)

<matplotlib.image.AxesImage at 0x7fb1e805cdd0>

<Figure size 432x288 with 1 Axes>

双边滤波

模糊操作基本都会损失掉图像细节信息,尤其前面介绍的线性滤波器,图像的边缘信息很难保留下来。然而,边缘(edge)信息是图像中很重要的一个特征,所以这才有了双边滤波。用cv2.bilateralFilter()函数实现:可以看到,双边滤波明显保留了更多边缘信息。
In [47]

blur = cv2.bilateralFilter(img, 9, 75, 75)  # 双边滤波
plt.imshow(blur)

<matplotlib.image.AxesImage at 0x7fb1e003bcd0>

<Figure size 432x288 with 1 Axes>

图像锐化
In [48]

kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #定义一个核
dst = cv2.filter2D(img, -1, kernel=kernel)
plt.imshow(dst)

<matplotlib.image.AxesImage at 0x7fb1bc7f7210>

<Figure size 432x288 with 1 Axes>

边缘检测

    Canny J . A Computational Approach To Edge Detection[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 1986, PAMI-8(6):679-698.

    Canny边缘检测的简单概念
    OpenCV函数:cv2.Canny()

Canny边缘检测方法常被誉为边缘检测的最优方法:

cv2.Canny()进行边缘检测,参数23表示最低、高阈值,下面来解释下具体原理。

    经验之谈:之前我们用低通滤波的方式模糊了图片,那反过来,想得到物体的边缘,就需要用到高通滤波。

Canny边缘检测

Canny边缘提取的具体步骤如下:

    使用5×5高斯滤波消除噪声:

边缘检测本身属于锐化操作,对噪点比较敏感,所以需要进行平滑处理。

K=1256[1464141624164624362464162416414641]K=\frac{1}{256}\left[ \begin{matrix} 1 & 4 & 6 & 4 & 1 \newline 4 & 16 & 24 & 16 & 4 \newline 6 & 24 & 36 & 24 & 6 \newline 4 & 16 & 24 & 16 & 4 \newline 1 & 4 & 6 & 4 & 1 \end{matrix} \right]
K=2561[1464141624164624362464162416414641]

    计算图像梯度的方向:

首先使用Sobel算子计算两个方向上的梯度$ G_x 和和和 G_y $,然后算出梯度的方向:

θ=arctan⁡(GyGx)\theta=\arctan(\frac{G_y}{G_x})
θ=arctan(Gx​Gy​​)

保留这四个方向的梯度:0°/45°/90°/135°,有什么用呢?我们接着看。

    取局部极大值:

梯度其实已经表示了轮廓,但为了进一步筛选,可以在上面的四个角度方向上再取局部极大值

    滞后阈值:

经过前面三步,就只剩下0和可能的边缘梯度值了,为了最终确定下来,需要设定高低阈值:

    像素点的值大于最高阈值,那肯定是边缘
    同理像素值小于最低阈值,那肯定不是边缘
    像素值介于两者之间,如果与高于最高阈值的点连接,也算边缘,所以上图中C算,B不算

Canny推荐的高低阈值比在2:13:1之间。
In [49]

img = cv2.imread('lena.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
edges = cv2.Canny(img, 30, 70)  # canny边缘检测
plt.imshow(edges)

<matplotlib.image.AxesImage at 0x7fb1bc759d10>

<Figure size 432x288 with 1 Axes>

先阈值分割后检测

其实很多情况下,阈值分割后再检测边缘,效果会更好。
In [50]

_, thresh = cv2.threshold(img, 124, 255, cv2.THRESH_BINARY)
edges = cv2.Canny(thresh, 30, 70)
plt.imshow(edges)

<matplotlib.image.AxesImage at 0x7fb1bc6bfa90>

<Figure size 432x288 with 1 Axes>

小结

    Canny是用的最多的边缘检测算法,用cv2.Canny()实现。

腐蚀与膨胀

    了解形态学操作的概念
    学习膨胀、腐蚀、开运算和闭运算等形态学操作
    OpenCV函数:cv2.erode(), cv2.dilate(), cv2.morphologyEx()

啥叫形态学操作

形态学操作其实就是改变物体的形状,比如腐蚀就是"变瘦",膨胀就是"变胖"。

    经验之谈:形态学操作一般作用于二值化图,来连接相邻的元素或分离成独立的元素。腐蚀和膨胀是针对图片中的白色部分!

腐蚀

腐蚀的效果是把图片"变瘦",其原理是在原图的小区域内取局部最小值。因为是二值化图,只有0255,所以小区域内有一个是0该像素点就为0。

这样原图中边缘地方就会变成0,达到了瘦身目的

OpenCV中用cv2.erode()函数进行腐蚀,只需要指定核的大小就行:
In [51]

img = cv2.imread('lena.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
kernel = np.ones((5, 5), np.uint8)
erosion = cv2.erode(img, kernel)  # 腐蚀
plt.imshow(erosion)

<matplotlib.image.AxesImage at 0x7fb1bc6abcd0>

<Figure size 432x288 with 1 Axes>

    这个核也叫结构元素,因为形态学操作其实也是应用卷积来实现的。结构元素可以是矩形/椭圆/十字形,可以用cv2.getStructuringElement()来生成不同形状的结构元素,比如:

kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))  # 矩形结构
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))  # 椭圆结构
kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (5, 5))  # 十字形结构

膨胀

膨胀与腐蚀相反,取的是局部最大值,效果是把图片"变胖":
In [52]

dilation = cv2.dilate(img, kernel)  # 膨胀
plt.imshow(dilation)

<matplotlib.image.AxesImage at 0x7fb1bc615590>

<Figure size 432x288 with 1 Axes>/闭运算

先腐蚀后膨胀叫开运算(因为先腐蚀会分开物体,这样容易记住),其作用是:分离物体,消除小区域。这类形态学操作用cv2.morphologyEx()函数实现:
In [53]

kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))  # 定义结构元素
opening = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)  # 开运算
plt.imshow(opening)

<matplotlib.image.AxesImage at 0x7fb1bc57d710>

<Figure size 432x288 with 1 Axes>

闭运算则相反:先膨胀后腐蚀(先膨胀会使白色的部分扩张,以至于消除/"闭合"物体里面的小黑洞,所以叫闭运算)
In [54]

closing = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)  # 闭运算
plt.imshow(closing)

<matplotlib.image.AxesImage at 0x7fb1bc569310>

<Figure size 432x288 with 1 Axes>

    经验之谈:很多人对开闭运算的作用不是很清楚,但看上图↑,不用怕:如果我们的目标物体外面有很多无关的小区域,就用开运算去除掉;如果物体内部有很多小黑洞,就用闭运算填充掉。



## 如何改变文本的样式

*强调文本* _强调文本_

**加粗文本** __加粗文本__

==标记文本==

~~删除文本~~

> 引用文本

H~2~O is是液体。

2^10^ 运算结果是 1024.

## 插入链接与图片

链接: [link](https://www.csdn.net/).

图片: ![Alt](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9hdmF0YXIuY3Nkbi5uZXQvNy83L0IvMV9yYWxmX2h4MTYzY29tLmpwZw)

带尺寸的图片: ![Alt](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9hdmF0YXIuY3Nkbi5uZXQvNy83L0IvMV9yYWxmX2h4MTYzY29tLmpwZw =30x30)

居中的图片: ![Alt](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9hdmF0YXIuY3Nkbi5uZXQvNy83L0IvMV9yYWxmX2h4MTYzY29tLmpwZw#pic_center)

居中并且带尺寸的图片: ![Alt](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9hdmF0YXIuY3Nkbi5uZXQvNy83L0IvMV9yYWxmX2h4MTYzY29tLmpwZw#pic_center =30x30)

当然,我们为了让用户更加便捷,我们增加了图片拖拽功能。

## 如何插入一段漂亮的代码片

去[博客设置](https://mp.csdn.net/console/configBlog)页面,选择一款你喜欢的代码片高亮样式,下面展示同样高亮的 `代码片`.
```javascript
// An highlighted block
var foo = 'bar';

第二节课图像分类模型——感知机

在机器学习中,感知机(perceptron)是二分类的线性分类模型,属于监督学习算法。输入为实例的特征向量,输出为实例的类别(取+1和-1)。

感知机对应于输入空间中将实例划分为两类的分离超平面。感知机旨在求出该超平面,为求得超平面导入了基于误分类的损失函数,利用梯度下降法 对损失函数进行最优化(最优化)。

感知机的学习算法具有简单而易于实现的优点,分为原始形式和对偶形式。感知机预测是用学习得到的感知机模型对新的实例进行预测的,因此属于判别模型。

感知机由Rosenblatt于1957年提出的,是神经网络和支持向量机的基础。

定义

假设输入空间(特征向量)为X=[x1,x2,x3,x…],输出空间为Y = [1,-1]。

输入 = X

表示实例的特征向量,对应于输入空间的点;

输出 = Y

表示示例的类别。

由输入空间到输出空间的函数为

f(x)=sign(wTx+b)f(x)=sign(w^Tx+b)f(x)=sign(wTx+b)

称为感知机。其中,参数w叫做权值向量(weight),b称为偏置(bias)。表示wTw^TwT和x的点积

wTx=w1∗x1+w2∗x2+w3∗x3+…+wn∗xn\mathbf{w}^{T} \mathbf{x} =w_1x_1+w_2x_2+w_3x_3+…+w_nx_nwTx=w1​∗x1​+w2​∗x2​+w3​∗x3​+…+wn​∗xn​ ,w=[w1,w2,…,wn]T\mathbf{w} = [w_1, w_2,…,w_n]^{T}w=[w1​,w2​,…,wn​]T ,x=[x1,x2,…,xn]T\mathbf{x} = [x_1, x_2,…,x_n]^{T}x=[x1​,x2​,…,xn​]T

sign为符号函数,即

sign(A)={+1,A≥0−1,A<0sign(A)=\left{\begin{matrix}+1,A \geq 0\-1,A<0\end{matrix}\right.sign(A)={+1,A≥0−1,A<0​

感知机算法就是要找到一个超平面将我们的数据分为两部分。

超平面就是维度比我们当前维度空间小一个维度的空间, 例如:我们当前的维度是二维的空间(由数据维度确定,x有多少列就有多大的维度),那么超平面就是一维 的,即一条直线。

#引入必要的包
import paddle
print("本教程使用的paddle版本为:" + paddle.__version__)
import numpy as np
import matplotlib.pyplot as plt

本教程使用的paddle版本为:2.0.0

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

In [2]

np.random.seed(0)
num=100

#生成数据集x1,x2,y0/1
#随机生成100个x1
x1=np.random.normal(6,1,size=(num))
#随机生成100个x2
x2=np.random.normal(3,1,size=(num))
#生成100个y(全身不都是1哦)
y=np.ones(num)
#将生成好的点放入到一个分类中
class1=np.array([x1,x2,y])
class1.shape
#接下来生成第二类点,原理跟第一类一样
x1=np.random.normal(3,1,size=(num))
x2=np.random.normal(6,1,size=(num))
y=np.ones(num)*(-1)
class2=np.array([x1,x2,y])

In [3]

#看一下生成点的样子
print(class1.shape)
print(class2.shape)

(3, 100)
(3, 100)

这个形状不便于我们操作数据,我们来转置一下
In [4]

class1=class1.T
class2=class2.T
#再看一下生成点的样子
print(class1.shape)
print(class2.shape)

(100, 3)
(100, 3)

光看数据,也不知到是啥啊。来画图看一下
In [5]

plt.scatter(class1[:,0],class1[:,1])
plt.scatter(class2[:,0],class2[:,1],marker='*')

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data

<matplotlib.collections.PathCollection at 0x7fde9daca710>

<Figure size 432x288 with 1 Axes>

这下就清楚了,接下来就要画线了
In [6]

#将两类数据都放到一个变量里面
all_data = np.concatenate((class1,class2))
print(all_data)

[[ 7.76405235  4.8831507   1.        ]
 [ 6.40015721  1.65224094  1.        ]
 [ 6.97873798  1.729515    1.        ]
 [ 8.2408932   3.96939671  1.        ]
 [ 7.86755799  1.82687659  1.        ]
 [ 5.02272212  4.94362119  1.        ]
 [ 6.95008842  2.58638102  1.        ]
 [ 5.84864279  2.25254519  1.        ]
 [ 5.89678115  4.92294203  1.        ]
 [ 6.4105985   4.48051479  1.        ]
 [ 6.14404357  4.86755896  1.        ]
 [ 7.45427351  3.90604466  1.        ]
 [ 6.76103773  2.13877431  1.        ]
 [ 6.12167502  4.91006495  1.        ]
 [ 6.44386323  2.73199663  1.        ]
 [ 6.33367433  3.8024564   1.        ]
 [ 7.49407907  3.94725197  1.        ]
 [ 5.79484174  2.84498991  1.        ]
 [ 6.3130677   3.61407937  1.        ]
 [ 5.14590426  3.92220667  1.        ]
 [ 3.44701018  3.37642553  1.        ]
 [ 6.6536186   1.90059921  1.        ]
 [ 6.8644362   3.29823817  1.        ]
 [ 5.25783498  4.3263859   1.        ]
 [ 8.26975462  2.30543214  1.        ]
 [ 4.54563433  2.85036546  1.        ]
 [ 6.04575852  2.56484645  1.        ]
 [ 5.81281615  4.84926373  1.        ]
 [ 7.53277921  3.67229476  1.        ]
 [ 7.46935877  3.40746184  1.        ]
 [ 6.15494743  2.23008393  1.        ]
 [ 6.37816252  3.53924919  1.        ]
 [ 5.11221425  2.32566734  1.        ]
 [ 4.01920353  3.03183056  1.        ]
 [ 5.65208785  2.36415392  1.        ]
 [ 6.15634897  3.67643329  1.        ]
 [ 7.23029068  3.57659082  1.        ]
 [ 7.20237985  2.79170124  1.        ]
 [ 5.61267318  3.39600671  1.        ]
 [ 5.69769725  1.90693849  1.        ]
 [ 4.95144703  1.50874241  1.        ]
 [ 4.57998206  3.4393917   1.        ]
 [ 4.29372981  3.1666735   1.        ]
 [ 7.9507754   3.63503144  1.        ]
 [ 5.49034782  5.38314477  1.        ]
 [ 5.5619257   3.94447949  1.        ]
 [ 4.74720464  2.08717777  1.        ]
 [ 6.77749036  4.11701629  1.        ]
 [ 4.38610215  1.68409259  1.        ]
 [ 5.78725972  2.5384154   1.        ]
 [ 5.10453344  2.93175839  1.        ]
 [ 6.3869025   4.71334272  1.        ]
 [ 5.48919486  2.25524518  1.        ]
 [ 4.81936782  2.17356146  1.        ]
 [ 5.97181777  2.90154748  1.        ]
 [ 6.42833187  2.33652171  1.        ]
 [ 6.06651722  4.12663592  1.        ]
 [ 6.3024719   1.92006849  1.        ]
 [ 5.36567791  1.85253135  1.        ]
 [ 5.63725883  2.56217996  1.        ]
 [ 5.32753955  2.50196755  1.        ]
 [ 5.64044684  4.92953205  1.        ]
 [ 5.18685372  3.94942081  1.        ]
 [ 4.2737174   3.08755124  1.        ]
 [ 6.17742614  1.77456448  1.        ]
 [ 5.59821906  3.84436298  1.        ]
 [ 4.36980165  1.99978465  1.        ]
 [ 6.46278226  1.4552289   1.        ]
 [ 5.09270164  4.18802979  1.        ]
 [ 6.0519454   3.31694261  1.        ]
 [ 6.72909056  3.92085882  1.        ]
 [ 6.12898291  3.31872765  1.        ]
 [ 7.13940068  3.85683061  1.        ]
 [ 4.76517418  2.34897441  1.        ]
 [ 6.40234164  1.96575716  1.        ]
 [ 5.31518991  3.68159452  1.        ]
 [ 5.12920285  2.19659034  1.        ]
 [ 5.42115034  2.31045022  1.        ]
 [ 5.68844747  2.5444675   1.        ]
 [ 6.05616534  3.01747916  1.        ]
 [ 4.83485016  2.64600609  1.        ]
 [ 6.90082649  1.62504871  1.        ]
 [ 6.46566244  2.3563816   1.        ]
 [ 4.46375631  0.77659685  1.        ]
 [ 7.48825219  3.62523145  1.        ]
 [ 7.89588918  1.39794234  1.        ]
 [ 7.17877957  1.89561666  1.        ]
 [ 5.82007516  3.05216508  1.        ]
 [ 4.92924738  2.260437    1.        ]
 [ 7.05445173  4.5430146   1.        ]
 [ 5.59682305  1.70714309  1.        ]
 [ 7.22244507  3.26705087  1.        ]
 [ 6.20827498  2.96071718  1.        ]
 [ 6.97663904  1.8319065   1.        ]
 [ 6.3563664   3.52327666  1.        ]
 [ 6.70657317  2.82845367  1.        ]
 [ 6.01050002  3.77179055  1.        ]
 [ 7.78587049  3.82350415  1.        ]
 [ 6.12691209  5.16323595  1.        ]
 [ 6.40198936  4.33652795  1.        ]
 [ 2.63081816  4.69347315 -1.        ]
 [ 2.76062082  7.65813068 -1.        ]
 [ 4.0996596   5.88183595 -1.        ]
 [ 3.65526373  5.3198218  -1.        ]
 [ 3.64013153  6.66638308 -1.        ]
 [ 1.38304396  5.53928021 -1.        ]
 [ 2.97567388  4.66574153 -1.        ]
 [ 2.26196909  4.65328249 -1.        ]
 [ 3.2799246   6.69377315 -1.        ]
 [ 2.90184961  5.84042656 -1.        ]
 [ 3.91017891  5.86629844 -1.        ]
 [ 3.31721822  7.07774381 -1.        ]
 [ 3.78632796  4.87317419 -1.        ]
 [ 2.5335809   5.26932225 -1.        ]
 [ 2.05555374  5.61512019 -1.        ]
 [ 2.58995031  6.09435159 -1.        ]
 [ 2.98297959  5.95782855 -1.        ]
 [ 3.37915174  5.71311281 -1.        ]
 [ 5.25930895  5.9383736  -1.        ]
 [ 2.95774285  5.89269472 -1.        ]
 [ 2.044055    5.28039561 -1.        ]
 [ 2.65401822  5.18700701 -1.        ]
 [ 2.53640403  6.27451636 -1.        ]
 [ 3.48148147  5.10908492 -1.        ]
 [ 1.45920299  4.84264474 -1.        ]
 [ 3.06326199  5.68770775 -1.        ]
 [ 3.15650654  5.84233298 -1.        ]
 [ 3.23218104  8.2567235  -1.        ]
 [ 2.40268393  5.29529972 -1.        ]
 [ 2.76207827  6.94326072 -1.        ]
 [ 1.57593909  6.74718833 -1.        ]
 [ 2.50668012  4.81105504 -1.        ]
 [ 2.45713852  6.77325298 -1.        ]
 [ 3.41605005  4.81611936 -1.        ]
 [ 1.84381757  3.34082776 -1.        ]
 [ 3.7811981   6.60631952 -1.        ]
 [ 4.49448454  4.24410942 -1.        ]
 [ 0.93001497  6.45093446 -1.        ]
 [ 3.42625873  5.3159891  -1.        ]
 [ 3.67690804  7.6595508  -1.        ]
 [ 2.36256297  7.0685094  -1.        ]
 [ 2.60272819  5.5466142  -1.        ]
 [ 2.86711942  5.31216239 -1.        ]
 [ 2.70220912  4.7859226  -1.        ]
 [ 2.69098703  5.55907737 -1.        ]
 [ 1.32399619  5.7196445  -1.        ]
 [ 4.15233156  5.63530646 -1.        ]
 [ 4.07961859  6.15670386 -1.        ]
 [ 2.18663574  6.5785215  -1.        ]
 [ 1.53357567  6.34965446 -1.        ]
 [ 3.52106488  5.23585608 -1.        ]
 [ 2.42421203  4.56220853 -1.        ]
 [ 3.14195316  7.36453185 -1.        ]
 [ 2.68067158  5.31055082 -1.        ]
 [ 3.69153875  5.3477064  -1.        ]
 [ 3.69474914  5.47881069 -1.        ]
 [ 2.27440262  4.15693045 -1.        ]
 [ 1.61663604  5.522026   -1.        ]
 [ 1.4170616   5.52034419 -1.        ]
 [ 3.61037938  6.6203583  -1.        ]
 [ 1.81114074  6.69845715 -1.        ]
 [ 2.49318365  6.00377089 -1.        ]
 [ 2.40368596  6.93184837 -1.        ]
 [ 2.9474327   6.33996498 -1.        ]
 [ 1.06372019  5.98431789 -1.        ]
 [ 3.1887786   6.16092817 -1.        ]
 [ 3.52389102  5.80934651 -1.        ]
 [ 3.08842209  5.60515049 -1.        ]
 [ 2.68911383  5.73226646 -1.        ]
 [ 3.09740017  4.87198867 -1.        ]
 [ 3.39904635  6.28044171 -1.        ]
 [ 0.22740724  5.00687639 -1.        ]
 [ 4.95591231  6.84163126 -1.        ]
 [ 3.39009332  5.75054142 -1.        ]
 [ 2.34759142  6.04949498 -1.        ]
 [ 2.60904662  6.49383678 -1.        ]
 [ 3.49374178  6.64331447 -1.        ]
 [ 2.88389606  4.42937659 -1.        ]
 [ 0.96931553  5.79309632 -1.        ]
 [ 5.06449286  6.88017891 -1.        ]
 [ 2.88945934  4.30189418 -1.        ]
 [ 4.02017271  6.38728048 -1.        ]
 [ 2.30795015  3.74443577 -1.        ]
 [ 4.53637705  4.97749316 -1.        ]
 [ 3.28634369  6.03863055 -1.        ]
 [ 3.60884383  4.3432849  -1.        ]
 [ 1.95474663  5.01448926 -1.        ]
 [ 4.21114529  4.52816499 -1.        ]
 [ 3.68981816  7.64813493 -1.        ]
 [ 4.30184623  6.16422776 -1.        ]
 [ 2.37191244  6.56729028 -1.        ]
 [ 2.51897288  5.7773249  -1.        ]
 [ 5.3039167   5.64656825 -1.        ]
 [ 1.93998418  4.38352581 -1.        ]
 [ 2.8640503   5.70816264 -1.        ]
 [ 4.13689136  5.23850779 -1.        ]
 [ 3.09772497  6.85792392 -1.        ]
 [ 3.58295368  7.14110187 -1.        ]
 [ 2.60055097  7.46657872 -1.        ]
 [ 3.37005589  6.85255194 -1.        ]]

这也太整齐了,不行不行,这对应我们的感知机来说岂不是太简单了。咱得来点男的(难的)
In [7]

#将数据打乱
np.random.shuffle(all_data)
print(all_data)

[[ 6.76103773  2.13877431  1.        ]
 [ 2.49318365  6.00377089 -1.        ]
 [ 7.53277921  3.67229476  1.        ]
 [ 6.46566244  2.3563816   1.        ]
 [ 4.92924738  2.260437    1.        ]
 [ 8.2408932   3.96939671  1.        ]
 [ 6.6536186   1.90059921  1.        ]
 [ 3.39009332  5.75054142 -1.        ]
 [ 5.3039167   5.64656825 -1.        ]
 [ 6.44386323  2.73199663  1.        ]
 [ 2.69098703  5.55907737 -1.        ]
 [ 2.40368596  6.93184837 -1.        ]
 [ 4.07961859  6.15670386 -1.        ]
 [ 5.59821906  3.84436298  1.        ]
 [ 4.15233156  5.63530646 -1.        ]
 [ 2.37191244  6.56729028 -1.        ]
 [ 5.49034782  5.38314477  1.        ]
 [ 2.88389606  4.42937659 -1.        ]
 [ 5.42115034  2.31045022  1.        ]
 [ 3.39904635  6.28044171 -1.        ]
 [ 2.18663574  6.5785215  -1.        ]
 [ 5.36567791  1.85253135  1.        ]
 [ 6.42833187  2.33652171  1.        ]
 [ 2.53640403  6.27451636 -1.        ]
 [ 2.86711942  5.31216239 -1.        ]
 [ 4.2737174   3.08755124  1.        ]
 [ 6.3563664   3.52327666  1.        ]
 [ 7.78587049  3.82350415  1.        ]
 [ 6.37816252  3.53924919  1.        ]
 [ 5.78725972  2.5384154   1.        ]
 [ 6.72909056  3.92085882  1.        ]
 [ 7.13940068  3.85683061  1.        ]
 [ 1.57593909  6.74718833 -1.        ]
 [ 5.10453344  2.93175839  1.        ]
 [ 3.64013153  6.66638308 -1.        ]
 [ 2.60055097  7.46657872 -1.        ]
 [ 2.76207827  6.94326072 -1.        ]
 [ 2.60904662  6.49383678 -1.        ]
 [ 3.58295368  7.14110187 -1.        ]
 [ 4.54563433  2.85036546  1.        ]
 [ 6.04575852  2.56484645  1.        ]
 [ 5.48919486  2.25524518  1.        ]
 [ 5.5619257   3.94447949  1.        ]
 [ 2.50668012  4.81105504 -1.        ]
 [ 1.84381757  3.34082776 -1.        ]
 [ 3.65526373  5.3198218  -1.        ]
 [ 3.1887786   6.16092817 -1.        ]
 [ 2.95774285  5.89269472 -1.        ]
 [ 5.69769725  1.90693849  1.        ]
 [ 7.17877957  1.89561666  1.        ]
 [ 2.26196909  4.65328249 -1.        ]
 [ 8.26975462  2.30543214  1.        ]
 [ 3.67690804  7.6595508  -1.        ]
 [ 4.74720464  2.08717777  1.        ]
 [ 3.48148147  5.10908492 -1.        ]
 [ 2.90184961  5.84042656 -1.        ]
 [ 3.28634369  6.03863055 -1.        ]
 [ 2.30795015  3.74443577 -1.        ]
 [ 4.29372981  3.1666735   1.        ]
 [ 1.45920299  4.84264474 -1.        ]
 [ 2.60272819  5.5466142  -1.        ]
 [ 5.06449286  6.88017891 -1.        ]
 [ 3.37915174  5.71311281 -1.        ]
 [ 4.36980165  1.99978465  1.        ]
 [ 3.68981816  7.64813493 -1.        ]
 [ 4.57998206  3.4393917   1.        ]
 [ 5.32753955  2.50196755  1.        ]
 [ 2.40268393  5.29529972 -1.        ]
 [ 3.44701018  3.37642553  1.        ]
 [ 2.88945934  4.30189418 -1.        ]
 [ 4.02017271  6.38728048 -1.        ]
 [ 3.31721822  7.07774381 -1.        ]
 [ 2.97567388  4.66574153 -1.        ]
 [ 3.69153875  5.3477064  -1.        ]
 [ 3.7811981   6.60631952 -1.        ]
 [ 7.46935877  3.40746184  1.        ]
 [ 2.51897288  5.7773249  -1.        ]
 [ 6.0519454   3.31694261  1.        ]
 [ 5.89678115  4.92294203  1.        ]
 [ 5.14590426  3.92220667  1.        ]
 [ 1.81114074  6.69845715 -1.        ]
 [ 4.81936782  2.17356146  1.        ]
 [ 5.09270164  4.18802979  1.        ]
 [ 4.13689136  5.23850779 -1.        ]
 [ 3.52106488  5.23585608 -1.        ]
 [ 1.4170616   5.52034419 -1.        ]
 [ 6.15634897  3.67643329  1.        ]
 [ 6.90082649  1.62504871  1.        ]
 [ 5.97181777  2.90154748  1.        ]
 [ 5.65208785  2.36415392  1.        ]
 [ 3.14195316  7.36453185 -1.        ]
 [ 6.40198936  4.33652795  1.        ]
 [ 3.91017891  5.86629844 -1.        ]
 [ 5.82007516  3.05216508  1.        ]
 [ 7.45427351  3.90604466  1.        ]
 [ 3.60884383  4.3432849  -1.        ]
 [ 5.81281615  4.84926373  1.        ]
 [ 2.45713852  6.77325298 -1.        ]
 [ 6.15494743  2.23008393  1.        ]
 [ 6.12691209  5.16323595  1.        ]
 [ 6.8644362   3.29823817  1.        ]
 [ 5.84864279  2.25254519  1.        ]
 [ 7.22244507  3.26705087  1.        ]
 [ 5.25783498  4.3263859   1.        ]
 [ 1.06372019  5.98431789 -1.        ]
 [ 6.12898291  3.31872765  1.        ]
 [ 2.27440262  4.15693045 -1.        ]
 [ 4.49448454  4.24410942 -1.        ]
 [ 7.86755799  1.82687659  1.        ]
 [ 3.09772497  6.85792392 -1.        ]
 [ 5.68844747  2.5444675   1.        ]
 [ 1.95474663  5.01448926 -1.        ]
 [ 3.09740017  4.87198867 -1.        ]
 [ 2.9474327   6.33996498 -1.        ]
 [ 3.2799246   6.69377315 -1.        ]
 [ 2.34759142  6.04949498 -1.        ]
 [ 6.3024719   1.92006849  1.        ]
 [ 5.79484174  2.84498991  1.        ]
 [ 3.69474914  5.47881069 -1.        ]
 [ 2.65401822  5.18700701 -1.        ]
 [ 1.32399619  5.7196445  -1.        ]
 [ 6.06651722  4.12663592  1.        ]
 [ 2.76062082  7.65813068 -1.        ]
 [ 0.22740724  5.00687639 -1.        ]
 [ 7.48825219  3.62523145  1.        ]
 [ 4.76517418  2.34897441  1.        ]
 [ 3.78632796  4.87317419 -1.        ]
 [ 7.49407907  3.94725197  1.        ]
 [ 6.46278226  1.4552289   1.        ]
 [ 5.64044684  4.92953205  1.        ]
 [ 6.40234164  1.96575716  1.        ]
 [ 7.76405235  4.8831507   1.        ]
 [ 4.83485016  2.64600609  1.        ]
 [ 0.93001497  6.45093446 -1.        ]
 [ 3.61037938  6.6203583  -1.        ]
 [ 5.02272212  4.94362119  1.        ]
 [ 3.23218104  8.2567235  -1.        ]
 [ 4.46375631  0.77659685  1.        ]
 [ 7.23029068  3.57659082  1.        ]
 [ 2.68067158  5.31055082 -1.        ]
 [ 5.59682305  1.70714309  1.        ]
 [ 6.17742614  1.77456448  1.        ]
 [ 6.97663904  1.8319065   1.        ]
 [ 2.8640503   5.70816264 -1.        ]
 [ 3.42625873  5.3159891  -1.        ]
 [ 2.68911383  5.73226646 -1.        ]
 [ 5.11221425  2.32566734  1.        ]
 [ 6.01050002  3.77179055  1.        ]
 [ 6.70657317  2.82845367  1.        ]
 [ 2.5335809   5.26932225 -1.        ]
 [ 5.63725883  2.56217996  1.        ]
 [ 7.20237985  2.79170124  1.        ]
 [ 2.70220912  4.7859226  -1.        ]
 [ 4.95591231  6.84163126 -1.        ]
 [ 4.95144703  1.50874241  1.        ]
 [ 4.38610215  1.68409259  1.        ]
 [ 2.63081816  4.69347315 -1.        ]
 [ 4.21114529  4.52816499 -1.        ]
 [ 1.38304396  5.53928021 -1.        ]
 [ 6.20827498  2.96071718  1.        ]
 [ 6.12167502  4.91006495  1.        ]
 [ 7.05445173  4.5430146   1.        ]
 [ 6.3130677   3.61407937  1.        ]
 [ 2.044055    5.28039561 -1.        ]
 [ 1.53357567  6.34965446 -1.        ]
 [ 3.08842209  5.60515049 -1.        ]
 [ 6.77749036  4.11701629  1.        ]
 [ 3.49374178  6.64331447 -1.        ]
 [ 5.18685372  3.94942081  1.        ]
 [ 6.05616534  3.01747916  1.        ]
 [ 7.9507754   3.63503144  1.        ]
 [ 2.98297959  5.95782855 -1.        ]
 [ 1.93998418  4.38352581 -1.        ]
 [ 2.58995031  6.09435159 -1.        ]
 [ 6.4105985   4.48051479  1.        ]
 [ 1.61663604  5.522026   -1.        ]
 [ 5.12920285  2.19659034  1.        ]
 [ 6.97873798  1.729515    1.        ]
 [ 6.3869025   4.71334272  1.        ]
 [ 0.96931553  5.79309632 -1.        ]
 [ 4.53637705  4.97749316 -1.        ]
 [ 3.15650654  5.84233298 -1.        ]
 [ 7.89588918  1.39794234  1.        ]
 [ 2.05555374  5.61512019 -1.        ]
 [ 5.31518991  3.68159452  1.        ]
 [ 3.52389102  5.80934651 -1.        ]
 [ 4.0996596   5.88183595 -1.        ]
 [ 4.30184623  6.16422776 -1.        ]
 [ 3.06326199  5.68770775 -1.        ]
 [ 6.95008842  2.58638102  1.        ]
 [ 6.14404357  4.86755896  1.        ]
 [ 6.40015721  1.65224094  1.        ]
 [ 5.61267318  3.39600671  1.        ]
 [ 3.41605005  4.81611936 -1.        ]
 [ 3.37005589  6.85255194 -1.        ]
 [ 2.42421203  4.56220853 -1.        ]
 [ 4.01920353  3.03183056  1.        ]
 [ 5.25930895  5.9383736  -1.        ]
 [ 6.33367433  3.8024564   1.        ]
 [ 2.36256297  7.0685094  -1.        ]]

这就舒服多了
In [ ]

print(all_data.shape)
#截取出坐标数据
train_data_x=all_data[:150,:2]
#截取出标签数据
train_data_y=all_data[:150,-1].reshape(150,1)

print(train_data_x.shape)

print(train_data_y.shape)

(200, 3)
(150, 2)
(150, 1)

千万不要弄混了哦,现在Y不是Y轴而是一个标签。x不是单一的x输入量而是一个坐标
In [ ]

#将数据转化为tensor形式
x_data = paddle.to_tensor(train_data_x.astype('float32'))
y_data = paddle.to_tensor(train_data_y.astype('float32'))

我们要完成的公式是:

y=w1x1+w2x2+by=w1x1+w2x2+by=w1x1+w2x2+b

注意:w1 w2 B 是学出来的
In [ ]

#初始化一个感知“鸡”
#这个感知“鸡”的公式可以去看一下官方文档哦

linear = paddle.nn.Linear(in_features=2, out_features=1)
#初始化一个优化函数帮助我们训练感知“鸡”
mse_loss = paddle.nn.MSELoss()
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters = linear.parameters())

开始训练!
In [ ]

# 定义一下要训练多少回合,这只鸡看起来就不太聪明咱们多训练一会
total_epoch = 50000
#构建训练过程
for i in range(total_epoch):
    #将数据给到定义好的linear感知“鸡”中,对就是要 赶 只 鸡
    y_predict = linear(x_data)
    #获取loss
    loss = mse_loss(y_predict, y_data)
    #反向传播
    loss.backward()
    sgd_optimizer.step()
    sgd_optimizer.clear_grad()
    #获取感知“鸡”中的w1
    w1_after_opt = linear.weight.numpy()[0].item()
    #获取感知“鸡”中的w2
    w2_after_opt = linear.weight.numpy()[1].item()
    #获取感知“鸡”中的b
    b_after_opt = linear.bias.numpy().item()
    #每1000次输出一次数据
    if i%1000 == 0:
        print("epoch {} loss {}".format(i, loss.numpy()))
        print("w1 after optimize: {}".format(w1_after_opt))
        print("w2 after optimize: {}".format(w2_after_opt))
        print("b after optimize: {}".format(b_after_opt))
print("finished training, loss {}".format(loss.numpy()))

epoch 0 loss [2.0655026]
w1 after optimize: -0.029253046959638596
w2 after optimize: 0.16224634647369385
b after optimize: -0.0012456965632736683
epoch 1000 loss [0.16172205]
w1 after optimize: 0.28587839007377625
w2 after optimize: -0.2802807688713074
b after optimize: -0.02707645483314991
epoch 2000 loss [0.161649]
w1 after optimize: 0.2867790162563324
w2 after optimize: -0.27933844923973083
b after optimize: -0.03552131727337837
epoch 3000 loss [0.1615831]
w1 after optimize: 0.2876250445842743
w2 after optimize: -0.27843284606933594
b after optimize: -0.04354376718401909
epoch 4000 loss [0.16152361]
w1 after optimize: 0.28842854499816895
w2 after optimize: -0.2775728404521942
b after optimize: -0.05116528645157814
epoch 5000 loss [0.1614699]
w1 after optimize: 0.2891915738582611
w2 after optimize: -0.27675503492355347
b after optimize: -0.05840545892715454
epoch 6000 loss [0.16142143]
w1 after optimize: 0.28991734981536865
w2 after optimize: -0.27597928047180176
b after optimize: -0.06528422236442566
epoch 7000 loss [0.1613777]
w1 after optimize: 0.2906062602996826
w2 after optimize: -0.2752418518066406
b after optimize: -0.0718187764286995
epoch 8000 loss [0.16133823]
w1 after optimize: 0.29126062989234924
w2 after optimize: -0.2745411992073059
b after optimize: -0.07802629470825195
epoch 9000 loss [0.1613026]
w1 after optimize: 0.29188188910484314
w2 after optimize: -0.27387499809265137
b after optimize: -0.08392355591058731
epoch 10000 loss [0.16127044]
w1 after optimize: 0.29247206449508667
w2 after optimize: -0.27324196696281433
b after optimize: -0.08952632546424866
epoch 11000 loss [0.16124143]
w1 after optimize: 0.29303333163261414
w2 after optimize: -0.27264124155044556
b after optimize: -0.09484925121068954
epoch 12000 loss [0.16121522]
w1 after optimize: 0.2935665249824524
w2 after optimize: -0.2720704674720764
b after optimize: -0.09990623593330383
epoch 13000 loss [0.1611916]
w1 after optimize: 0.2940731346607208
w2 after optimize: -0.2715282440185547
b after optimize: -0.10471048951148987
epoch 14000 loss [0.16117024]
w1 after optimize: 0.2945544421672821
w2 after optimize: -0.27101314067840576
b after optimize: -0.10927443951368332
epoch 15000 loss [0.161151]
w1 after optimize: 0.29501140117645264
w2 after optimize: -0.2705240845680237
b after optimize: -0.11361029744148254

感知机应用实例——作业二

In [1]

#导入需要的包
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import paddle
print("本教程基于Paddle的版本号为:"+paddle.__version__)

本教程基于Paddle的版本号为:2.0.0

Step1:准备数据。

(1)数据集介绍

MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~910个数字。

(2)transform函数是定义了一个归一化标准化的标准

(3)train_dataset和test_dataset

paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集

transform=transform参数则为归一化标准
In [2]

#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('加载完成')

下载并加载训练数据
加载完成

In [3]

#让我们一起看看数据集中的图片是什么样子的
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))

AxesImage(18,18;111.6x108.72)
train_data0 的标签为: [5]

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)

<Figure size 144x144 with 1 Axes>

In [4]

#让我们再来看看数据样子是什么样的吧
print(train_data0)

[[-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.9764706  -0.85882354 -0.85882354 -0.85882354 -0.01176471  0.06666667
   0.37254903 -0.79607844  0.3019608   1.          0.9372549  -0.00392157
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.7647059  -0.7176471  -0.2627451   0.20784314
   0.33333334  0.9843137   0.9843137   0.9843137   0.9843137   0.9843137
   0.7647059   0.34901962  0.9843137   0.8980392   0.5294118  -0.49803922
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.6156863   0.8666667   0.9843137   0.9843137   0.9843137
   0.9843137   0.9843137   0.9843137   0.9843137   0.9843137   0.96862745
  -0.27058825 -0.35686275 -0.35686275 -0.56078434 -0.69411767 -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.85882354  0.7176471   0.9843137   0.9843137   0.9843137
   0.9843137   0.9843137   0.5529412   0.42745098  0.9372549   0.8901961
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.37254903  0.22352941 -0.16078432  0.9843137
   0.9843137   0.60784316 -0.9137255  -1.         -0.6627451   0.20784314
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.8901961  -0.99215686  0.20784314
   0.9843137  -0.29411766 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.          0.09019608
   0.9843137   0.49019608 -0.9843137  -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -0.9137255
   0.49019608  0.9843137  -0.4509804  -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.7254902   0.8901961   0.7647059   0.25490198 -0.15294118 -0.99215686
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.3647059   0.88235295  0.9843137   0.9843137  -0.06666667
  -0.8039216  -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.64705884  0.45882353  0.9843137   0.9843137
   0.1764706  -0.7882353  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.8745098  -0.27058825  0.9764706
   0.9843137   0.46666667 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.          0.9529412
   0.9843137   0.9529412  -0.49803922 -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.6392157   0.01960784  0.43529412  0.9843137
   0.9843137   0.62352943 -0.9843137  -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.69411767  0.16078432  0.79607844  0.9843137   0.9843137   0.9843137
   0.9607843   0.42745098 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -0.8117647  -0.10588235
   0.73333335  0.9843137   0.9843137   0.9843137   0.9843137   0.5764706
  -0.3882353  -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.81960785 -0.48235294  0.67058825  0.9843137
   0.9843137   0.9843137   0.9843137   0.5529412  -0.3647059  -0.9843137
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -0.85882354  0.34117648  0.7176471   0.9843137   0.9843137   0.9843137
   0.9843137   0.5294118  -0.37254903 -0.92941177 -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -0.5686275   0.34901962
   0.77254903  0.9843137   0.9843137   0.9843137   0.9843137   0.9137255
   0.04313726 -0.9137255  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.          0.06666667  0.9843137
   0.9843137   0.9843137   0.6627451   0.05882353  0.03529412 -0.8745098
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]]

Step2.网络配置

以下的代码判断就是定义一个简单的多层感知器,一共有三层,两个大小为100的隐层和一个大小为10的输出层,因为MNIST数据集是手写09的灰度图像,类别有10个,所以最后的输出大小是10。最后输出层的激活函数是Softmax,所以最后的输出层相当于一个分类器。加上一个输入层的话,多层感知器的结构是:输入层-->>隐层-->>隐层-->>输出层。

请补全网络代码
In [14]

# 定义多层感知器 
#动态图定义多层感知器
class multilayer_perceptron(paddle.nn.Layer):
    def __init__(self):
        super(multilayer_perceptron,self).__init__()
        #请在这里补全网络代码
        self.flatten=paddle.nn.Flatten()
        self.hidden=paddle.nn.Linear(in_features=784,out_features=128)
        self.output=paddle.nn.Linear(in_features=128,out_features=10)
    def forward(self, x):
        #请在这里补全传播过程的代码
        x=self.flatten(x)
        x=self.hidden(x) #经过隐藏层
        x=F.relu(x) #经过激活层
        x=self.output(x)
        return x
        y = self.softmax(x)
        return y

import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor

In [ ]

#请在这里定义卷积网络的代码
#注意:定义完成卷积的代码后,后面的代码是需要修改的!

In [15]

from paddle.metric import Accuracy

# 用Model封装模型
model=paddle.Model(multilayer_perceptron())


# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

# 配置模型
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())

# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=2,batch_size=64,save_dir='multilayer_perceptron',verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2
step 938/938 [==============================] - loss: 0.2972 - acc: 0.8950 - 8ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0983 - acc: 0.9455 - 6ms/step         
Eval samples: 10000
Epoch 2/2
step 938/938 [==============================] - loss: 0.0791 - acc: 0.9489 - 8ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0125 - acc: 0.9541 - 7ms/step         
Eval samples: 10000
save checkpoint at /home/aistudio/multilayer_perceptron/final

In [17]

# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=2,batch_size=64,save_dir='multilayer_perceptron',verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2
step 938/938 [==============================] - loss: 0.1219 - acc: 0.9807 - 8ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0038 - acc: 0.9692 - 6ms/step         
Eval samples: 10000
Epoch 2/2
step 938/938 [==============================] - loss: 0.0096 - acc: 0.9813 - 8ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 7.7252e-04 - acc: 0.9772 - 6ms/step     
Eval samples: 10000
save checkpoint at /home/aistudio/multilayer_perceptron/final

In [18]

#获取测试集的第一个图片
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]
test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))
#模型预测
result = model.predict(test_dataset, batch_size=1)
#打印模型预测的结果
print('test_data0 预测的数值为:%d' % np.argsort(result[0][0])[0][-1])

AxesImage(18,18;111.6x108.72)
test_data0 的标签为: [7]
Predict begin...
step 10000/10000 [==============================] - 1ms/step        
Predict samples: 10000
test_data0 预测的数值为:7

<Figure size 144x144 with 1 Axes>

In [20]

import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import Compose, Normalize

# 数据处理(归一化)(把0-255的归一到-1~1)
transform = Compose([Normalize(mean=[127.5],
                               std=[127.5],
                               data_format='CHW')])
#导入MNIST数据
train_dataset=paddle.vision.datasets.MNIST(mode="train", transform=transform)
val_dataset=paddle.vision.datasets.MNIST(mode="test", transform=transform)

In [22]

#定义模型
class LeNetModel(paddle.nn.Layer):
    def __init__(self):
        super(LeNetModel, self).__init__()
        # 创建卷积和池化层块,每个卷积层后面接着2x2的池化层
        #卷积层L1
        self.conv1 = paddle.nn.Conv2D(in_channels=1,
                                      out_channels=6,
                                      kernel_size=5,
                                      stride=1)
        #池化层L2
        self.pool1 = paddle.nn.MaxPool2D(kernel_size=2,
                                         stride=2)
        #卷积层L3
        self.conv2 = paddle.nn.Conv2D(in_channels=6,
                                      out_channels=16,
                                      kernel_size=5,
                                      stride=1)
        #池化层L4
        self.pool2 = paddle.nn.MaxPool2D(kernel_size=2,
                                         stride=2)
        #线性层L5
        self.fc1=paddle.nn.Linear(256,120)
        #线性层L6
        self.fc2=paddle.nn.Linear(120,84)
        #线性层L7
        self.fc3=paddle.nn.Linear(84,10)

    #正向传播过程
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = paddle.flatten(x, start_axis=1,stop_axis=-1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        out = self.fc3(x)
        return out

model=paddle.Model(LeNetModel())

In [16]

from paddle.metric import Accuracy


optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())


model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

model.fit(train_dataset,
          val_dataset,
          epochs=5,
          batch_size=64,
          verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5
step 938/938 [==============================] - loss: 0.1565 - acc: 0.9587 - 7ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0184 - acc: 0.9657 - 6ms/step         
Eval samples: 10000
Epoch 2/5
step 938/938 [==============================] - loss: 0.0206 - acc: 0.9680 - 8ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0143 - acc: 0.9679 - 6ms/step         
Eval samples: 10000
Epoch 3/5
step 938/938 [==============================] - loss: 0.0480 - acc: 0.9722 - 8ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0046 - acc: 0.9694 - 6ms/step         
Eval samples: 10000
Epoch 4/5
step 938/938 [==============================] - loss: 0.0072 - acc: 0.9751 - 8ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0045 - acc: 0.9676 - 6ms/step         
Eval samples: 10000
Epoch 5/5
step 938/938 [==============================] - loss: 0.1231 - acc: 0.9778 - 8ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 7.3149e-04 - acc: 0.9725 - 6ms/step     
Eval samples: 10000

In [19]

#获取测试集的第一个图片
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]
test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))
#模型预测
result = model.predict(test_dataset, batch_size=1)
#打印模型预测的结果
print('test_data0 预测的数值为:%d' % np.argsort(result[0][0])[0][-1])

AxesImage(18,18;111.6x108.72)
test_data0 的标签为: [7]
Predict begin...
step 10000/10000 [==============================] - 1ms/step        
Predict samples: 10000
test_data0 预测的数值为:7

<Figure size 144x144 with 1 Axes>

总结

通过本次学习我从机器学习中线代模型过渡到了神经网络模型,有了更深的理解与认识,希望通过下次学习更加加深自己对于CV方向的了解。
附上课链接:https://aistudio.baidu.com/aistudio/course/introduce/11939?directly=1&shared=1

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值