图像主题色提取算法

中位切分法

在RGB彩色空间中R,G,B三基色对应于空间的3个坐标轴,并将每一个坐标轴都量化到0-255。0对应全黑,255对应全白。这样就形成了一个边长256的彩色立方体,所有可能的颜色都对应于立方体上的一个点。
在这里插入图片描述

算法步骤

  1. 将图片转为rgb直方图,空间中可以想象成一个色块,分别为R轴、G轴、B轴
  2. 找到最长的那条轴,使用最长轴进行排序。
  3. 将该色块按照排序后的结果一分为二。
  4. 将第3步骤得到的2个色块继续进行234步骤,直到色块数量达到提取色的个数k。
  5. 得到k个色块,统计每个色块的像素RGB均值,即最终提取色结果。

实现

# k为提取色数
def medianSegmentation(imgPath, k = 256):
	# 读取图片	
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	img = np.array(img)
	# 3维转2维
	img = img.reshape((-1, img.shape[2]))
	# 队列
	que = Queue()
	# 入队
	que.put(img)
	# 切分数量到k之前不停止
	while que.qsize() < k:
		# 取色块
		img = que.get()
		# 找rgb中的最大轴
		selectColor = 0
		selectDifference = 0
		for i in range(3):
			selectMinVal = np.min(img[:, i])
			selectMaxVal = np.max(img[:, i])
			if selectMaxVal - selectMinVal > selectDifference:
				selectColor = i
				selectDifference = selectMaxVal - selectMinVal
		# 使用最大轴进行排序
		sortedIndexs = np.lexsort((img[:, selectColor], ))
		img = img[sortedIndexs, :]
		# 切分
		leftImg = img[:img.shape[0]//2, :]
		rightImg = img[img.shape[0]//2:, :]
		# 入队列
		que.put(leftImg)
		que.put(rightImg)
	# 提取颜色
	colors = []
	while not que.empty():
		# 取色块
		img = que.get()
		# 取像素均值
		colors.append(np.mean(img, axis=0))
	# 显示色块
	showColors = np.zeros(shape=(len(colors) * 20, 200, 3), dtype=np.uint8)
	for i in range(k):
		showColors[i*20: i*20+20, :, :] = colors[i]
	cv2.imshow("", showColors)
	cv2.waitKey(0)

优化

  1. 由于每次切分色块后都进行了一次排序算法,这样性能是非常差的。因此可以对RGB分别进行一次排序后,使用一些方式进行一定的优化。
  2. 内存方面,重复开辟了空间;可以结合第一点一起整合优化。
  3. 有一些文章说:有可能存在某些条件下VBox体积很大但只包含少量像素。解决的方法是,每次进行切分时,并不是对上一次切分得到的所有VBox进行切分,而是通过一个优先级队列进行排序,刚开始时这一队列以VBox仅以VBox所包含的像素数作为优先级考量,当切分次数变多之后,将体积*包含像素数作为优先级。其中VBox是这篇文章中描述的色块。

八叉树

其时间复杂度和空间复杂度都有很大的优势,并且保真度也是非常的高。 此段摘抄图片主题色提取算法,实验得到时间复杂度较大

算法原理

算法描述较为复杂,其原理是将颜色的RGB转为二进制,二进制形式为xxxx xxxx,将RGB的8位二进制形式的每一列进行黏合。例如:

R: 0100 1010
G: 0101 0100
B: 0010 0101

经过黏合后的列表为:

level0: 000		对应index为0, 即在-1层节点的children[0]
level1: 110		对应index为6, 即在0层节点的children[0]->children[6]
level2: 001		对应index为1, 即在0层节点的children[0]->children[6]->children[1]
level3: 010		对应index为2, 即在0层节点的children[0]->children[6]->children[1]->children[2]

level4: 100		对应index为4, 即在0层节点的children[0]->children[6]->children[1]->children[2]->children[4]
level5: 011		对应index为3, 即在0层节点的children[0]->children[6]->children[1]->children[2]->children[4]->children[3]
level6: 100		对应index为4, 即在0层节点的children[0]->children[6]->children[1]->children[2]->children[4]->children[3]->children[4]
level7: 001		对应index为1, 即在0层节点的children[0]->children[6]->children[1]->children[2]->children[4]->children[3]->children[4]->children[1]

具体可以参考后面两篇文章:图片主题色提取算法Octree color quantization

实现

# 常量类
class Const(object):
	MAX_LEVEL = 8


# 颜色类
class Color(object):

	# Color构造函数
	def __init__(self, r, g, b):
		self.r = r
		self.g = g
		self.b = b

	# 颜色相加
	def add(self, color):
		self.r += color.r
		self.g += color.g
		self.b += color.b

	# 颜色除法
	def div(self, k):
		if k == 0:
			raise Exception("error color div zero.")
		return Color(self.r // k, self.g // k, self.b // k)

	# 根据八叉树原理得到是哪个children
	def getIndex(self, level):
		r = "{0:08b}".format(self.r)[level]
		g = "{0:08b}".format(self.g)[level]
		b = "{0:08b}".format(self.b)[level]
		return int(''.join([r, g, b]), 2)

	def __str__(self):
		return "Color({0}, {1}, {2})".format(self.r, self.g, self.b)

	def __repr__(self):
		return str(self)


# 八叉树节点类
class Node(object):

	# 节点构造函数
	def __init__(self, level, parent):
		self.color = Color(0, 0, 0)						# 节点颜色
		self.level = level  							# 节点所属level
		self.children = [None for i in range(8)]		# 节点拥有的children
		self.pixedCount = 0  							# 相同颜色的个数
		if level < Const.MAX_LEVEL - 1:					# 节点level为7, 则不进octree levels链表
			parent.addLevelNode(level + 1, self)		# 由于root节点level为-1, 因此进行+1操作

	# 递归创建一个Color路径到叶节点
	def addColor(self, color, level, parent):
		if level < Const.MAX_LEVEL:										# level小于8
			index = color.getIndex(level)								# level从0到7, 也就是获取每一列的rgb编码进而得到的索引
			if self.children[index] is None:							# 对应index孩子还没创建
				self.children[index] = Node(level, parent)				# 创建孩子节点, 放在index位置
			self.children[index].addColor(color, level + 1, parent)		# 递归level到下一层
		else:															# level到达第8层, 为叶节点
			self.color.add(color)										# 第8层不需要创建Node, 直接累加第7层的color
			self.pixedCount += 1  										# 累加第7层的color数量

	# 获取包括自身及孩子的叶节点
	def leafNodes(self):
		leafNodes = []											# 记录叶节点
		if self.isLeaf():
			leafNodes.append(self)								# append叶节点
		else:
			for node in self.children:							# 遍历子树
				if not node is None:
					leafNodes = leafNodes + node.leafNodes()	# 拿到子树的叶节点
		return leafNodes										# 返回叶节点

	# 是否为叶节点
	def isLeaf(self):
		return self.pixedCount > 0 		# 若Node的PixedCound大于0, 说明是叶节点

	# 合并孩子节点(该操作在外部调用需要从level=7开始, 也就是从叶子节点一直遍历到root节点; 在这个过程中, 节点A的叶子(孩子)节点被合并, 节点A变为了叶节点)
	def reduce(self):
		reduceCount = 0  								# 合并叶节点数量
		for node in self.children:						# 遍历孩子
			if not node is None:
				self.color.add(node.color)				# 将叶节点颜色值累加到父节点上
				self.pixedCount += node.pixedCount		# 将叶节点像素个数累加到父节点上
				reduceCount += 1 						# 合并计数
		self.children = [None for i in range(8)]  		# 将孩子抛弃掉
		return reduceCount - 1  						# 由于自身Node变为叶节点, 因此+1

	# 均值Node上的Color值
	def normalize(self):
		return self.color.div(self.pixedCount)			# 均值Color


# 八叉树算法
class Octree(object):

	# 八叉树构造函数
	def __init__(self):
		self.levels = [[] for i in range(Const.MAX_LEVEL)] 		# 构建levels链表, 用于后续提取color主题色
		self.root = Node(-1, self)								# root节点

	# 添加Node到levels链表中
	def addLevelNode(self, level, node):
		self.levels[level].append(node)

	# 添加颜色到八叉树中
	def addColor(self, color):
		self.root.addColor(color, 0, self)

	# 提取颜色
	def extractColor(self, k = 256):
		leafCount = len(self.root.leafNodes())		# 获取八叉树叶节点数量
		for i in range(Const.MAX_LEVEL, 0, -1):		# 从level为7开始遍历, 合并叶节点
			level = i - 1   						# 由于i从8开始, 因此-1
			if leafCount <= k:						# 如果叶节点已经小于提取个数k, 结束
				break
			if not self.levels[level] is None:		# 链表不为空, 遍历叶节点的父亲层, 因为在这里是要统计父节点的颜色
				for node in self.levels[level]:		# 遍历链表中的node节点
					leafCount -= node.reduce()		# 统计node的颜色
					if leafCount <= k:				# 如果叶节点已经小于提取个数k, 结束
						break
			self.levels[level] = []					# level置空
		# 提取色
		colors = []
		# 获取合并后八叉树的所有叶节点
		for leafNode in self.root.leafNodes():
			if leafNode.isLeaf() and len(colors) <= k:
				# 获取Node颜色的均值
				colors.append(leafNode.normalize())
		# 返回提取色
		return colors


def octreeColor(imgPath):
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	img = np.array(img)
	width, height, channel = img.shape
	# 创建八叉树
	octree = Octree()
	# 加入颜色
	for i in range(width):
		for j in range(height):
			octree.addColor(Color(img[i, j, 0], img[i, j, 1], img[i, j, 2]))
	# 提取颜色
	k = 16
	colors = octree.extractColor(k)
	# 显示色块
	showColors = np.zeros(shape=(len(colors) * 20, 200, 3), dtype=np.uint8)
	for i in range(len(colors)):
		c = [colors[i].r, colors[i].g, colors[i].b]
		showColors[i*20: i*20+20, :, :] = np.array(c)
	cv2.imshow("", showColors)
	cv2.waitKey(0)

性能

经过测试:自己的PC平台上为20~30秒左右。

K-means

K-Means算法是无监督的聚类算法。一般会用来做分类

算法步骤

K-Means聚类算法原理,本文只实现了其中的传统K-Means算法。其原理看文章就能懂。

实现

class KMeans0(object):

	def __init__(self, k = 256, it = 8):
		self.colors = None
		self.epcho = 0
		self.k = k
		self.iter = it

	def distance(self, colors, color):
		return np.sum(np.power(colors - color, 2), axis=1)

	def fit(self, img):
		# 长宽高
		width, height, channel = img.shape
		# 3维转2维
		img = img.reshape((width * height, channel))
		# 随机取点
		colorIndexs = np.random.randint(0, high = width * height, size = self.k, dtype = 'l')
		self.colors = np.array(img[colorIndexs], dtype=np.float32)
		# 迭代次数
		self.epcho = 0
		while True:
			# 每类总颜色值
			C = np.array(self.colors, dtype=np.float32)
			# 每类个数
			CCount = np.ones(shape=(self.k, 1), dtype=np.float32)
			# 为每个像素分类
			for i in range(width * height):
				# 像素对应的类
				index = np.argmin(self.distance(self.colors, img[i]))
				# 对应类中加入这个颜色
				C[index] += img[i]
				# 对应类的颜色数量+1
				CCount[index] += 1
			# 记录上次的类聚中心
			oldColors = np.array(self.colors, dtype=np.uint8)
			# 计算新类聚中心
			self.colors = np.array(C // CCount, dtype=np.float32)
			# 迭代次数+1
			self.epcho += 1
			# 中心点没变更, 拟合完成
			if np.sum(oldColors == np.array(self.colors, dtype=np.uint8)) == self.k * 3:
				break
			# epcho上限退出迭代
			if self.epcho > self.iter:
				break

	def extractColor(self):
		return np.array(self.colors, dtype=np.uint8)

def KmeansColor(imgPath):
	startTime = time.time()
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	# 转为numpy数组
	img = np.array(img)
	# 创建KMeans
	kmeans = KMeans0(16)
	kmeans.fit(img)
	print("time: {0}\n color: {1}".format(time.time() - startTime, kmeans.extractColor()))

性能

sklearn开源机器学习库进行对比,得到的结果与之相差较小,比sklearn快一点。

项目链接

主题色提取项目

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值