# coding=utf-8
# 螺旋迭代器
import numpy as np
class SpiralIterator:
def __init__(self, source, x=None, y=None, length=None):
self.source = source
self.row = np.shape(self.source)[0]
self.col = np.shape(self.source)[1]
if length:
self.length = min(length, np.size(self.source))
else:
self.length = np.size(self.source)
if x:
self.x = x
else:
self.x = self.row // 2
if y:
self.y = y
else:
self.y = self.col // 2
self.i = self.x
self.j = self.y
self.iteSize = 0
def hasNext(self):
return self.iteSize < self.length # 不能取更多值了
def get(self):
if self.hasNext(): # 还能再取一个值
# 先记录当前坐标的值 —— 准备返回
i = self.i
j = self.j
# 计算下一个值的坐标
relI = self.i - self.x # 相对坐标
relJ = self.j - self.y # 相对坐标
if relJ > 0 and abs(relI) < relJ:
self.i -= 1 # 上
elif relI < 0 and relJ > relI:
self.j -= 1 # 左
elif relJ < 0 and abs(relJ) > relI:
self.i += 1 # 下
elif relI >= 0 and relI >= relJ:
self.j += 1 # 右
if 0 <= i < self.row and 0 <= j < self.col:
self.iteSize += 1
return i, j, self.source[i][j]
# 使用方式
iterator = SpiralIterator(self.lock) # 螺旋遍历
while iterator.hasNext():
x, y, data = iterator.get() # 坐标X,坐标Y,数值data