上面的例子是使用numpy.where()得到二维数组中符合条件的数据的索引,
位置是以tuple的形式返回的:
tuple里每个元素(array)可以理解为对应axis上的坐标
这里numpy中是先行后列的坐标,行index,列index
推广到更高维数组的情况:
可以先看我另一篇文章讲高维数组读法的↓↓↓
python numpy高维数组(三维数组) reshape操作+order详解+numpy高维数组的读法详解_プロノCodeSteel-CSDN博客
以右侧shape的数组为例 (10,9,8,7)
使用numpy.where()按条件搜索单个值则会返回length为4的tuple
如果想要定位是高维数组里的低维数组:
比如是一张BGR格式的图片
shape: (410,820,3) 设变量为 image
需要定位每一个[255,255,255]的数组
则可以使用一个shape: (410,820)的有值的数组 设为 loc,对其使用numpy.where(),用返回的tuple选取需要的低维数组
代码:(这种写法应该是隐式调用了numpy.where())
needed = image[loc==0]
发现上面这个例子不够清楚,给个例子:属于传统ComputerVision的,一个使用CCA识别物体并使用不同颜色标记目标的代码
用法这一行代码是
labeled_img[label_hue == 0] = 0
例子:
import sys
import cv2
from matplotlib import pyplot as plt
import numpy as np
import copy
from Threshold_Based_Segmentation1 import calc_hist
show_img = False
data_dir = './data/2/'
src = cv2.imread(str(data_dir + 'birds.jpg'), cv2.IMREAD_GRAYSCALE)
############################################## 1 get histogram #########################################################
calc_hist(data_dir, 'bird_', src)
print(src.shape)
################################################ 2 Threshold ###########################################################
ret, thresh = cv2.threshold(src, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # black birds
cv2.imwrite(str(data_dir + 'birds_otsu_thres.jpg'), thresh)
# print(ret) # 171
# use self defined threshold to segment
_, thresh = cv2.threshold(src, 75, 255, cv2.THRESH_BINARY_INV) # white birds
cv2.imwrite(str(data_dir + 'birds_t75_thres.jpg'), thresh)
######################################### 3 Binary Morphology: Opening ###############################################
# OpenCV set a SE
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# Do erosion to img
# !!!特别注意,如果目标是黑色的,那么erode 和 dilate操作将会相反,因为黑色恰好值为0
eroded = cv2.erode(thresh, kernel)
# 显示腐蚀后的图像
if show_img:
cv2.imshow("Eroded Image", eroded)
cv2.imwrite(str(data_dir + 'birds_t75_thres_eroded.jpg'), eroded)
# 膨胀图像
dilated = cv2.dilate(thresh, kernel)
# 显示膨胀后的图像
if show_img:
cv2.imshow("Dilated Image", dilated)
cv2.imwrite(str(data_dir + 'birds_t75_thres_dilated.jpg'), dilated)
# opening 操作,disconnected
eroded = cv2.erode(thresh, kernel,iterations=2)
dilated = cv2.dilate(eroded, kernel,iterations=2)
if show_img:
cv2.imshow("opening Image", dilated)
cv2.imwrite(str(data_dir + 'birds_t75_thres_opening2times.jpg'), dilated)
if show_img:
cv2.waitKey(0)
cv2.destroyAllWindows()
########################################## 4 Connected-component labeling ###############################################
def cca(thresh, connectivity, background=0):
"""Label connected regions of an integer array.
https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.label
Args:
thresh: The binary thresh image
connectivity: 4 or 8, default 8
background: int, optional (not implemented)
Returns:
return labels(Labeled array, where all connected regions are assigned the same integer value.),num(Number of labels, which equals the maximum label index)
"""
labels = np.zeros_like(thresh)
# padding the image for ease of scanning
padded = np.pad(thresh, pad_width=((1, 1), (1, 1)), constant_values=0)
shape = labels.shape
num = 1
equals = {}
# first run: label new number or lowest in the neighbor
for y in range(1, shape[0] + 1):
for x in range(1, shape[1] + 1):
if connectivity == 8:
if padded[y + 1, x + 1] == 0:
continue
else:
square = labels[y - 1:y + 2, x - 1:x + 2]
neighbors = list(set(square.flatten()) - set([0]))
if len(neighbors) == 0:
labels[y, x] = num
equals[num] = [num]
num += 1
else:
labels[y, x] = min(neighbors)
if equals.get(min(neighbors)) is None:
equals[labels[y, x]] = neighbors
else:
equals[labels[y, x]] += neighbors
# equals[labels[y, x]] = list(set(equals[labels[y, x]]))
elif connectivity == 4: # 全是bug按照上面的改
if padded[y - 1, x - 1] == 0:
continue
else:
cross = [labels[y - 2, x - 1], labels[y, x - 1], labels[y - 1, x - 2], labels[y - 1, x]]
neighbors = list(set(cross.flatten()) - set([0]))
if len(neighbors) == 0:
labels[y - 1, x - 1] = num
equals[num] = [num]
num += 1
else:
labels[y - 1, x - 1] = min(neighbors)
if equals.get(min(neighbors)) is None:
equals[labels[y - 1, x - 1]] = neighbors
else:
equals[labels[y - 1, x - 1]] += neighbors
# equals[labels[y - 1, x - 1]] = list(set(equals[labels[y - 1, x - 1]]))
# second run: replace label with lower number in equations
# 1 : [2,3,5] The right only records those that are higher than the left
# 2 : [7]
# 所有dict转化成list, key 加到 value list里面
equal_list = [] # 储存所有等价关系
for key in equals:
equals[key].append(key)
equals[key] = list(set(equals[key]))
equal_list.append(equals[key])
# 合并有相同项的list,并移除其中一个
# logic
# 1. 只要找到有交集的list就合并,并重新开始对于list的循环
# 2. 对所有list循环完成并未找到有交集的list,停止合并
# final_equal = copy.deepcopy(equal_list)
# l = len(equal_list)
done = False
outer_break = False
while True:
if done:
break
outer_break = False
for i in range(len(equal_list)):
for j in range(len(equal_list)):
if i == len(equal_list) - 1 and j == len(equal_list) - 1:
done = True
if i == j:
continue
else:
if len(list(set(equal_list[i]) & set(equal_list[j]))) != 0:
equal_list[i] += equal_list[j]
equal_list[i] = list(set(equal_list[i]))
equal_list.pop(j)
# 这时list已经发生改变,需要新的大循环
outer_break = True
break
if outer_break:
break
# 3.对每一个list进行升序排序
for i in range(len(equal_list)):
equal_list[i] = sorted(equal_list[i])
# 4.relabelling
count = 1
for item in equal_list:
# item : [1,2,3,4]
for i in item:
labels[labels == i] = count
count += 1
return labels, len(equal_list)
# You need to choose 4 or 8 for connectivity type
use_self_func = True
# use_self_func = False
connectivity = 8
if use_self_func:
labels, num_labels = cca(thresh, connectivity)
print(num_labels) #
else:
# Perform the operation
output = cv2.connectedComponentsWithStats(thresh, connectivity, cv2.CV_32S)
# Get the results
# The first cell is the number of labels
num_labels = output[0]
print(num_labels) # 49
# label matrix: the same spatial dimensions as our input thresh
labels = output[1]
# The third cell is the stat matrix
stats = output[2]
# The fourth cell is the centroid matrix
centroids = output[3]
print(labels.shape) # (492, 800)
print(type(labels))
################################################ 5 coloring ############################################################
# read in BGR
src_color = cv2.imread(str(data_dir + 'birds.jpg'), cv2.IMREAD_COLOR)
print(src_color[0, 0, 1])
# creat color series
label_arr = np.arange(0, num_labels, 1)
def imshow_components(labels):
# Map component labels to hue val
label_hue = np.uint8(179.0 * labels / np.max(labels))
blank_ch = 255 * np.ones_like(label_hue)
labeled_img = cv2.merge([label_hue, blank_ch, blank_ch])
# cvt to BGR for display
labeled_img = cv2.cvtColor(labeled_img, cv2.COLOR_HSV2BGR)
# set bg label to black
labeled_img[label_hue == 0] = 0
# cv2.imshow('labeled.png', labeled_img)
# cv2.waitKey()
return labeled_img
labeled_img = imshow_components(labels)
cv2.imwrite(str(data_dir + 'birds_t75_thres_colored_blackBG.jpg'), labeled_img)
# color the birds
for x in range(labels.shape[0]):
for y in range(labels.shape[1]):
# cannot judge by single value of B/G/R
if np.sum(labeled_img[x, y]) != 0:
src_color[x, y, :] = labeled_img[x, y, :]
cv2.imwrite(str(data_dir + 'birds_t75_thres_colored.jpg'), src_color)
cv2.imshow('colored.png', src_color)
cv2.waitKey()
cv2.destroyAllWindows()