Python常用模块的使用技巧
目录
(6)代码分析工具 Pylint安装+pycharm下的配置
(1)matplotlib.image、PIL.Image、cv2图像读取模块
(3)python中PIL.Image和OpenCV图像格式相互转换
1.Python配置说明
(1)Python注释说明
在pyCharm中File->Setting->Editor->File and Code Templates->Python Script:
# -*-coding: utf-8 -*-
"""
@Project: ${PROJECT_NAME}
@File : ${NAME}.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : ${YEAR}-${MONTH}-${DAY} ${HOUR}:${MINUTE}:${SECOND}
"""
# -*-coding: utf-8 -*-
"""
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : ${YEAR}-${MONTH}-${DAY} ${HOUR}:${MINUTE}:${SECOND}
@Brief :
"""
(2)函数说明
def my_fun(para1,para2):
'''
函数功能实现简介
:param para1: 输入参数说明,类型
:param para2: 输入参数说明,类型
:return: 返回内容,类型
'''
(3)ipynb文件转.py文件
jupyter nbconvert --to script demo.ipynb
(4)Python计算运行时间
import datetime
def RUN_TIME(deta_time):
'''
返回毫秒,deta_time.seconds获得秒数=1000ms,deta_time.microseconds获得微妙数=1/1000ms
:param deta_time: ms
:return:
'''
time_=deta_time.seconds * 1000 + deta_time.microseconds / 1000.0
return time_
T0 = datetime.datetime.now()
# do something
T1 = datetime.datetime.now()
print("rum time:{}".format(RUN_TIME(T1-T0)))
(5)镜像加速方法
TUNA 还提供了 Anaconda 仓库的镜像,运行以下命令:
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --set show_channel_urls yes
设置上述镜像后,瞬间提速,但该镜像仅限该命令窗口有效
windows 下在用户目录下面创建pip,然后创建pip.ini文件,把阿里的源复制进去:
[global]
trusted-host=mirrors.aliyun.com
index-url = http://mirrors.aliyun.com/pypi/simple/
Linux下,修改 ~/.pip/pip.conf (没有就创建一个文件夹及文件。文件夹要加“.”,表示是隐藏文件夹)
内容如下:
[global] index-url = https://pypi.tuna.tsinghua.edu.cn/simple [install] trusted-host=mirrors.aliyun.com
windows下,直接在user目录中创建一个pip目录,如:C:\Users\xx\pip,新建文件pip.ini。内容同上。
临时的方法:pip时加上"-i https://mirrors.aliyun.com/pypi/simple/":,如
pip install opencv-python -i https://mirrors.aliyun.com/pypi/simple/
(6)代码分析工具 Pylint安装+pycharm下的配置
代码分析工具 Pylint安装+pycharm下的配置 - oohy - 博客园
(7)Python添加环境路径和搜索路径的方法
添加环境路径:
# 添加graphviz环境路径
import os
os.environ["PATH"] += os.pathsep + 'D:/ProgramData/Anaconda3/envs/pytorch-py36/Library/bin/graphviz/'
搜索路径:
import sys
import os
# 打印当前python搜索模块的路径集
print(sys.path)
# 打印当前文件所在路径
print("os.path.dirname(__file__):", os.path.dirname(__file__))
print("os.getcwd(): ", os.getcwd()) # get current work directory:cwd:获得当前工作目录
'''添加相关的路径
sys.path.append(‘你的模块的名称’)。
sys.path.insert(0,’模块的名称’)
'''
# 先添加image_processing所在目录路径
sys.path.append("F:/project/python-learning-notes/utils")
# sys.path.append(os.getcwd())
# 再倒入该包名
import image_processing
#
os.environ["PATH"] += os.pathsep + 'D:/ProgramData/Anaconda3/envs/pytorch-py36/Library/bin/graphviz/'
image_path = "F:/project/python-learning-notes/dataset/test_image/1.jpg"
image = image_processing.read_image(image_path)
image_processing.cv_show_image("image", image)
(8)conda常用命令
- 列举当前所有环境:conda info --envs 或者conda env list
- 生成一个
environment.yml
文件:conda env export > environment.yml- 根据
environment.yml
文件安装该环境:conda env create -f environment.yml- 列举当前活跃环境下的所有包:conda list
- 参数某个环境:conda remove --name your_env_name --all
2.常用的模块
2.1 numpy模块:
(1)矩阵的拼接和分割,奇偶项分割数据
# 产生5*2的矩阵数据
data1=np.arange(0,10)
data1=data1.reshape([5,2])
# 矩阵拼接
y = np.concatenate([data1, data2], 0)
# 矩阵拼接
def cat_labels_indexMat(labels,indexMat):
indexMat_labels = np.concatenate([labels,indexMat], axis=1)
return indexMat_labels
# 矩阵分割
def split_labels_indexMat(indexMat_labels,label_index=0):
labels = indexMat_labels[:, 0:label_index+1] # 第一列是labels
indexMat = indexMat_labels[:, label_index+1:] # 其余是indexMat
return labels, indexMat
def split_data(data):
'''
按照奇偶项分割数据
:param data:
:return:
'''
data1 = data[0::2]
data2 = data[1::2]
return data1,data2
if __name__=='__main__':
data = np.arange(0, 20)
data = data.reshape([10, 2])
data1,data2=split_data(data)
print("embeddings:{}".format(data))
print("embeddings1:{}".format(data1))
print("embeddings2:{}".format(data2))
(2)按照列进行排序
pair_issame = pair_issame[np.lexsort(pair_issame.T)]#按最后一列进行排序
(3)提取符合条件的某行某列
假设有数据:pair_issame:
如果想提取第三列的为"1"的数据,可以这样:
pair_issame_1 = pair_issame[pair_issame[:, -1] == "1", :] # 筛选数组
(4)查找符合条件的向量
import numpy as np
def matching_data_vecror(data, vector):
'''
从data中匹配vector向量,查找出现vector的index,如:
data = [[1., 0., 0.],[0., 0., 0.],[2., 0., 0.],
[0., 0., 0.],[0., 3., 0.],[0., 0., 4.]]
# 查找data中出现[0, 0, 0]的index
data = np.asarray(data)
vector=[0, 0, 0]
index =find_index(data,vector)
print(index)
>>[False True False True False False]
# 实现去除data数组中元素为[0, 0, 0]的行向量
pair_issame_1 = data[~index, :] # 筛选数组
:param data:
:param vector:
:return:
'''
# index = (data[:, 0] == 0) & (data[:, 1] == 0) & (data[:, 2] == 0)
row_nums = len(data)
clo_nums = len(vector)
index = np.asarray([True] * row_nums)
for i in range(clo_nums):
index = index & (data[:, i] == vector[i])
return index
def set_mat_vecror(data, index, vector):
'''
实现将data指定index位置的数据设置为vector
# 实现将大于阈值分数的point,设置为vector = [10, 10]
point = [[0., 0.], [1., 1.], [2., 2.],
[3., 3.], [4., 4.], [5., 5.]]
point = np.asarray(point) # 每个数据点
score = np.array([0.7, 0.2, 0.3, 0.4, 0.5, 0.6])# 每个数据点的分数
score_th=0.5
index = np.where(score > score_th) # 获得大于阈值分数的所有下标
vector = [10, 10] # 将大于阈值的数据设置为vector
out = set_mat_vecror(point, index, vector)
:param data:
:param index:
:param vector:
:return:
'''
data[index, :] = vector
return data
(5)打乱顺序
python numpy array random 随机排列(打乱训练数据)_Song_Lynn的博客-CSDN博客_numpy 随机排列
per = np.random.permutation(pair_issame_1.shape[0]) # 打乱后的行号
pair_issame_1 = pair_issame_0[per, :] # 获取打乱后的数据
2.2 pickle模块
pickle可以存储什么类型的数据呢?
- 所有python支持的原生类型:布尔值,整数,浮点数,复数,字符串,字节,None。
- 由任何原生类型组成的列表,元组,字典和集合。
- 函数,类,类的实例
import pickle
import numpy as np
def save_data(data, file):
with open(file, 'wb') as f:
pickle.dump(data, f)
def load_data(file):
with open(file, 'rb') as f:
data = pickle.load(f)
return data
if __name__ == "__main__":
data1 = ['aa', 'bb', 'cc'] # list
data1=np.asarray(data1) # ndarray
data_path = "data.pk"
save_data(data1, data_path)
data2 = load_data(data_path)
print(data1)
print(data2)
2.3 random.shuffle产生固定种子
files_list=...
labels_list=...
shuffle=True
if shuffle:
# seeds = random.randint(0,len(files_list)) #产生一个随机数种子
seeds = 100 # 固定种子,只要seed的值一样,后续生成的随机数都一样
random.seed(seeds)
random.shuffle(files_list)
random.seed(seeds)
random.shuffle(labels_list)
2.4 zip()与zip(*) 函数:
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
zip 方法在 Python 2 和 Python 3 中的不同:在 Python 3.x 中为了减少内存,zip() 返回的是一个对象。如需展示列表,需手动 list() 转换。
a = [1,2,3]
b = [4,5,6]
c = [4,5,6,7,8]
zipped = zip(a,b) # 打包为元组的列表
# 结果:[(1, 4), (2, 5), (3, 6)]
zip(a,c) # 元素个数与最短的列表一致
# 结果:[(1, 4), (2, 5), (3, 6)]
zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
# 结果:[(1, 2, 3), (4, 5, 6)]
2.5 map、for快速遍历方法:
# 假设files_list为:
files_list=['../training_data/test\\0.txt', '../training_data/test\\1.txt', '../training_data/test\\2.txt', '../training_data/test\\3.txt', '../training_data/test\\4.txt', '../training_data/test\\5.txt', '../training_data/test\\6.txt']
# 下面的三个方法都是现实获得files_list的文件名
files_nemes1=list(map(lambda s: os.path.basename(s),files_list))
files_nemes2=list(os.path.basename(i)for i in files_list)
files_nemes3=[os.path.basename(i)for i in files_list]
2.6 glob模块
glob模块是最简单的模块之一,内容非常少。用它可以查找符合特定规则的文件路径名。跟使用windows下的文件搜索差不多。查找文件只用到三个匹配符:"*", "?", "[]"。"*"匹配0个或多个字符;"?"匹配单个字符;"[]"匹配指定范围内的字符,如:[0-9]匹配数字。
import glob
#获取指定目录下的所有图片
print glob.glob(r"E:\Picture\*\*.jpg")
#获取上级目录的所有.py文件
print glob.glob(r'../*.py') #相对路径
对于遍历指定目录的jpg图片,可以这样:
# -*- coding:utf-8 -*-
import glob
#遍历指定目录下的jpg图片
image_path="/home/ubuntu/TFProject/view-finding-network/test_images/*.jpg"
for per_path in glob.glob(image_path):
print(per_path)
若想遍历多个格式的文件,可以这样:
# 遍历'jpg','png','jpeg'的图片
image_format=['jpg','png','jpeg']#图片格式
image_dir='./test_image' #图片目录
image_list=[]
for format in image_format:
path=image_dir+'/*.'+format
image_list.extend(glob.glob(path))
print(image_list)
2.7 os模块
import os
os.getcwd()#获得当前工作目录
os.path.abspath('.')#获得当前工作目录
os.path.abspath('..')#获得当前工作目录的父目录
os.path.abspath(os.curdir)#获得当前工作目录
os.path.join(os.getcwd(),'filename')#获取当前目录,并组合成新目录
os.path.exists(path)#判断文件是否存在
os.path.isfile(path)#如果path是一个存在的文件,返回True。否则返回False。
os.path.basename('path/to/test.jpg')#获得路径下的文件名:test.jpg
os.path.getsize(path) #返回文件大小,如果文件不存在就返回错误
path=os.path.dirname('path/to/test.jpg')#获得路径:path/to
os.sep#当前操作系统的路径分隔符,Linux/UNIX是‘/’,Windows是‘\\’
dirname='path/to/test.jpg'.split(os.sep)[-1]#获得当前文件夹的名称“test.jpg”
dirname='path/to/test.jpg'.split(os.sep)[-2]#获得当前文件夹的名称“to”
# 删除该目录下的所有文件
def delete_dir_file(dir_path):
ls = os.listdir(dir_path)
for i in ls:
c_path = os.path.join(dir_path, i)
if os.path.isdir(c_path):
delete_dir_file(c_path)
else:
os.remove(c_path)
# 若目录不存在,则创建新的目录(只能创建一级目录)
if not os.path.exists(out_dir):
os.mkdir(out_dir)
# 创建多级目录
if not os.path.exists(segment_out_name):
os.makedirs(segment_out_dir)
# 删除该目录下的所有文件
delete_dir_file(out_dir)
# 或者:
shutil.rmtree(out_dir) # delete output folder
下面是实现:【1】getFilePathList:获取file_dir目录下,所有文本路径,包括子目录文件,【2】get_files_list:获得file_dir目录下,后缀名为postfix所有文件列表,包括子目录, 【3】gen_files_labels: 获取files_dir路径下所有文件路径,以及labels,其中labels用子级文件名表示
# coding: utf-8
import os
import os.path
import pandas as pd
def getFilePathList(file_dir):
'''
获取file_dir目录下,所有文本路径,包括子目录文件
:param rootDir:
:return:
'''
filePath_list = []
for walk in os.walk(file_dir):
part_filePath_list = [os.path.join(walk[0], file) for file in walk[2]]
filePath_list.extend(part_filePath_list)
return filePath_list
def get_files_list(file_dir,postfix='ALL'):
'''
获得file_dir目录下,后缀名为postfix所有文件列表,包括子目录
:param file_dir:
:param postfix:
:return:
'''
postfix=postfix.split('.')[-1]
file_list=[]
filePath_list = getFilePathList(file_dir)
if postfix=='ALL':
file_list=filePath_list
else:
for file in filePath_list:
basename=os.path.basename(file) # 获得路径下的文件名
postfix_name=basename.split('.')[-1]
if postfix_name==postfix:
file_list.append(file)
file_list.sort()
return file_list
def gen_files_labels(files_dir):
'''
获取files_dir路径下所有文件路径,以及labels,其中labels用子级文件名表示
files_dir目录下,同一类别的文件放一个文件夹,其labels即为文件的名
:param files_dir:
:return:filePath_list所有文件的路径,label_list对应的labels
'''
filePath_list = getFilePathList(files_dir)
print("files nums:{}".format(len(filePath_list)))
# 获取所有样本标签
label_list = []
for filePath in filePath_list:
label = filePath.split(os.sep)[-2]
label_list.append(label)
labels_set=list(set(label_list))
print("labels:{}".format(labels_set))
# 标签统计计数
print(pd.value_counts(label_list))
return filePath_list,label_list
if __name__=='__main__':
file_dir='JPEGImages'
file_list=get_files_list(file_dir)
for file in file_list:
print(file)
实现遍历dir目录下,所有文件(包含子文件夹的文件)
# coding: utf-8
import os
import os.path
def get_files_list(dir):
'''
实现遍历dir目录下,所有文件(包含子文件夹的文件)
:param dir:指定文件夹目录
:return:包含所有文件的列表->list
'''
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
files_list=[]
for parent, dirnames, filenames in os.walk(dir):
for filename in filenames:
# print("parent is: " + parent)
# print("filename is: " + filename)
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
files_list.append([os.path.join(parent, filename)])
return files_list
if __name__=='__main__':
dir = 'images'
files_list=get_files_list(dir)
print(files_list)
下面是一个封装好的get_input_list()函数,path是文件夹,则遍历所有png,jpg,jpeg等图像文件, path是txt文件路径,则读取txt中保存的文件列表(不要出现多余一个的空行),path是单个图片文件:path/to/1.png。
# -*-coding: utf-8 -*-
"""
@Project: hdrnet
@File : my_test.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-28 14:30:51
"""
import os
import logging
import re
logging.basicConfig(format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
log = logging.getLogger("train")
log.setLevel(logging.INFO)
def get_input_list(path):
'''
返回所有图片的路径
:param path:单张图片的路径,或文件夹,或者txt文件
:return:
'''
regex = re.compile(".*.(png|jpeg|jpg|tif|tiff)")
# path是文件夹,则遍历所有png,jpg,jpeg等图像文件
# path/to
if os.path.isdir(path):
inputs = os.listdir(path)
inputs = [os.path.join(path, f) for f in inputs if regex.match(f)]
log.info("Directory input {}, with {} images".format(path, len(inputs)))
# path是txt文件路径,则读取txt中保存的文件列表(不要出现多余一个的空行)
# path/to/filelist.txt
elif os.path.splitext(path)[-1] == ".txt":
dirname = os.path.dirname(path)
with open(path, 'r') as fid:
inputs = [l.strip() for l in fid.readlines()]
inputs = [os.path.join(dirname, im) for im in inputs]
log.info("Filelist input {}, with {} images".format(path, len(inputs)))
# path是单个图片文件:path/to/1.png
elif regex.match(path):
inputs = [path]
log.info("Single input {}".format(path))
return inputs
if __name__ == '__main__':
path='dataset/filelist.txt';
result=get_input_list(path);
print(result);
2.8 判断图像文件为空和文件不存,文件过小
def isValidImage(images_list,sizeTh=1000,isRemove=False):
''' 去除不存的文件和文件过小的文件列表
:param images_list:
:param sizeTh: 文件大小阈值,单位:字节B,默认1000B
:param isRemove: 是否在硬盘上删除被损坏的原文件
:return:
'''
i=0
while i<len(images_list):
path=images_list[i]
# 判断文件是否存在
if not (os.path.exists(path)):
print(" non-existent file:{}".format(path))
images_list.pop(i)
continue
# 判断文件是否为空
if os.path.getsize(path)<sizeTh:
print(" empty file:{}".format(path))
if isRemove:
os.remove(path)
print(" info:----------------remove image:{}".format(path))
images_list.pop(i)
continue
# 判断图像文件是否损坏
try:
Image.open(path).verify()
except :
print(" damaged image:{}".format(path))
if isRemove:
os.remove(path)
print(" info:----------------remove image:{}".format(path))
images_list.pop(i)
continue
i += 1
return images_list
2.9 保存多维array数组的方法
由于np.savetxt()不能直接保存三维以上的数组,因此需要转为向量的形式来保存
import numpy as np
arr1 = np.zeros((3,4,5), dtype='int16') # 创建3*4*5全0三维数组
print("维度:",np.shape(arr1))
arr1[0,:,:]=0
arr1[1,:,:]=1
arr1[2,:,:]=2
print("arr1=",arr1)
# 由于savetxt不能保存三维以上的数组,因此需要转为向量来保存
vector=arr1.reshape((-1,1))
np.savetxt("data.txt", vector)
data= np.loadtxt("data.txt")
print("data=",data)
arr2=data.reshape(arr1.shape)
print("arr2=",arr2)
2.10读取txt数据的方法
这是封装好的txt读写模块,这里输入和输出的数据都是list列表:
# -*-coding: utf-8 -*-
"""
@Project: TxtStorage
@File : TxtStorage.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-07-12 17:32:47
"""
from numpy import *
class TxtStorage:
# def __init__(self):
def write_txt(self, content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line=""
for col,data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line=str_line+str(data)+" "
else:
# 每行最后一个数据用换行符“\n”
str_line=str_line+str(data)+"\n"
f.write(str_line)
def read_txt(self, fileName):
"""读取txt数据函数
:param filename:文件名
:return: txt的数据列表
:rtype: list
Python中有三个去除头尾字符、空白符的函数,它们依次为:
strip: 用来去除头尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
lstrip:用来去除开头字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
注意:这些函数都只会删除头和尾的字符,中间的不会删除。
"""
txtData=[]
with open(fileName, 'r') as f:
lines = f.readlines()
for line in lines:
lineData = line.rstrip().split(" ")
data=[]
for l in lineData:
if self.is_int(l): # isdigit() 方法检测字符串是否只由数字组成,只能判断整数
data.append(int(l))
elif self.is_float(l):#判断是否为小数
data.append(float(l))
else:
data.append(l)
txtData.append(data)
return txtData
def is_int(self,str):
# 判断是否为整数
try:
x = int(str)
return isinstance(x, int)
except ValueError:
return False
def is_float(self,str):
# 判断是否为整数和小数
try:
x = float(str)
return isinstance(x, float)
except ValueError:
return False
if __name__ == '__main__':
txt_filename = 'test.txt'
w_data = [['1.jpg', 'dog', 200, 300,1.0], ['2.jpg', 'dog', 20, 30,-2]]
print("w_data=",w_data)
txt_str = TxtStorage()
txt_str.write_txt(w_data, txt_filename, mode='w')
r_data = txt_str.read_txt(txt_filename)
print('r_data=',r_data)
一个读取TXT文本数据的常用操作:
# -*-coding: utf-8 -*-
"""
@Project: TxtStorage
@File : TxtStorage.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-07-12 17:32:47
"""
from numpy import *
def write_txt(content, filename, mode='w'):
"""保存txt数据
:param content:需要保存的数据,type->list
:param filename:文件名
:param mode:读写模式:'w' or 'a'
:return: void
"""
with open(filename, mode) as f:
for line in content:
str_line = ""
for col, data in enumerate(line):
if not col == len(line) - 1:
# 以空格作为分隔符
str_line = str_line + str(data) + " "
else:
# 每行最后一个数据用换行符“\n”
str_line = str_line + str(data) + "\n"
f.write(str_line)
def read_txt(fileName):
"""读取txt数据函数
:param filename:文件名
:return: txt的数据列表
:rtype: list
Python中有三个去除头尾字符、空白符的函数,它们依次为:
strip: 用来去除头尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
lstrip:用来去除开头字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
注意:这些函数都只会删除头和尾的字符,中间的不会删除。
"""
txtData = []
with open(fileName, 'r') as f:
lines = f.readlines()
for line in lines:
lineData = line.rstrip().split(" ")
data = []
for l in lineData:
if is_int(l): # isdigit() 方法检测字符串是否只由数字组成,只能判断整数
data.append(int(l))
elif is_float(l): # 判断是否为小数
data.append(float(l))
else:
data.append(l)
txtData.append(data)
return txtData
def is_int(str):
# 判断是否为整数
try:
x = int(str)
return isinstance(x, int)
except ValueError:
return False
def is_float(str):
# 判断是否为整数和小数
try:
x = float(str)
return isinstance(x, float)
except ValueError:
return False
def merge_list(data1,data2):
'''
将两个list进行合并
:param data1:
:param data2:
:return:返回合并后的list
'''
if not len(data1)==len(data2):
return
all_data=[]
for d1,d2 in zip(data1,data2):
all_data.append(d1+d2)
return all_data
def split_list(data,split_index=1):
'''
将data切分成两部分
:param data: list
:param split_index: 切分的位置
:return:
'''
data1=[]
data2=[]
for d in data:
d1=d[0:split_index]
d2=d[split_index:]
data1.append(d1)
data2.append(d2)
return data1,data2
if __name__ == '__main__':
txt_filename = 'test.txt'
w_data = [['1.jpg', 'dog', 200, 300, 1.0], ['2.jpg', 'dog', 20, 30, -2]]
print("w_data=", w_data)
write_txt(w_data, txt_filename, mode='w')
r_data = read_txt(txt_filename)
print('r_data=', r_data)
data1,data2=split_list(w_data)
mer_data=merge_list(data1,data2)
print('mer_data=', mer_data)
读取以下txt文件,可使用以下方法:
test_image/dog/1.jpg 0 11
test_image/dog/2.jpg 0 12
test_image/dog/3.jpg 0 13
test_image/dog/4.jpg 0 14
test_image/cat/1.jpg 1 15
test_image/cat/2.jpg 1 16
test_image/cat/3.jpg 1 17
test_image/cat/4.jpg 1 18
def load_image_labels(test_files):
'''
载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
:param test_files:
:return:
'''
images_list=[]
labels_list=[]
with open(test_files) as f:
lines = f.readlines()
for line in lines:
#rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
content=line.rstrip().split(' ')
name=content[0]
labels=[]
for value in content[1:]:
labels.append(float(value))
images_list.append(name)
labels_list.append(labels)
return images_list,labels_list
2.11 pandas模块
(1)文件数据拼接
假设有'data1.txt', 'data2.txt', 'data3.txt'数据:
#'data1.txt'
1.jpg 11
2.jpg 12
3.jpg 13
#'data2.txt'
1.jpg 110
2.jpg 120
3.jpg 130
#'data3.txt'
1.jpg 1100
2.jpg 1200
3.jpg 1300
需要拼接成:
1.jpg 11 110 1100
2.jpg 12 120 1200
3.jpg 13 130 1300
实现代码:
# coding: utf-8
import pandas as pd
def concat_data(page,save_path):
pd_data=[]
for i in range(len(page)):
content=pd.read_csv(page[i], dtype=str, delim_whitespace=True, header=None)
if i==0:
pd_data=pd.concat([content], axis=1)
else:# 每一列数据拼接
pd_data=pd.concat([pd_data,content.iloc[:,1]], axis=1)
pd_data.to_csv(save_path, index=False, sep=' ', header=None)
if __name__=='__main__':
txt_path = ['data1.txt', 'data2.txt', 'data3.txt']
out_path = 'all_data.txt'
concat_data(txt_path,out_path)
(2)DataFrame
import pandas as pd
import numpy as np
def print_info(class_name,labels):
# index =range(len(class_name))+1
index=np.arange(0,len(class_name))+1
columns = ['class_name', 'labels']
content = np.array([class_name, labels]).T
df = pd.DataFrame(content, index=index, columns=columns) # 生成6行4列位置
print(df) # 输出6行4列的表格
class_name=['C1','C2','C3']
labels=[100,200,300]
print_info(class_name,labels)
Pandas DataFrame数据的增、删、改、查
Pandas DataFrame数据的增、删、改、查_夏雨淋河的博客-CSDN博客_dataframe修改数据
import pandas as pd
import numpy as np
df = pd.DataFrame(data = [['tom1','f',22],['tom2','f',22],['tom3','m',21]],index = [1,2,3],columns = ['name','sex','age'])#测试数据。
name | sex | age | |
---|---|---|---|
1 | tom1 | f | 22 |
2 | tom2 | f | 22 |
3 | tom3 | m | 21 |
citys = ['shenzhen1','shenzhen2','shenzhen3']
df.insert(2,'city',citys) #在第2列,加上column名称为city,值为citys的数值。
jobs = ['student','teacher','teacher']
df['job'] = jobs #默认在df最后一列加上column名称为job,值为jobs的数据。
df.loc[:,'salary'] = ['1k','2k','2k'] #在df最后一列加上column名称为salary,值为等号右边数据。
df
name | sex | city | age | job | salary | |
---|---|---|---|---|---|---|
1 | tom1 | f | shenzhen1 | 22 | student | 1k |
2 | tom2 | f | shenzhen2 | 22 | teacher | 2k |
3 | tom3 | m | shenzhen3 | 21 | teacher | 2k |
#若df中没有index为“4”的这一行的话,该行代码作用是往df中加一行index为“4”,值为等号右边值的数据。
#若df中已经有index为“4”的这一行,则该行代码作用是把df中index为“4”的这一行修改为等号右边数据。
df.loc[4] = ['tom4','m','shenzhen4',24,"engineer",'3k']
df
name | sex | city | age | job | salary | |
---|---|---|---|---|---|---|
1 | tom1 | f | shenzhen1 | 22 | student | 1k |
2 | tom2 | f | shenzhen2 | 22 | teacher | 2k |
3 | tom3 | m | shenzhen3 | 21 | teacher | 2k |
4 | tom4 | m | shenzhen4 | 24 | engineer | 3k |
# 按照age的值进行排序
df=df.sort_values(by=["age"],ascending=False)
df
name | sex | city | age | job | salary | |
---|---|---|---|---|---|---|
4 | tom4 | m | shenzhen4 | 24 | engineer | 3k |
1 | tom1 | f | shenzhen1 | 22 | student | 1k |
2 | tom2 | f | shenzhen2 | 22 | teacher | 2k |
3 | tom3 | m | shenzhen3 | 21 | teacher | 2k |
2.12 csv模块
使用csv模块读取csv文件的数据
# -*- coding:utf-8 -*-
import csv
csv_path='test.csv'
with open(csv_path,'r') as csvfile:
reader = csv.DictReader(csvfile)
for item in reader:#遍历全部元素
print(item)
with open(csv_path, 'r') as csvfile:
reader = csv.DictReader(csvfile)
for item in reader: # 遍历全部元素
print(item['filename'],item['class'],item.get('height'),item.get('width'))
运行结果:
{'filename': 'test01.jpg', 'height': '638', 'class': 'dog', 'width': '486'}
{'filename': 'test02.jpg', 'height': '954', 'class': 'person', 'width': '726'}
test01.jpg dog 638 486
test02.jpg person 954 726
读写过程:
import csv
csv_path = 'test.csv'
#写csv
data=["1.jpg",200,300,'dog']
with open(csv_path, 'w+',newline='') as csv_file:
# headers = [k for k in dictionaries[0]]
headers=['filename','width','height', 'class']
print(headers)
writer = csv.DictWriter(csv_file, fieldnames=headers)
writer.writeheader()
dictionary={'filename': data[0],
'width': data[1],
'height': data[2],
'class': data[3],
}
writer.writerow(dictionary)
print(dictionary)
#读csv
with open(csv_path, 'r') as csvfile:
reader = csv.DictReader(csvfile)
for item in reader: # 遍历全部元素
print(item)
with open(csv_path, 'r') as csvfile:
reader = csv.DictReader(csvfile)
for item in reader: # 遍历全部元素
print(item['filename'], item['class'], item.get('height'), item.get('width'))
2.13 logging模块
import logging
# level级别:debug、info、warning、error以及critical
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG,format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.debug("----1----")
logger.info("----2----")
logger.warning("----3----")
logger.error("----4----")
3. 数据预处理
3.1 数据(图像)分块处理
import numpy as np
def split_cell(mat,cell=(3,3),stepsize=(1,1)):
'''
:param mat:输入单通道的图像数据(可能有误,需要验证)
:param cell:块大小
:param stepsize: 步长stepsize<cell
:return:
'''
rows,cols=np.shape(mat)
Rx=cell[0]//2
Ry=cell[1]//2
stepX=stepsize[0]
stepY=stepsize[1]
dest=np.zeros(shape=(int((rows+stepX-1)/stepX),int((cols+stepY-1)/stepY)),dtype=np.float32)
for i in range(0,rows,stepX):
for j in range(0,cols,stepY):
x1=i-Rx
x2=i+Rx
y1=j-Ry//坐标有误
y2=j+Ry//
x1=np.clip(x1,0,rows-1)
x2=np.clip(x2,0,rows-1)
y1=np.clip(y1,0,cols-1)
y2=np.clip(y2,0,cols-1)
#计算block的平均值
block=mat[y1:(y2+1),x1:(x2+1)]
m=np.mean(block)
indexX=int((i+stepX-1)/stepX)#向上取整
indexY=int((j+stepY-1)/stepY)
dest[indexX,indexY]=m/255
# dest=dest.reshape()
return dest
def split_block(mat,grid=(7,7)):
rows,cols=grid
block_image=[]
height,width = np.shape(mat)
step_width = int(width / cols)
step_height = int( height/ rows)
for i in range(0,rows):
for j in range(0,cols):
x1 = j * step_width
x2=(j + 1) * step_width
y1 = i * step_height
y2=(i + 1) * step_height
block=mat[y1:y2,x1:x2]#注意顺序:mat[row,col]
# fea=block_feature(block, feature_type="LBP")
block_image.append(block)
return block_image
if __name__=="__main__":
data=np.arange(0,100)
image=data.reshape((20,5))
dest=split_block(image,cell=(3,3),stepsize=(1,1))
3.2 读取图片和显示
Python中读取图片和显示图片的方式很多,绝大部分图像处理模块读取图片的通道是RGB格式,只有opencv-python模块读取的图片的BGR格式,如果采用其他模块显示opencv读取的图片,需要转换通道顺序,方法也比较简单,即:
import cv2
import matplotlib.pyplot as plt
temp_img=cv2.imread(image_path) #默认:BGR(不是RGB),uint8,[0,255],ndarry()
cv2.imshow("opencv-python",temp_img5)
cv2.waitKey(0)
# b, g, r = cv2.split(temp_img5)# 将BGR转为RGB格式
# img = cv2.merge([r, g, b])
# 推荐使用cv2.COLOR_BGR2RGB->将BGR转为RGB格式
img = cv2.cvtColor(temp_img5, cv2.COLOR_BGR2RGB)
plt.imshow(img) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()
(1)matplotlib.image、PIL.Image、cv2图像读取模块
# coding: utf-8
'''
在Caffe中,彩色图像的通道要求是BGR格式,输入数据是float32类型,范围[0,255],
对每一层shape=(batch_size, channel_dim, height, width)。
[1]caffe的训练/测试prototxt文件,一般在数据层设置:cale:0.00392156885937,即1/255.0,即将数据归一化到[0,1]
[2]当输入数据为RGB图像,float32,[0,1],则需要转换:
--transformer.set_raw_scale('data',255) # 缩放至0~255
--transformer.set_channel_swap('data',(2,1,0))# 将RGB变换到BGR
[3]当输入数据是RGB图像,int8类型,[0,255],则输入数据之前必须乘以*1.0转换为float32
--transformer.set_raw_scale('data',1.0) # 数据不用缩放了
--transformer.set_channel_swap('data',(2,1,0))#将RGB变换到BGR
--通道:img = img.transpose(2, 0, 1) #通道由[h,w,c]->[c,h,w]
[4]在Python所有读取图片的模块,其图像格式都是shape=[height, width, channels],
比较另类的是,opencv-python读取的图片的BGR(caffe通道要求是BGR格式),而其他模块是RGB格式
'''
import numpy as np
import matplotlib.pyplot as plt
image_path = 'test_image/C0.jpg'#C0.jpg是高h=400,宽w=200
# 1.caffe
import caffe
img1 = caffe.io.load_image(image_path) # 默认:RGB,float32,[0-1],ndarry,shape=[400,200,3]
# 2.skimage
import skimage.io
img2 = skimage.io.imread(image_path) # 默认:RGB,uint8,[0,255],ndarry,shape=[400,200,3]
# img2=img2/255.0
# 3.matplotlib
import matplotlib.image
img3 = matplotlib.image.imread(image_path) # 默认:RGB,uint8,[0,255],ndarry,shape=[400,200,3]
# 4.PIL
from PIL import Image
temp_img4 = Image.open(image_path) # 默认:RGB,uint8,[0,255],
# temp_img4.show() #会调用系统自定的图片查看器显示图片
img4 = np.array(temp_img4) # 转为ndarry类型,shape=[400,200,3]
# 5.opencv
import cv2
temp_img5 = cv2.imread(image_path) # 默认:BGR(不是RGB),uint8,[0,255],ndarry,shape=[400,200,3]
# cv2.imshow("opencv-python",temp_img5)
# cv2.waitKey(0)
# b, g, r = cv2.split(temp_img5)# 将BGR转为RGB格式
# img5 = cv2.merge([r, g, b])
# 推荐使用cv2.COLOR_BGR2RGB->将BGR转为RGB格式
img5 = cv2.cvtColor(temp_img5, cv2.COLOR_BGR2RGB)
img6 = img5.transpose(2, 0, 1) #通道由[h,w,c]->[c,h,w]
# 以上ndarry类型图像数据都可以用下面的方式直接显示
plt.imshow(img5) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()
封装好的图像读取和保存模块:
import matplotlib.pyplot as plt
import cv2
def show_image(title, image):
'''
显示图片
:param title: 图像标题
:param image: 图像的数据
:return:
'''
# plt.figure("show_image")
# print(image.dtype)
plt.imshow(image)
plt.axis('on') # 关掉坐标轴为 off
plt.title(title) # 图像题目
plt.show()
def show_image_rect(win_name, image, rect):
plt.figure()
plt.title(win_name)
plt.imshow(image)
rect =plt.Rectangle((rect[0], rect[1]), rect[2], rect[3], linewidth=2, edgecolor='r', facecolor='none')
plt.gca().add_patch(rect)
plt.show()
def read_image(filename, resize_height, resize_width,normalization=False):
'''
读取图片数据,默认返回的是uint8,[0,255]
:param filename:
:param resize_height:
:param resize_width:
:param normalization:是否归一化到[0.,1.0]
:return: 返回的图片数据
'''
bgr_image = cv2.imread(filename)
if len(bgr_image.shape)==2:#若是灰度图则转为三通道
print("Warning:gray image",filename)
bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#将BGR转为RGB
# show_image(filename,rgb_image)
# rgb_image=Image.open(filename)
if resize_height>0 and resize_width>0:
rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
rgb_image=np.asanyarray(rgb_image)
if normalization:
# 不能写成:rgb_image=rgb_image/255
rgb_image=rgb_image/255.0
# show_image("src resize image",image)
return rgb_image
def save_image(image_path,image):
plt.imsave(image_path,image)
(2)将 numpy 数组转换为 PIL 图片:
这里采用 matplotlib.image 读入图片数组,注意这里读入的数组是 float32 型的,范围是 0-1,而 PIL.Image 数据是 uinit8 型的,范围是0-255,所以要进行转换:
import matplotlib.image as mpimg
from PIL import Image
lena = mpimg.imread('lena.png') # 这里读入的数据是 float32 型的,范围是0-1
im = Image.fromarray(np.uinit8(lena*255))
im.show()
(3)python中PIL.Image和OpenCV图像格式相互转换
PIL.Image转换成OpenCV格式:
import cv2
from PIL import Image
import numpy
image = Image.open("plane.jpg")
image.show()
img = cv2.cvtColor(numpy.asarray(image),cv2.COLOR_RGB2BGR)
cv2.imshow("OpenCV",img)
cv2.waitKey()
OpenCV转换成PIL.Image格式:
import cv2
from PIL import Image
import numpy
img = cv2.imread("plane.jpg")
cv2.imshow("OpenCV",img)
image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
image.show()
cv2.waitKey()
判断图像数据是否是OpenCV格式:
isinstance(img, np.ndarray)
(4)matplotlib显示阻塞问题
matplotlib.pyplot 中显示图像的两种模式(交互和阻塞)及其在Python画图中的应用_wonengguwozai的博客-CSDN博客_matplotlib 交互模式
下面这个例子讲的是如何像matlab一样同时打开多个窗口显示图片或线条进行比较,同时也是在脚本中开启交互模式后图像一闪而过的解决办法:
import matplotlib.pyplot as plt
plt.ion() # 打开交互模式
# 同时打开两个窗口显示图片
plt.figure()
plt.imshow(image1)
plt.figure()
plt.imshow(image2)
plt.ioff()# 显示前关掉交互模式,避免一闪而过
plt.show()
(5)matplotlib绘制矩形框
import matplotlib.pyplot as plt
def show_image(win_name, image, rect):
plt.figure()
plt.title(win_name)
plt.imshow(image)
rect =plt.Rectangle((rect[0], rect[1]), rect[2], rect[3], linewidth=2, edgecolor='r', facecolor='none')
plt.gca().add_patch(rect)
plt.show()
3.3 one-hot独热编码
import os
import numpy as np
from sklearn import preprocessing
def gen_data_labels(label_list,ont_hot=True):
'''
label_list:输入labels ->list
'''
# 将labels转为整数编码
# labels_set=list(set(label_list))
# labels=[]
# for label in label_list:
# for k in range(len(labels_set)):
# if label==labels_set[k]:
# labels+=[k]
# break
# labels = np.asarray(labels)
# 也可以用下面的方法:将labels转为整数编码
labelEncoder = preprocessing.LabelEncoder()
labels = labelEncoder.fit_transform(label_list)
labels_set = labelEncoder.classes_
for i in range(len(labels_set)):
print("labels:{}->{}".format(labels_set[i],i))
# 是否进行独热编码
if ont_hot:
labels_nums=len(labels_set)
labels = labels.reshape(len(labels), 1)
onehot_encoder = preprocessing.OneHotEncoder(sparse=False,categories=[range(labels_nums)])
onehot_encoder = preprocessing.OneHotEncoder(sparse=False,categories='auto')
labels = onehot_encoder.fit_transform(labels)
return labels
3.4 循环产生batch数据:
TXT文本:
1.jpg 1 11
2.jpg 2 12
3.jpg 3 13
4.jpg 4 14
5.jpg 5 15
6.jpg 6 16
7.jpg 7 17
8.jpg 8 18
# -*-coding: utf-8 -*-
"""
@Project: LSTM
@File : create_batch_data.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-10-27 18:20:15
"""
import math
import random
import os
import glob
import numpy as np
def get_list_batch(inputs, batch_size=None, shuffle=False):
'''
循环产生batch数据
:param inputs: list数据
:param batch_size: batch大小
:param shuffle: 是否打乱inputs数据
:return: 返回一个batch数据
'''
if shuffle:
random.shuffle(inputs)
while True:
batch_inouts = inputs[0:batch_size]
inputs=inputs[batch_size:] + inputs[:batch_size]# 循环移位,以便产生下一个batch
yield batch_inouts
def get_data_batch(inputs, batch_size=None, shuffle=False):
'''
循环产生batch数据
:param inputs: list数据
:param batch_size: batch大小
:param shuffle: 是否打乱inputs数据
:return: 返回一个batch数据
'''
# rows,cols=inputs.shape
rows=len(inputs)
indices =list(range(rows))
if shuffle:
random.shuffle(indices )
while True:
batch_indices = indices[0:batch_size]
indices= indices [batch_size:] + indices[:batch_size] # 循环移位,以便产生下一个batch
batch_data=find_list(batch_indices,inputs)
# batch_data=find_array(batch_indices,inputs)
yield batch_data
def find_list(indices,data):
out=[]
for i in indices:
out=out+[data[i]]
return out
def find_array(indices,data):
rows,cols=data.shape
out = np.zeros((len(indices), cols))
for i,index in enumerate(indices):
out[i]=data[index]
return out
def load_file_list(text_dir):
text_dir = os.path.join(text_dir, '*.txt')
text_list = glob.glob(text_dir)
return text_list
def get_next_batch(batch):
return batch.__next__()
def load_image_labels(test_files):
'''
载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签1,如:test_image/1.jpg 0 2
:param test_files:
:return:
'''
images_list=[]
labels_list=[]
with open(test_files) as f:
lines = f.readlines()
for line in lines:
#rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格)
content=line.rstrip().split(' ')
name=content[0]
labels=[]
for value in content[1:]:
labels.append(float(value))
images_list.append(name)
labels_list.append(labels)
return images_list,labels_list
if __name__ == '__main__':
filename='./training_data/train.txt'
images_list, labels_list=load_image_labels(filename)
# inputs = np.reshape(np.arange(8*3), (8,3))
iter = 10 # 迭代10次,每次输出5个
batch = get_data_batch(images_list, batch_size=3, shuffle=False)
for i in range(iter):
print('**************************')
# train_batch=batch.__next__()
batch_images=get_next_batch(batch)
print(batch_images)
3.5 统计元素个数和种类
label_list=['星座', '星座', '财经', '财经', '财经', '教育', '教育', '教育', ]
set1 = set(label_list) # set1 ={'财经', '教育', '星座'},set集合中不允许重复元素出现
set2 = np.unique(label_list)# set2=['教育' '星座' '财经']
# 若要输出对应元素的个数:
from collections import Counter
arr = [1, 2, 3, 3, 2, 1, 0, 2]
result = {}
for i in set(arr):
result[i] = arr.count(i)
print(result)
# 更加简单的方法:
import pandas as pd
print(pd.value_counts(label_list))
3.6 python 字典(dict)按键和值排序
python 字典(dict)的特点就是无序的,按照键(key)来提取相应值(value),如果我们需要字典按值排序的话,那可以用下面的方法来进行:
1 .下面的是按照value的值从大到小的顺序来排序
dic = {'a':31, 'bc':5, 'c':3, 'asd':4, 'aa':74, 'd':0}
dict= sorted(dic.items(), key=lambda d:d[1], reverse = True)
print dict
输出的结果:
[('aa', 74), ('a', 31), ('bc', 5), ('asd', 4), ('c', 3), ('d', 0)]
下面我们分解下代码
print dic.items() 得到[(键,值)]的列表。
然后用sorted方法,通过key这个参数,指定排序是按照value,也就是第一个元素d[1的值来排序。reverse = True表示是需要翻转的,默认是从小到大,翻转的话,那就是从大到小。
2 .对字典按键(key)排序:
dic = {'a':31, 'bc':5, 'c':3, 'asd':4, 'aa':74, 'd':0}
dict= sorted(dic.items(), key=lambda d:d[0]) d[0]表示字典的键
print dict
3.7 自定义排序sorted
下面my_sort函数,将根据labels的相同的个数进行排序,把labels相同的个数多的样本,排在前面
# -*-coding: utf-8 -*-
"""
@Project: IntelligentManufacture
@File : statistic_analysis.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2019-02-15 13:47:58
"""
import pandas as pd
import numpy as np
import functools
def print_cluster_info(title,labels_id, labels,columns = ['labels_id', 'labels']):
index= np.arange(0, len(labels_id)) + 1
content = np.array([labels_id, labels]).T
df = pd.DataFrame(content, index=index, columns=columns) # 生成6行4列位置
print('*************************************************')
print("{}{}".format(title,df))
def print_cluster_container(title,cluster_container,columns = ['labels_id', 'labels']):
'''
:param cluster_container:type:list[tupe()]
:param columns:
:return:
'''
labels_id, labels=zip(*cluster_container)
labels_id=list(labels_id)
labels=list(labels)
print_cluster_info(title,labels_id, labels, columns=columns)
def sort_cluster_container(cluster_container):
'''
自定义排序:将根据labels的相同的个数进行排序,把labels相同的个数多的样本,排在前面
:param labels_id:
:param labels:
:return:
'''
# labels_id=list(cluster_container.keys())
# labels=list(cluster_container.values())
labels_id, labels=zip(*cluster_container)
labels_id=list(labels_id)
labels=list(labels)
# 求每个labels的样本个数value_counts_dict
value_counts_dict = {}
labels_set = set(labels)
for i in labels_set:
value_counts_dict[i] = labels.count(i)
def cmp(a, b):
# 降序
a_key, a_value = a
b_key, b_value = b
a_count = value_counts_dict[a_value]
b_count = value_counts_dict[b_value]
if a_count > b_count: # 个数多的放在前面
return -1
elif (a_count == b_count) and (a_value > b_value): # 当个数相同时,则value大的放在前面
return -1
else:
return 1
out = sorted(cluster_container, key=functools.cmp_to_key(cmp))
return out
if __name__=='__main__':
labels_id=["image0",'image1',"image2","image3","image4","image5","image6"]
labels=[0.0,1.0,2.0,1.0,1.0,2.0,3.0]
# labels=['L0','L1','L2','L1','L1','L2',"L3"]
cluster_container=list(zip(labels_id, labels))
print("cluster_container:{}".format(cluster_container))
print_cluster_container("排序前:\n",cluster_container, columns=['labels_id', 'labels'])
out=sort_cluster_container(cluster_container)
print_cluster_container("排序后:\n",out, columns=['labels_id', 'labels'])
结果:
3.8 加载yml配置文件
假设config.yml的配置文件如下:
## Basic config
batch_size: 2
learning_rate: 0.001
epoch: 1000## reset image size
height: 128
width: 128
利用Python可以如下加载数据:
import yaml
class Dict2Obj:
'''
dict转类对象
'''
def __init__(self, bokeyuan):
self.__dict__.update(bokeyuan)
def load_config_file(file):
with open(file, 'r') as f:
data_dict = yaml.load(f,Loader=yaml.FullLoader)
data_dict = Dict2Obj(data_dict)
return data_dict
if __name__=="__main__":
config_file='../config/config.yml'
para=load_config_file(config_file)
print("batch_size:{}".format(para.batch_size))
print("learning_rate:{}".format(para.learning_rate))
print("epoch:{}".format(para.epoch))
运行输出结果:
batch_size:2
learning_rate:0.001
epoch:1000
3.9 移动、复制、重命名文件
# -*- coding: utf-8 -*-
#!/usr/bin/python
#test_copyfile.py
import os,shutil
def rename(image_list):
for name in image_list:
cut_len=len('_cropped.jpg')
newName = name[:-cut_len]+'.jpg'
print(name)
print(newName)
os.rename(name, newName)
def mymovefile(srcfile,dstfile):
if not os.path.isfile(srcfile):
print "%s not exist!"%(srcfile)
else:
fpath,fname=os.path.split(dstfile) #分离文件名和路径
if not os.path.exists(fpath):
os.makedirs(fpath) #创建路径
shutil.move(srcfile,dstfile) #移动文件
print "move %s -> %s"%( srcfile,dstfile)
def mycopyfile(srcfile,dstfile):
if not os.path.isfile(srcfile):
print "%s not exist!"%(srcfile)
else:
fpath,fname=os.path.split(dstfile) #分离文件名和路径
if not os.path.exists(fpath):
os.makedirs(fpath) #创建路径
shutil.copyfile(srcfile,dstfile) #复制文件
print "copy %s -> %s"%( srcfile,dstfile)
srcfile='/Users/xxx/git/project1/test.sh'
dstfile='/Users/xxx/tmp/tmp/1/test.sh'
mymovefile(srcfile,dstfile)
3.10 产生batch_size的数据
def get_batch(image_list, batch_size):
nums = len(image_list)
# batch_num = math.ceil(sample_num / batch_size)
batch_num = (nums + batch_size - 1) // batch_size
for i in range(batch_num):
start = i * batch_size
end = min((i + 1) * batch_size, nums)
batch_image = image_list[start:end]
print("batch_image:{}".format(batch_image))
if __name__ == "__main__":
nums = 20
batch_size = 25
image_list = []
for i in range(nums): image_list.append(str(i + 1) + ".jpg")
get_batch(image_list, batch_size)