说明:列表所有维度的深度必须是一样的。
例如:[[1,2,3],[2,3]]
是合法的,但[[1,2,3],2]
是不合法的。
def cal_len(seqs):
# 计算嵌套列表个维度的长度。例如 cal_len([[1,2,3],[1,2]])->[3,2]
if not isinstance(seqs[0], list):
return len(seqs)
else:
return [cal_len(lst) for lst in seqs]
def cal_max(seqs):
# 计算列表每个维度的最大值,例如 cal_max([[1,2,3],[1,2]])->[3,2]
if isinstance(seqs, int):
return seqs
if isinstance(seqs[0], list):
return cal_max([cal_max(lst) for lst in seqs])
else:
return max(seqs)
def cal_dims(seqs):
# 计算每个维度长度最大值,例如 cal_dims([[1,2,3],[1,2]])->[2,3]
size = []
tmp_seqs = deepcopy(seqs)
while isinstance(tmp_seqs, list):
tmp_seqs = cal_len(tmp_seqs)
size.append(cal_max(tmp_seqs))
return size[::-1]
def size_match(mat, size):
# 判断是不列表是不是已经转化成size大小
# 例如:size_match([[1,2,3],[1,2]], [2,3])->False
# size_match([[1,2,3],[1,2,3]], [2,3])->False
# size_match([[1,2,3],[1,2,3]], [3,3])->False
assert isinstance(mat, list)
if isinstance(mat[0], list):
return (len(mat) == size[0]) & all([size_match(m, size[1:]) for m in mat])
else:
return len(mat) == size[0]
def pad_seq(seqs, size):
# 将list 填充成 size大小,用0填充
# 例如: pad_seq([[1,2,3],[1,2]],[2,3])->[[1,2,3],[1,2,0]]
zero_seq = np.zeros([size[0] - len(seqs)] + size[1:], dtype=np.int).tolist()
seqs.extend(zero_seq)
if isinstance(seqs[0], list):
for seq in seqs:
pad_seq(seq, size[1:])
if __name__=="__main":
seqs = [[[1,2,3],[4,5,6]],[[6,7],[8]]]
size = cal_dims(seqs)
pad_seq(seqs, size)
pad_seq(seqs, [3,3,3])
>>> [[[1, 2, 3], [4, 5, 6]], [[6, 7, 0], [8, 0, 0]]]
>>> [[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
[[6, 7, 0], [8, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]]]