利用鸢尾花数据集直观感受神经网络表征能力(pytorch)

 

import torch

import torch

import torchvision

from torch.autograd import Variable

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import cv2

from torch import nn

from numpy import *

import torch.nn.functional as F

import matplotlib.pyplot as plt

import numpy as np

from sklearn import datasets

 

from sklearn.decomposition import PCA

 

class NN_cliassifer(torch.nn.Module):

    def __init__(self) :

        super(NN_cliassifer, self).__init__()

 

        self.Line1 = torch.nn.Linear(4, 8)

        self.r1 = torch.nn.ReLU()

        self.Line2 = torch.nn.Linear(8, 6)

        self.r2 = torch.nn.ReLU()

        self.Line3 = torch.nn.Linear(6, 3)

 

    def forward(self, x):

 

        x = self.Line1(x)

        x1 = self.r1(x)

        x = self.Line2(x)

        x2 = self.r2(x)

        x = self.Line3(x)

        return x, x1, x2


 

net = NN_cliassifer()

cost = torch.nn.CrossEntropyLoss()



 

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

 

iris = datasets.load_iris()

 

X = torch.tensor(iris.data).float()

y = torch.tensor(iris.target).long()


 

for i in range(1000):

        optimizer.zero_grad()

        outputs,_,_ = net(X)

 

        loss = cost(outputs, y)

        loss.backward()

        optimizer.step()

        print(loss)


 

x, x1, x2 = net(X)

 

pca = PCA(n_components = 2)

reduced_X = pca.fit_transform(x1.detach().numpy())



 

red_x, red_y = [], []              # 第一类数据点

blue_x, blue_y = [], []            # 第二类数据点

green_x, green_y = [], []          # 第三类数据点

 

for i in range(len(reduced_X)):    # 按照鸢尾花的类别将降维后的数据点保存在不同的列表中。

    if y[i] == 0:

        red_x.append(reduced_X[i][0])

        red_y.append(reduced_X[i][1])

    elif y[i] == 1:

        blue_x.append(reduced_X[i][0])

        blue_y.append(reduced_X[i][1])

    else:

        green_x.append(reduced_X[i][0])

        green_y.append(reduced_X[i][1])

 

plt.scatter(red_x, red_y, c='r', marker='x')

plt.scatter(blue_x, blue_y, c='b', marker='D')

plt.scatter(green_x, green_y, c='g', marker='.')

plt.show()





 

pca = PCA(n_components = 2)

reduced_X = pca.fit_transform(x2.detach().numpy())


 

red_x, red_y = [], []              # 第一类数据点

blue_x, blue_y = [], []            # 第二类数据点

green_x, green_y = [], []          # 第三类数据点

 

for i in range(len(reduced_X)):    # 按照鸢尾花的类别将降维后的数据点保存在不同的列表中。

    if y[i] == 0:

        red_x.append(reduced_X[i][0])

        red_y.append(reduced_X[i][1])

    elif y[i] == 1:

        blue_x.append(reduced_X[i][0])

        blue_y.append(reduced_X[i][1])

    else:

        green_x.append(reduced_X[i][0])

        green_y.append(reduced_X[i][1])

 

plt.scatter(red_x, red_y, c='r', marker='x')

plt.scatter(blue_x, blue_y, c='b', marker='D')

plt.scatter(green_x, green_y, c='g', marker='.')

plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值