【机器学习基础】k近邻算法

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ Python机器学习 ⌋ ⌋ 机器学习是一门人工智能的分支学科,通过算法和模型让计算机从数据中学习,进行模型训练和优化,做出预测、分类和决策支持。Python成为机器学习的首选语言,依赖于强大的开源库如Scikit-learn、TensorFlow和PyTorch。本专栏介绍机器学习的相关算法以及基于Python的算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/Python_machine_learning


  本文就将介绍一个最基本的分类和回归算法:k近邻(k-nearest neighbor, KNN)算法。KNN是最简单也是最重要的机器学习算法之一,它的思想可以用一句话来概括:“相似的数据往往拥有相同的类别”,这也对应于中国的一句谚语:“物以类聚,人以群分”。具体来说,我们在生活中常常可以观察到,同一种类的数据之间特征更为相似,而不同种类的数据之间特征差别更大。例如,在常见的花中,十字花科的植物大多数有4片花瓣,而夹竹桃科的植物花瓣大多数是5的倍数。虽然存在例外,但如果我们按花瓣数量对植物做分类,那么花瓣数量相同或成倍数关系的植物,相对更可能属于同一种类。

  下面,本文将详细讲解并动手实现KNN算法,再将其应用到不同的任务中去。

一、KNN算法的原理

  在分类任务中,我们的目标是判断样本 x \boldsymbol{x} x的类别 y y y。KNN首先会观察与该样本点距离最近的 K K K个样本,统计这些样本所属的类别。然后,将当前样本归到出现次数最多的类中。我们用KNN算法的一张经典示意图来更清晰地说明其思想。如图1所示,假设共有两个类别的数据点:蓝色圆形和橙色正方形,而中心位置的绿色样本当前尚未被分类。根据统计近邻的思路:

  • K = 3 K=3 K=3 时,绿色样本的3个近邻中有两个橙色正方形样本,一个蓝色圆形样本,因此应该将绿色样本点归类为橙色正方形;

  • K = 5 K=5 K=5 时,绿色样本的5个近邻中有两个橙色正方形样本,三个蓝色圆形样本,因此应该将绿色样本点归类为蓝色圆形。

在这里插入图片描述

图1 KNN算法示意图

  从这个例子中可以看出,KNN的基本思路是让当前样本的分类服从邻居中的多数分类。但是,当 K K K的大小变化时,由于邻居的数量变化,其多数类别也可能会变化,从而改变对当前样本的分类判断。因此,决定 K K K的大小是KNN中最重要的部分之一。直观上来说,当 K K K的取值太小时,分类结果很容易受到待分类样本周围的个别噪声数据影响;当 K K K的取值太大时,又可能将远处一些不相关的样本包含进来。因此,我们应该根据数据集动态地调整 K K K的大小,以得到最理想的结果。

  下面,我们用数学语言来描述KNN算法。设已分类样本的集合为 X 0 \mathcal{X_0} X0。对于一个待分类的样本 x \boldsymbol{x} x,定义其邻居 N K ( x ) \mathcal{N}_K(\boldsymbol{x}) NK(x) X 0 \mathcal{X_0} X0中与 x \boldsymbol{x} x距离最近的 K K K个样本 x 1 , x 2 , ⋯   , x K \boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_K x1,x2,,xK 组成的集合,这些样本对应的类别分别是 y 1 , y 2 , ⋯   , y K y_1,y_2,\cdots,y_K y1,y2,,yK。我们统计集合 N K ( x ) \mathcal{N}_K(\boldsymbol{x}) NK(x)中类别为 j j j的样本的数量,记为 G j ( x ) G_j(\boldsymbol{x}) Gj(x) G j ( x ) = ∑ x i ∈ N K ( x ) I ( y i = j ) G_j(\boldsymbol{x})=\sum_{\boldsymbol{x}_i\in \mathcal{N}_K(\boldsymbol{x})}\mathbb{I}(y_i=j) Gj(x)=xiNK(x)I(yi=j) 其中, I ( p ) \mathbb{I}(p) I(p)是示性函数,其自变量 p p p是一个命题。当 p p p为真时, I ( p ) = 1 \mathbb{I}(p)=1 I(p)=1,反之,当 p p p为假时, I ( p ) = 0 \mathbb{I}(p)=0 I(p)=0。最后,我们将 x \boldsymbol{x} x的类别 y ^ ( x ) \hat{y}(\boldsymbol{x}) y^(x)判断为使最大的类别: y ^ ( x ) = arg ⁡ max ⁡ j G j ( x ) \hat{y}(\boldsymbol{x})=\arg\max_j G_j(\boldsymbol{x}) y^(x)=argjmaxGj(x)

  与分类任务类似,我们还可以将KNN应用于回归任务。对于样本 x \boldsymbol{x} x,我们需要预测其对应的实数值 y y y。同样的,KNN考虑 K K K个相邻的样本点 x i ∈ N K ( x ) \boldsymbol{x}_i\in \mathcal{N}_K(\boldsymbol{x}) xiNK(x),将这些样本点对应的实数值 y i y_i yi进行加权平均,就得到样本 x \boldsymbol{x} x的预测结果 y ^ ( x ) \hat{y}(\boldsymbol{x}) y^(x) y ^ ( x ) = ∑ x i ∈ N K ( x ) w i y i , 其中 ∑ i = 1 K w i = 1 \hat{y}(\boldsymbol{x})=\sum_{\boldsymbol{x}_i\in \mathcal{N}_K(\boldsymbol{x})}w_iy_i,其中\sum_{i=1}^Kw_i=1 y^(x)=xiNK(x)wiyi,其中i=1Kwi=1

  在这里,权重 w i w_i wi代表不同邻居对当前样本的重要程度,权重越大,该邻居的值 y i y_i yi对最后的预测影响也越大。我们既可以预先定义好权重,例如简单地认为所有邻居的重要程度相同,令所有 w i = 1 K w_i=\frac{1}{K} wi=K1;也可以根据数据集的特性设置权重与距离的关系,例如让权重与距离成反比;还可以将权重作为模型的参数,通过学习得到。

二、用KNN算法完成分类任务

  该任务将在MNIST数据集上应用KNN算法,完成分类任务。MNIST是手写数字数据集,其中包含了很多手写数字0~9的黑白图像,每张图像都由2828个像素点组成。可以在MNIST的官方网站上得到更多数据集的信息。读入后,每个像素点用1或0表示,1代表黑色像素,属于图像背景;0代表白色像素,属于手写数字。我们的任务是用KNN对不同的手写数字进行分类。为了更清晰地展示数据集的内容,下面先将前两个数据点转成黑白图像显示出来。此外,把每个数据集都按8:2的比例随机划分成训练集(training set)和测试集(test set)。我们先在训练集上应用KNN算法,再在测试集上测试算法的表现。

  我们会用到NumPy和Matplotlib两个Python库。NumPy是科学计算库,包含了大量常用的计算工具,如数组工具、数据统计、线性代数等,我们用NumPy中的数组来存储数据,并且会用到其中的许多函数。Matplotlib是可视化库,包含了各种绘图工具,我们用Matplotlib进行数据可视化,以及绘制各种训练结果。

import matplotlib.pyplot as plt
import numpy as np
import os

# 读入mnist数据集
m_x = np.loadtxt('mnist_x', delimiter=' ')
m_y = np.loadtxt('mnist_y')

# 数据集可视化
data = np.reshape(np.array(m_x[0], dtype=int), [28, 28])
plt.figure()
plt.imshow(data, cmap='gray')

# 将数据集分为训练集和测试集
ratio = 0.8
split = int(len(m_x) * ratio)
# 打乱数据
np.random.seed(0)
idx = np.random.permutation(np.arange(len(m_x)))
m_x = m_x[idx]
m_y = m_y[idx]
x_train, x_test = m_x[:split], m_x[split:]
y_train, y_test = m_y[:split], m_y[split:]
plt.show()

在这里插入图片描述

  下面是KNN算法的具体实现。首先,我们定义样本之间的距离。简单起见,我们采用最常用的欧氏距离(Euclidean distance),也就是我们最生活中最常用、最直观的空间距离。对于 n n n维空间中的两个点 x = ( x 1 , x 2 , ⋯   , x n ) \boldsymbol{x}=(x_1,x_2,\cdots,x_n) x=(x1,x2,,xn) y = ( y 1 , y 2 , ⋯   , y n ) \boldsymbol{y}=(y_1,y_2,\cdots,y_n) y=(y1,y2,,yn),其欧氏距离为: d Euc ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d_\text{Euc}(\boldsymbol{x},\boldsymbol{y})=\sqrt{\sum_{i=1}^n(x_i-y_i)^2} dEuc(x,y)=i=1n(xiyi)2

# 欧氏距离
def distance(a, b):
    return np.sqrt(np.sum(np.square(a - b)))

  为了方便,我们将KNN算法定义成类,其初始化参数是 K K K和类别的数量。每一部分的含义在代码中有详细注释。

class KNN:
    def __init__(self, k, label_num):
        self.k = k
        self.label_num = label_num # 类别的数量

    def fit(self, x_train, y_train):
        # 在类中保存训练数据
        self.x_train = x_train
        self.y_train = y_train

    def get_knn_indices(self, x):
        # 获取距离目标样本点最近的K个样本点的标签
        # 计算已知样本的距离
        dis = list(map(lambda a: distance(a, x), self.x_train)) 
        # 按距离从小到大排序,并得到对应的下标
        knn_indices = np.argsort(dis) 
        # 取最近的K个
        knn_indices = knn_indices[:self.k] 
        return knn_indices

    def get_label(self, x):
        # 对KNN方法的具体实现,观察K个近邻并使用np.argmax获取其中数量最多的类别
        knn_indices = self.get_knn_indices(x)
        # 类别计数
        label_statistic = np.zeros(shape=[self.label_num]) 
        for index in knn_indices:
            label = int(self.y_train[index])
            label_statistic[label] += 1
        # 返回数量最多的类别
        return np.argmax(label_statistic) 

    def predict(self, x_test): 
        # 预测样本 test_x 的类别
        predicted_test_labels = np.zeros(shape=[len(x_test)], dtype=int)
        for i, x in enumerate(x_test):
            predicted_test_labels[i] = self.get_label(x)
        return predicted_test_labels

  最后,我们在测试集上观察算法的效果,并对不同的 K K K的取值进行测试。

for k in range(1, 10):
    knn = KNN(k, label_num=10)
    knn.fit(x_train, y_train)
    predicted_labels = knn.predict(x_test)

    accuracy = np.mean(predicted_labels == y_test)
    print(f'K的取值为 {k}, 预测准确率为 {accuracy * 100:.1f}%')

在这里插入图片描述

思考:若将距离distance(a, b)改为曼哈顿距离,观察对分类效果的影响。

# 曼哈顿距离
def distance(a, b):
    return np.sum(np.abs(a - b))

在这里插入图片描述

  • 稀疏数据:在数据稀疏的情况下,尤其是在特征空间中存在大量零值时(例如,用户对电影的评分矩阵),曼哈顿距离可能更合适,因为它不会因为零值特征的欧氏距离为零而忽略非零特征之间的差异。
  • 异常值敏感性:欧氏距离对异常值比较敏感,因为异常值会在高维空间中拉大与其他点的距离。曼哈顿距离由于只考虑各个维度上的绝对差异,因此对异常值的敏感性较低。

三、使用scikit-learn实现KNN算法

  Python作为机器学习的常用工具,有许多Python库已经封装好了机器学习常用的各种算法。这些库通常经过了很多优化,其运行效率比上面我们自己实现的要高。所以,能够熟练掌握各种机器学习库的用法,也是机器学习的学习目标之一。其中,scikit-learn(简称sklearn)是一个常用的机器学习算法库,包含了数据处理工具和许多简单的机器学习算法。

KNN常用的参数及其说明:

class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=1, **kwargs)

参数名称说明
n_neighbors接收int。表示近邻点的个数,即K值。默认为5。
weights接收str或callable,可选参数有“uniform”和“distance”。表示近邻点的权重,“uniform”表示所有的邻近点权重相等;“distance”表示距离近的点比距离远的点的权重大。默认为“uniform”。
algorithm接收str,可选参数有“auto”“ball_tree”“kd_tree”和“brute”。表示搜索近邻点的算法。默认为“auto”,即自动选择。
leaf_size接收int。表示kd树和ball树算法的叶尺寸,它会影响树构建的速度和搜索速度,以及存储树所需的内存大小。默认为30。
p接收int。表示距离度量公式,1是曼哈顿距离公式;2是欧式距离。默认为2。
metric接收str或callable。表示树算法的距离矩阵。默认为“minkowski”。
metric_params接收dict。表示metric参数中接收的自定义函数的参数。默认为None。
n_jobs接收int。表示并行运算的(CPU)数量,默认为1,-1则是使用全部的CPU。

以sklearn库为例,来讲解如何使用封装好的KNN算法,并在高斯数据集gauss.csv上观察分类效果。该数据集包含一些平面上的点,分别由两个独立的二维高斯分布随机生成,每一行包含三个数,依次是点的 x x x y y y坐标和类别。首先,我们导入数据集并进行可视化。

from sklearn.neighbors import KNeighborsClassifier # sklearn中的KNN分类器
from matplotlib.colors import ListedColormap

# 读入高斯数据集
data = np.loadtxt('gauss.csv', delimiter=',')
x_train = data[:, :2]
y_train = data[:, 2]
print('数据集大小:', len(x_train))

# 可视化
plt.figure(figsize=(8, 6))
plt.scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1], c='blue', marker='o')
plt.scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1], c='red', marker='x')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.show()

在这里插入图片描述

  在高斯数据集中,我们将整个数据集作为训练集,将平面上的其他点作为测试集,观察KNN在不同的 K K K值下的分类效果。因此,我们不对数据集进行划分,而是在平面上以0.02为间距构造网格作为测试集。由于平面上的点是连续的,我们无法依次对它们测试,只能像这样从中采样。在没有特殊要求的情况下,我们一般采用最简单的均匀网格采样。这里,我们选用网格间距0.02是为了平衡测试点的个数和测试点的代表性,也可以调整该数值观察结果的变化。

# 设置步长
step = 0.02 
# 设置网格边界
x_min, x_max = np.min(x_train[:, 0]) - 1, np.max(x_train[:, 0]) + 1 
y_min, y_max = np.min(x_train[:, 1]) - 1, np.max(x_train[:, 1]) + 1
# 构造网格
xx, yy = np.meshgrid(np.arange(x_min, x_max, step), np.arange(y_min, y_max, step))
grid_data = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)

  在sklearn中,KNN分类器由KNeighborsClassifier定义,通过参数n_neighbors指定 K K K的大小。我们分别设置 K = 1 K=1 K=1 K = 3 K=3 K=3 K = 10 K=10 K=10 观察分类效果。数据集中的点用深色表示,平面上被分到某一类的点用与其相对应的浅色表示。可以看出,随着 K K K的增大,分类的边界变得更平滑,但错分的概率也在变大。

fig = plt.figure(figsize=(16,4.5))
# K值,读者可以自行调整,观察分类结果的变化
ks = [1, 3, 10] 
cmap_light = ListedColormap(['royalblue', 'lightcoral'])

for i, k in enumerate(ks):
    # 定义KNN分类器
    knn = KNeighborsClassifier(n_neighbors=k) 
    knn.fit(x_train, y_train)
    z = knn.predict(grid_data)

    # 画出分类结果
    ax = fig.add_subplot(1, 3, i + 1)
    ax.pcolormesh(xx, yy, z.reshape(xx.shape), cmap=cmap_light, alpha=0.7)
    ax.scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1], c='blue', marker='o')
    ax.scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1], c='red', marker='x')

    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_title(f'K = {k}')
plt.show()

在这里插入图片描述

四、用KNN算法完成回归任务——色彩风格迁移

  上面我们展示了KNN在分类任务上的效果,我们将KNN算法应用到回归任务——色彩风格迁移上。在该任务中,我们的目标是给一张黑白照片上色,同时要求上色的风格要接近另一张彩色照片。如图2所示,内容图像A是一张上海外滩的黑白风景照片,风格图像B是梵高著名的画作《星空》。通过色彩风格迁移,我们可以达到图像C中的上色效果。梵高作为著名的荷兰后印象派画家,其画作色彩比较夸张奔放,常常采用一些高明度、高纯度的色彩。得益于其富有特色的色彩,我们可以从风格迁移中明显观察到图像风格的转变。因此,在后续任务中,我们都采用这张外滩风景作为内容图像,而用梵高的不同作品当做风格图像。

在这里插入图片描述

图2 任务图像展示

  首先,我们安装导入必要的库。会用到 scikit-image(简称 skimage)这一图像处理库,以及sklearn中的KNN回归器KNeighborsRegressorKNeighborsRegressor参数及说明如下。

KNeighborsRegressor(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=4, p=2, weights='distance')

参数名称参数说明
algorithm接收str,可选参数,默认=‘auto’。算法用于计算最近邻居,默认为’auto’。可以选择的值包括:‘auto’: 根据训练数据自动选择合适的算法,在大多数情况下,会选择’ball_tree’、'kd_tree’或’brute’中的最佳算法。‘ball_tree’: 使用BallTree数据结构来寻找最近邻居。‘kd_tree’: 使用KDTree数据结构来寻找最近邻居。‘brute’: 使用暴力搜索算法,计算所有可能的邻居并选择最接近的。
leaf_size接收int,可选参数,默认=30。用于BallTree或KDTree的叶子大小。影响树的构建和查询速度,具体取决于数据的特征数量。
metric接收str or callable,可选参数,默认=‘minkowski’。用于计算距离的度量方法。如果是字符串,默认为’minkowski’。可以是预定义的距离度量字符串,如:‘euclidean’: 欧氏距离;‘manhattan’: 曼哈顿距离;‘chebyshev’: 切比雪夫距离;‘minkowski’: 通用的Minkowski距离。当p=2时等同于欧氏距离。还可以传入自定义的距离度量函数,函数应该接受两个参数(每个数据点的特征向量),并返回它们之间的距离。
metric_params接收dict,可选参数,默认=None。如果指定了metric参数并且使用Minkowski距离,则可以通过此参数传递额外的关键字参数给距离度量函数。
n_jobs接收int,可选参数,默认=1。并行运行的任务数。如果设置为 -1,则使用所有可用的CPU核心。
n_neighbors接收int,可选参数,默认=4。搜索的最近邻居的数量。
p接收int,可选参数,默认=2。Minkowski距离的参数。当p=1时,为曼哈顿距离,p=2时为欧氏距离。
weights接收str or callable,可选参数,默认=‘distance’。用于预测的权重函数。可能的值包括:‘uniform’: 所有邻居的权重都相等。‘distance’: 权重与距离的倒数成正比。即,更近的邻居对预测的贡献更大。
!pip install scikit-image
from skimage import io # 图像输入输出
from skimage.color import rgb2lab, lab2rgb # 图像通道转换
from sklearn.neighbors import KNeighborsRegressor # KNN 回归器
import os

path = 'style_transfer'

在这里插入图片描述

  在讲解KNN的用法之前,我们必须要了解如何表示图像的色彩。我们先将部分用到的梵高画作展示出来,有较为清晰的感受。数据集中,每幅画作都由 256 × \times × 256 个像素来表示。

data_dir = os.path.join(path, 'vangogh')
fig = plt.figure(figsize=(16, 5))
for i, file in enumerate(np.sort(os.listdir(data_dir))[:3]):
    img = io.imread(os.path.join(data_dir, file))
    ax = fig.add_subplot(1, 3, i + 1)
    ax.imshow(img)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_title(file)
plt.show()

在这里插入图片描述

(一)RGB空间与LAB空间

  我们知道,所有颜色都可以由三原色红、绿、蓝混合得到。因此,在计算机中,为了表示图像中每个像素的颜色,我们常用RGB表示法。其中 R(red)、G(green)和 B(blue)分别代表红、绿、蓝在颜色中所占的比例,均为 0~255 间的整数。将整张图像上每个像素的 RGB 值分别合在一起,就得到了如图3所示的图像的RGB矩阵。如果图像的高是 H H H,宽是 W W W,这一 H × W × 3 H\times W\times 3 H×W×3 的矩阵就包含了图像的色彩信息。

在这里插入图片描述

图3 RGB空间示意

  然而,RGB表示法中对数字大小的限制使得RGB并不能表示出所有颜色。除了RGB之外,计算机中还常用LAB法来表示颜色。其中,L(light)代表亮度,A 表示红、绿方向的分量,B 表示黄、蓝方向的分量。虽然LAB理论上也能表示所有颜色,但由于实际应用中的限制,一般规定 L 的范围是 0~100,0 代表黑色,100 代表白色;A 为-128~127,负数代表绿色,正数代表红色;B 也为-128~127,负数代表蓝色,正数代表黄色。图4展示了LAB空间的色彩变化。相比于RGB,LAB将亮度信息提取出来,与色彩信息独立,使我们可以在不改变黑白图像亮度的情况下对其上色,完成色彩风格迁移。

在这里插入图片描述

图4 LAB空间示意

(二)算法设计

  在确定了图像色彩的表示方式后,上色的过程就是确立从黑白图像到彩色图像的颜色映射的过程。然而,黑白图像中只有亮度信息,我们无法直接还原出其对应的颜色。因此,需要为其补充额外的信息。我们可以采用KNN算法来完成这一映射。首先,我们将风格图像也变成黑白的,提取出其灰度信息。接下来,最简单的思路是,将内容图像中的像素到黑白风格图像中进行匹配,用最接近的 K K K个像素的原始颜色取平均,作为该像素上色后的颜色。

  然而,这一想法所利用到的信息太少,最后上色的效果也不佳。在内容图像中,同样的灰度像素既可能出现在黄色的土地上,也可能出现在蓝色的天上。如果将这些差异很大的颜色取平均进行上色,自然得不到我们期望的效果。就像在一个人组成的方阵中,只靠身高去找人一样。同样身高的人可能有很多,我们很难准确定位要找的人。但是,如果我们又知道了目标周围相邻的人的身高,就可以大大提高精确度。因此,我们将匹配的范围扩大,对于内容图像中的任意一个像素点,我们取其周围相邻的8个像素,组成 3 × \times × 3 的窗口,再向黑白风格图像中寻找与其最相似的 K K K个 3 × \times × 3 的像素窗口。最后,把这些窗口的中心像素的颜色取平均,作为该像素的颜色。图5描述了上述使用KNN算法的思路。

在这里插入图片描述

图5 用KNN解决色彩迁移问题的算法

  下面,我们就来实现这一算法。首先记录风格图像中每个窗口对应的原始颜色,供最后上色使用。

# block_size表示向外扩展的层数,扩展1层即3*3
block_size = 1 

def read_style_image(file_name, size=block_size):
    # 读入风格图像, 得到映射 X->Y
    # 其中X储存3*3像素格的灰度值,Y储存中心像素格的色彩值
    # 读取图像文件,设图像宽为W,高为H,得到W*H*3的RGB矩阵
    img = io.imread(file_name) 
    fig = plt.figure()
    plt.imshow(img)
    plt.xlabel('X axis')
    plt.ylabel('Y axis')
    plt.show()

    # 将RGB矩阵转换成LAB表示法的矩阵,大小仍然是W*H*3,三维分别是L、A、B
    img = rgb2lab(img) 
    # 取出图像的宽度和高度
    w, h = img.shape[:2] 
    
    X = []
    Y = []
    # 枚举全部可能的中心点
    for x in range(size, w - size): 
        for y in range(size, h - size):
            # 保存所有窗口
            X.append(img[x - size: x + size + 1, \
                y - size: y + size + 1, 0].flatten())
            # 保存窗口对应的色彩值a和b
            Y.append(img[x, y, 1:])
    return X, Y

  接下来,读取梵高的《星空》作为风格图像,并用sklearn中的KNN回归器建立模型。

X, Y = read_style_image(os.path.join(path, 'style.jpg')) # 建立映射

# weights='distance'表示邻居的权重与其到样本的距离成反比
knn = KNeighborsRegressor(n_neighbors=4, weights='distance')
knn.fit(X, Y)

在这里插入图片描述

  我们将内容图像分割成同样大小的窗口,并用KNN模型上色。

def rebuild(img, size=block_size):
    # 打印内容图像
    fig = plt.figure()
    plt.imshow(img)
    plt.xlabel('X axis')
    plt.ylabel('Y axis')
    plt.show()
    
    # 将内容图像转为LAB表示
    img = rgb2lab(img) 
    w, h = img.shape[:2]
    
    # 初始化输出图像对应的矩阵
    photo = np.zeros([w, h, 3])
    # 枚举内容图像的中心点,保存所有窗口
    print('Constructing window...')
    X = []
    for x in range(size, w - size):
        for y in range(size, h - size):
            # 得到中心点对应的窗口
            window = img[x - size: x + size + 1, \
                y - size: y + size + 1, 0].flatten()
            X.append(window)
    X = np.array(X)

    # 用KNN回归器预测颜色
    print('Predicting...')
    pred_ab = knn.predict(X).reshape(w - 2 * size, h - 2 * size, -1)
    # 设置输出图像
    photo[:, :, 0] = img[:, :, 0]
    photo[size: w - size, size: h - size, 1:] = pred_ab
    
    # 由于最外面size层无法构造窗口,简单起见,我们直接把这些像素裁剪掉
    photo = photo[size: w - size, size: h - size, :]
    return photo

  最后,我们设置相关参数,并展示风格迁移后的图像。

content = io.imread(os.path.join(path, 'input.jpg'))
new_photo = rebuild(content)
# 为了展示图像,我们将其再转换为RGB表示
new_photo = lab2rgb(new_photo)

fig = plt.figure()
plt.imshow(new_photo)
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.show()

在这里插入图片描述

:以上文中的数据集及相关资源下载地址:
链接:https://pan.quark.cn/s/3e67a53f0d14
提取码:xHBZ

  • 18
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Francek Chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值