SAM处理大型图像报错RuntimeError: nonzero is not supported for tensors with more than INT_MAX elements
错误原因
pytorch不支持64位整数,导致大图像处理时计算结果超过INT_MAX。
解决方法
找到amg.py
,在代码中找到change_indices = diff.nonzero()
行,注释掉这一行,并在后面插入下面的代码段:
# the torch function nonzero() only works up to INT_MAX tensor elements
# We first test if we have more than that:
# Total elements in the tensor
b, w_h = diff.shape
total_elements = b * w_h
# Maximum allowable elements in one chunk - as torch is using 32 bit integers for this function
max_elements_per_chunk = 2**31 - 1
if total_elements < max_elements_per_chunk:
change_indices = (
diff.nonzero()
) # the tensor is < 32 bit so we find the change indices in a single torch call.
else:
# Calculate the number of chunks needed
num_chunks = total_elements // max_elements_per_chunk
if total_elements % max_elements_per_chunk != 0:
num_chunks += 1
# Calculate the actual chunk size
chunk_size = b // num_chunks
if b % num_chunks != 0:
chunk_size += 1
# List to store the results from each chunk
all_indices = []
# Loop through the diff tensor in chunks
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, b)
chunk = diff[start:end, :]
# Get non-zero indices for the current chunk
indices = chunk.nonzero()
# Adjust the row indices to the original tensor
indices[:, 0] += start
all_indices.append(indices)
# Concatenate all the results
change_indices = torch.cat(all_indices, dim=0)