[PyTorch练手]使用CNN图像分类

[Pytorch练手]使用CNN图像分类

需求

在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类

在这里插入图片描述
如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1类

想法

  1. 通过numpy、PIL构造4*4的图像数据集
  2. 构造自己的数据集类
  3. 读取数据集对数据集选取减少偏斜
  4. cnn设计因为特征少,直接1*1卷积层
  5. 或者在4*4外围添加padding成6*6,设计2**2的卷积核得出3**3再接上全连接层

代码

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

构造数据集

import csv
import collections
import os
import shutil

def buildDataset(root,dataType,dataSize):
    """构造数据集
    构造的图片存到root/{dataType}Data
    图片地址和标签的csv文件存到 root/{dataType}DataInfo.csv
    Args:
        root:str
            项目目录
        dataType:str
            'train'或者‘test'
        dataNum:int
            数据大小
    Returns:
    """
    dataInfo = []
    dataPath = f'{root}/{dataType}Data'
    if not os.path.exists(dataPath):
        os.makedirs(dataPath)
    else:
        shutil.rmtree(dataPath)
        os.mkdir(dataPath)
        
    for i in range(dataSize):
        # 创建0,1 数组
        imageArray=np.random.randint(0,2,(4,4))
        # 计算0,1数量得到标签
        allBlackNum = collections.Counter(imageArray.flatten())[0]
        innerBlackNum = collections.Counter(imageArray[1:3,1:3].flatten())[0]
        label = 0 if (allBlackNum-innerBlackNum)>innerBlackNum else 1
        # 将图片保存
        path = f'{dataPath}/{i}.jpg'
        dataInfo.append([path,label])
        im = Image.fromarray(np.uint8(imageArray*255))
        im = im.convert('1') 
        im.save(path)
    # 将图片地址和标签存入csv文件
    filePath = f'{root}/{dataType}DataInfo.csv'
    with open(filePath, 'w') as f:
        writer = csv.writer(f)
        writer.writerows(dataInfo)

root=r'/Users/null/Documents/PythonProject/Classifier'
构造训练数据集
buildDataset(root,'train',20000)
构造测试数据集
buildDataset(root,'test',10000)

读取数据集

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, root, datacsv, transform=None):
        super(MyDataset, self).__init__()
        with open(f'{root}/{datacsv}', 'r') as f:
            imgs = []
            # 读取csv信息到imgs列表
            for path,label in map(lambda line:line.rstrip()
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值