numpy.where()在高维数组里的使用

上面的例子是使用numpy.where()得到二维数组中符合条件的数据的索引,

位置是以tuple的形式返回的:

tuple里每个元素(array)可以理解为对应axis上的坐标

这里numpy中是先行后列的坐标,行index,列index

推广到更高维数组的情况:

可以先看我另一篇文章讲高维数组读法的↓↓↓

python numpy高维数组(三维数组) reshape操作+order详解+numpy高维数组的读法详解_プロノCodeSteel-CSDN博客

以右侧shape的数组为例 (10,9,8,7)

使用numpy.where()按条件搜索单个值则会返回length为4的tuple

如果想要定位是高维数组里的低维数组:

比如是一张BGR格式的图片

shape: (410,820,3)    设变量为 image

需要定位每一个[255,255,255]的数组

则可以使用一个shape: (410,820)的有值的数组 设为 loc,对其使用numpy.where(),用返回的tuple选取需要的低维数组

代码:(这种写法应该是隐式调用了numpy.where())

needed = image[loc==0]

发现上面这个例子不够清楚,给个例子:属于传统ComputerVision的,一个使用CCA识别物体并使用不同颜色标记目标的代码

用法这一行代码是

labeled_img[label_hue == 0] = 0

例子:

import sys
import cv2
from matplotlib import pyplot as plt
import numpy as np
import copy
from Threshold_Based_Segmentation1 import calc_hist

show_img = False
data_dir = './data/2/'

src = cv2.imread(str(data_dir + 'birds.jpg'), cv2.IMREAD_GRAYSCALE)

############################################## 1 get histogram #########################################################
calc_hist(data_dir, 'bird_', src)
print(src.shape)

################################################ 2 Threshold ###########################################################
ret, thresh = cv2.threshold(src, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)  # black birds
cv2.imwrite(str(data_dir + 'birds_otsu_thres.jpg'), thresh)
# print(ret)  # 171
# use self defined threshold to segment
_, thresh = cv2.threshold(src, 75, 255, cv2.THRESH_BINARY_INV)  # white birds
cv2.imwrite(str(data_dir + 'birds_t75_thres.jpg'), thresh)

######################################### 3 Binary Morphology: Opening ###############################################
# OpenCV set a SE
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))

# Do erosion to img
# !!!特别注意,如果目标是黑色的,那么erode 和 dilate操作将会相反,因为黑色恰好值为0
eroded = cv2.erode(thresh, kernel)
# 显示腐蚀后的图像
if show_img:
    cv2.imshow("Eroded Image", eroded)
cv2.imwrite(str(data_dir + 'birds_t75_thres_eroded.jpg'), eroded)

# 膨胀图像
dilated = cv2.dilate(thresh, kernel)
# 显示膨胀后的图像
if show_img:
    cv2.imshow("Dilated Image", dilated)
cv2.imwrite(str(data_dir + 'birds_t75_thres_dilated.jpg'), dilated)

# opening 操作,disconnected
eroded = cv2.erode(thresh, kernel,iterations=2)
dilated = cv2.dilate(eroded, kernel,iterations=2)
if show_img:
    cv2.imshow("opening Image", dilated)
cv2.imwrite(str(data_dir + 'birds_t75_thres_opening2times.jpg'), dilated)

if show_img:
    cv2.waitKey(0)
    cv2.destroyAllWindows()


########################################## 4 Connected-component labeling ###############################################


def cca(thresh, connectivity, background=0):
    """Label connected regions of an integer array.
    https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.label

    Args:
        thresh: The binary thresh image
        connectivity: 4 or 8, default 8
        background: int, optional (not implemented)
    Returns:
        return labels(Labeled array, where all connected regions are assigned the same integer value.),num(Number of labels, which equals the maximum label index)
    """
    labels = np.zeros_like(thresh)
    # padding the image for ease of scanning
    padded = np.pad(thresh, pad_width=((1, 1), (1, 1)), constant_values=0)
    shape = labels.shape
    num = 1
    equals = {}
    # first run: label new number or lowest in the neighbor
    for y in range(1, shape[0] + 1):
        for x in range(1, shape[1] + 1):
            if connectivity == 8:
                if padded[y + 1, x + 1] == 0:
                    continue
                else:
                    square = labels[y - 1:y + 2, x - 1:x + 2]
                    neighbors = list(set(square.flatten()) - set([0]))
                    if len(neighbors) == 0:
                        labels[y, x] = num
                        equals[num] = [num]
                        num += 1
                    else:
                        labels[y, x] = min(neighbors)
                        if equals.get(min(neighbors)) is None:
                            equals[labels[y, x]] = neighbors
                        else:
                            equals[labels[y, x]] += neighbors
                            # equals[labels[y, x]] = list(set(equals[labels[y, x]]))
            elif connectivity == 4:  # 全是bug按照上面的改
                if padded[y - 1, x - 1] == 0:
                    continue
                else:
                    cross = [labels[y - 2, x - 1], labels[y, x - 1], labels[y - 1, x - 2], labels[y - 1, x]]
                    neighbors = list(set(cross.flatten()) - set([0]))
                    if len(neighbors) == 0:
                        labels[y - 1, x - 1] = num
                        equals[num] = [num]
                        num += 1
                    else:
                        labels[y - 1, x - 1] = min(neighbors)
                        if equals.get(min(neighbors)) is None:
                            equals[labels[y - 1, x - 1]] = neighbors
                        else:
                            equals[labels[y - 1, x - 1]] += neighbors
                            # equals[labels[y - 1, x - 1]] = list(set(equals[labels[y - 1, x - 1]]))

    # second run: replace label with lower number in equations
    # 1 : [2,3,5] The right only records those that are higher than the left
    # 2 : [7]
    # 所有dict转化成list, key 加到 value list里面
    equal_list = []  # 储存所有等价关系
    for key in equals:
        equals[key].append(key)
        equals[key] = list(set(equals[key]))
        equal_list.append(equals[key])
    # 合并有相同项的list,并移除其中一个
    # logic
    # 1. 只要找到有交集的list就合并,并重新开始对于list的循环
    # 2. 对所有list循环完成并未找到有交集的list,停止合并
    # final_equal = copy.deepcopy(equal_list)
    # l = len(equal_list)
    done = False
    outer_break = False
    while True:
        if done:
            break
        outer_break = False
        for i in range(len(equal_list)):
            for j in range(len(equal_list)):
                if i == len(equal_list) - 1 and j == len(equal_list) - 1:
                    done = True
                if i == j:
                    continue
                else:
                    if len(list(set(equal_list[i]) & set(equal_list[j]))) != 0:
                        equal_list[i] += equal_list[j]
                        equal_list[i] = list(set(equal_list[i]))
                        equal_list.pop(j)
                        # 这时list已经发生改变,需要新的大循环
                        outer_break = True
                        break
            if outer_break:
                break
    # 3.对每一个list进行升序排序
    for i in range(len(equal_list)):
        equal_list[i] = sorted(equal_list[i])
    # 4.relabelling
    count = 1
    for item in equal_list:
        # item : [1,2,3,4]
        for i in item:
            labels[labels == i] = count
        count += 1
    return labels, len(equal_list)


# You need to choose 4 or 8 for connectivity type
use_self_func = True
# use_self_func = False
connectivity = 8
if use_self_func:
    labels, num_labels = cca(thresh, connectivity)
    print(num_labels)  #
else:
    # Perform the operation
    output = cv2.connectedComponentsWithStats(thresh, connectivity, cv2.CV_32S)
    # Get the results
    # The first cell is the number of labels
    num_labels = output[0]
    print(num_labels)  # 49
    # label matrix: the same spatial dimensions as our input thresh
    labels = output[1]
    # The third cell is the stat matrix
    stats = output[2]
    # The fourth cell is the centroid matrix
    centroids = output[3]

print(labels.shape)  # (492, 800)
print(type(labels))
################################################ 5 coloring ############################################################
# read in BGR
src_color = cv2.imread(str(data_dir + 'birds.jpg'), cv2.IMREAD_COLOR)
print(src_color[0, 0, 1])

# creat color series
label_arr = np.arange(0, num_labels, 1)


def imshow_components(labels):
    # Map component labels to hue val
    label_hue = np.uint8(179.0 * labels / np.max(labels))
    blank_ch = 255 * np.ones_like(label_hue)
    labeled_img = cv2.merge([label_hue, blank_ch, blank_ch])

    # cvt to BGR for display
    labeled_img = cv2.cvtColor(labeled_img, cv2.COLOR_HSV2BGR)

    # set bg label to black
    labeled_img[label_hue == 0] = 0

    # cv2.imshow('labeled.png', labeled_img)
    # cv2.waitKey()
    return labeled_img


labeled_img = imshow_components(labels)
cv2.imwrite(str(data_dir + 'birds_t75_thres_colored_blackBG.jpg'), labeled_img)

# color the birds
for x in range(labels.shape[0]):
    for y in range(labels.shape[1]):
        # cannot judge by single value of B/G/R
        if np.sum(labeled_img[x, y]) != 0:
            src_color[x, y, :] = labeled_img[x, y, :]

cv2.imwrite(str(data_dir + 'birds_t75_thres_colored.jpg'), src_color)
cv2.imshow('colored.png', src_color)
cv2.waitKey()
cv2.destroyAllWindows()

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值