(3-22日总结—数据集之csv文件的读取与应用)
一、处理csv文件,并分类好图片存入新的文件夹
示例:数据集往往会搭配csv文件,一般以 {filename, label} 的形式。
参考代码:
import csv
import os
from PIL import Image
train_csv_path="/data/DataBase/mini-imagenet/train.csv"
val_csv_path="/data/DataBase/mini-imagenet/val.csv"
test_csv_path="/data/DataBase/mini-imagenet/test.csv"
# label{row[0], row[1]} 字典记录
train_label={}
with open(train_csv_path) as csvfile:
csv_reader=csv.reader(csvfile)
birth_header=next(csv_reader)
for row in csv_reader:
train_label[row[0]]=row[1]
img_path="/data/DataBase/mini-imagenet/mini-imagenet/images"
new_img_path="/data/DataBase/mini-imagenet/mini-imagenet/ok"
for png in os.listdir(img_path):
path = img_path+ '/' + png
if path == '/data/DataBase/mini-imagenet/mini-imagenet/images/.DS_Store':
continue
im=Image.open(path)
if(png in train_label.keys()):
tmp=train_label[png]
temp_path=new_img_path+'/train'+'/'+tmp
if(os.path.exists(temp_path)==False):
os.makedirs(temp_path)
t=temp_path+'/'+png
im.save(t)
# with open(temp_path, 'wb') as f:
# f.write(path)
二、数据加载与处理:
**注意数据的:
1.大小
2.类型: (tensor, numpy, numpy_array, pandas, list, tuple(str,int, float, bool))
一、数据的加载和处理:
链接: https://www.jianshu.com/p/6e22d21c84be.
二、读取数据集的两种预处理:
两种数据集:
- 所有图片都在同一个文件夹内。(这个用 torch.utils.data.DataSet类就行!)
- 不同类别的图片放在不同的文件夹。(用 torchvision.datasets.ImageFolder(‘image_dir_root’ )
链接: https://blog.csdn.net/Hungryof/article/details/76649006.
三、imread 的几种方法:
链接:各种imread使用方法
四大读取方式 | 使用方法 | 返回对象-通道 |
---|---|---|
Plt | PIL.Image.open + numpy | 普通对象–RGB |
matplotlib | matplot.image.imread | NumpyArray–RGB |
opencv | cv2.imread | NumpyArray–BGR |
skimage | skimage.io.imread | NumpyArray–RGB |
skimage | caffe.io.load_image | Float(0~1)–RGB |
#encoding=utf8
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import skimage
import sys
from skimage import io
#PIL
#相关:scipy.misc.imread, scipy.ndimage.imread
#misc.imread 提供可选参数mode,但本质上是调用PIL,具体的模式可以去看srccode或者document
#https://github.com/scipy/scipy/blob/v0.17.1/scipy/misc/pilutil.py
imagepath='test1.jpg'
im1=Image.open(imagepath)
im1=np.array(im1)#获得numpy对象,RGB
print(type(im1))
print(im1.shape)
#2 opencv
im2=cv2.imread(imagepath)
print(type(im2))#numpy BGR
print(im2.shape)#[width,height,3]
#3 matplotlib 类似matlab的方式
im3 = mpimg.imread(imagepath)
print(type(im3))#np.array
print(im3.shape)
#4 skimge
#caffe.io.load_iamge()也是调用的skimage实现的,返回的是0-1 float型数据
im4 = io.imread(imagepath)
print(type(im4))#np.array
print(im4.shape)
#print(im4)
# cv2.imshow('test',im4)
# cv2.waitKey()
#统一使用plt进行显示,不管是plt还是cv2.imshow,在python中只认numpy.array,但是由于cv2.imread 的图片是BGR,cv2.imshow 时相应的换通道显示
plt.subplot(221)
plt.title('PIL read')
plt.imshow(im1)
#plt.axis('off') # 不显示坐标轴
plt.show()
1.1头文件导入
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
1.2 pandas读取csv文件
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
显示从“csv文件”中提取出来的(filename,label)
def show_landmarks(image, landmarks):
"""Show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
''' (face/img_name, landmarks) '''
show_landmarks(io.imread(os.path.join('faces/', img_name)),
landmarks)
plt.show()
1.3 继承Dataset类(from torchvision import Dataset)
主要通过重写__getitem__方法
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
# image为numpy.ndarray
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
Dataloader 读取数据集
import torch
torch.utils.data.Dataloader
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
landmarks_batch[i, :, 1].numpy(),
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break