python构建cnn图片分类_Pytorch 使用CNN图像分类的实现

本文介绍了如何使用Pytorch构建CNN进行图像分类。首先,通过numpy和PIL构造了一个4*4图像数据集,并调整数据集以避免偏斜。接着,定义了CNN模型,包括两种不同的网络结构,然后进行了训练并展示了训练结果。最后,测试了模型的预测性能。
摘要由CSDN通过智能技术生成

需求

在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类

如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1类

想法

通过numpy、PIL构造4*4的图像数据集

构造自己的数据集类

读取数据集对数据集选取减少偏斜

cnn设计因为特征少,直接1*1卷积层

或者在4*4外围添加padding成6*6,设计2*2的卷积核得出3*3再接上全连接层

代码

import torch

import torchvision

import torchvision.transforms as transforms

import numpy as np

from PIL import Image

构造数据集

import csv

import collections

import os

import shutil

def buildDataset(root,dataType,dataSize):

"""构造数据集

构造的图片存到root/{dataType}Data

图片地址和标签的csv文件存到 root/{dataType}DataInfo.csv

Args:

root:str

项目目录

dataType:str

'train'或者‘test'

dataNum:int

数据大小

Returns:

"""

dataInfo = []

dataPath = f'{root}/{dataType}Data'

if not os.path.exists(dataPath):

os.makedirs(dataPath)

else:

shutil.rmtree(dataPath)

os.mkdir(dataPath)

for i in range(dataSize):

# 创建0,1 数组

imageArray=np.random.randint(0,2,(4,4))

# 计算0,1数量得到标签

allBlackNum = collections.Counter(imageArray.flatten())[0]

innerBlackNum = collections.Counter(imageArray[1:3,1:3].flatten())[0]

label = 0 if (allBlackNum-innerBlackNum)>innerBlackNum else 1

# 将图片保存

path = f'{dataPath}/{i}.jpg'

dataInfo.append([path,label])

im = Image.fromarray(np.uint8(imageArray*255))

im = im.convert('1')

im.save(path)

# 将图片地址和标签存入csv文件

filePath = f'{root}/{dataType}DataInfo.csv'

with open(filePath, 'w') as f:

writer = csv.writer(f)

writer.writerows(dataInfo)

root=r'/Users/null/Documents/PythonProject/Classifier'

构造训练数据集

buildDataset(root,'train',20000)

构造测试数据集

buildDataset(root,'test',10000)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值