(3-22日总结—数据集之csv文件的读取与应用)

(3-22日总结—数据集之csv文件的读取与应用)


一、处理csv文件,并分类好图片存入新的文件夹

示例:数据集往往会搭配csv文件,一般以 {filename, label} 的形式。
参考代码:

import csv
import os
from PIL import Image
train_csv_path="/data/DataBase/mini-imagenet/train.csv"
val_csv_path="/data/DataBase/mini-imagenet/val.csv"
test_csv_path="/data/DataBase/mini-imagenet/test.csv"

# label{row[0], row[1]} 字典记录
train_label={}
with open(train_csv_path) as csvfile:
    csv_reader=csv.reader(csvfile)
    birth_header=next(csv_reader)
    for row in csv_reader:
        train_label[row[0]]=row[1]
 
img_path="/data/DataBase/mini-imagenet/mini-imagenet/images"
new_img_path="/data/DataBase/mini-imagenet/mini-imagenet/ok"
for png in os.listdir(img_path):
    path = img_path+ '/' + png
    if path == '/data/DataBase/mini-imagenet/mini-imagenet/images/.DS_Store':
        continue
    im=Image.open(path)
    if(png in train_label.keys()):
        tmp=train_label[png]
        temp_path=new_img_path+'/train'+'/'+tmp
        if(os.path.exists(temp_path)==False):
            os.makedirs(temp_path)
        t=temp_path+'/'+png
        im.save(t)
        # with open(temp_path, 'wb') as f:
        #     f.write(path)
 

 
 
 

二、数据加载与处理:

**注意数据的:
1.大小
2.类型: (tensor, numpy, numpy_array, pandas, list, tuple(str,int, float, bool))

一、数据的加载和处理:
链接: https://www.jianshu.com/p/6e22d21c84be.

二、读取数据集的两种预处理:

两种数据集:

  1. 所有图片都在同一个文件夹内。(这个用 torch.utils.data.DataSet类就行!)
  2. 不同类别的图片放在不同的文件夹。(用 torchvision.datasets.ImageFolder(‘image_dir_root’ )
    链接: https://blog.csdn.net/Hungryof/article/details/76649006.

三、imread 的几种方法:
链接:各种imread使用方法

四大读取方式使用方法返回对象-通道
PltPIL.Image.open + numpy普通对象–RGB
matplotlibmatplot.image.imreadNumpyArray–RGB
opencvcv2.imreadNumpyArray–BGR
skimageskimage.io.imreadNumpyArray–RGB
skimagecaffe.io.load_imageFloat(0~1)–RGB
#encoding=utf8
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import skimage
import sys
from skimage import io 

#PIL
#相关:scipy.misc.imread, scipy.ndimage.imread
#misc.imread 提供可选参数mode,但本质上是调用PIL,具体的模式可以去看srccode或者document
#https://github.com/scipy/scipy/blob/v0.17.1/scipy/misc/pilutil.py
imagepath='test1.jpg'
im1=Image.open(imagepath)
im1=np.array(im1)#获得numpy对象,RGB
print(type(im1))
print(im1.shape)

#2 opencv 
im2=cv2.imread(imagepath)
print(type(im2))#numpy BGR
print(im2.shape)#[width,height,3]

#3 matplotlib 类似matlab的方式
im3 = mpimg.imread(imagepath)
print(type(im3))#np.array
print(im3.shape)

#4 skimge 
#caffe.io.load_iamge()也是调用的skimage实现的,返回的是0-1 float型数据
im4 = io.imread(imagepath)
print(type(im4))#np.array
print(im4.shape)
#print(im4)


# cv2.imshow('test',im4)
# cv2.waitKey()
#统一使用plt进行显示,不管是plt还是cv2.imshow,在python中只认numpy.array,但是由于cv2.imread 的图片是BGR,cv2.imshow 时相应的换通道显示
plt.subplot(221)
plt.title('PIL read')
plt.imshow(im1)

#plt.axis('off') # 不显示坐标轴
plt.show()

1.1头文件导入

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

1.2 pandas读取csv文件

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

显示从“csv文件”中提取出来的(filename,label)

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
''' (face/img_name, landmarks) '''
show_landmarks(io.imread(os.path.join('faces/', img_name)),
               landmarks)
plt.show()

1.3 继承Dataset类(from torchvision import Dataset)

主要通过重写__getitem__方法

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        # image为numpy.ndarray
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                    root_dir='faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

Dataloader 读取数据集

import torch
torch.utils.data.Dataloader

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
                    landmarks_batch[i, :, 1].numpy(),
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值