1. hash encoding是如何减少encoding的参数量的,相对于dense grid
instantnpg 中采用了多尺度的特征网格,每个尺度网格对应一个hash table的一块连续区域(底层结构为一个array)
-
coarse resolution
1:1对应,每个网格对应一个array entry 无冲突。 -
fine resolution
使用一个hash function转空间坐标为index,会有冲突,多个网格会指向一个array entry。由于冲突,在参数更新的时候,该网格会自动倾向于保留冲突网格中最重要的尺度细节的网格信息。
self.hash_table = torch.rand(size=(self.hash_table_size * num_levels, features_per_level)) * 2 - 1
2. 采用的hash function
查看nerfstudio的hash encoding的实现。
pytorch_fwd首先获取到每个resolution下的坐标对应的grid的八个顶点坐标经过hash function的array entry的index,即hashed_0-hashed_7,形状[…,num_levels,3],numlevels为resolution的数量。
接着访问八个顶点的尺度特征,f_0-f_7,形状[…, num_levels, features_per_level]。
最后进行线性插值,concatenate得到最终的encoding。
def pytorch_fwd(self, in_tensor: TensorType["bs":..., "input_dim"]) -> TensorType["bs":..., "output_dim"]:
"""Forward pass using pytorch. Significantly slower than TCNN implementation."""
assert in_tensor.shape[-1] == 3
in_tensor = in_tensor[..., None, :] # [..., 1, 3]
scaled = in_tensor * self.scalings.view(-1, 1).to(in_tensor.device) # [..., num_levels, 3]
scaled_c = torch.ceil(scaled).type(torch.int32)
scaled_f = torch.floor(scaled).type(torch.int32)
offset = scaled - scaled_f
hashed_0 = self.hash_fn(scaled_c) # [..., num_levels]
hashed_1 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1))
hashed_2 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1))
hashed_3 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_c[..., 2:3]], dim=-1))
hashed_4 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1))
hashed_5 = self.hash_fn(torch.cat([scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_f[..., 2:3]], dim=-1))
hashed_6 = self.hash_fn(scaled_f)
hashed_7 = self.hash_fn(torch.cat([scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1))
f_0 = self.hash_table[hashed_0] # [..., num_levels, features_per_level]
f_1 = self.hash_table[hashed_1]
f_2 = self.hash_table[hashed_2]
f_3 = self.hash_table[hashed_3]
f_4 = self.hash_table[hashed_4]
f_5 = self.hash_table[hashed_5]
f_6 = self.hash_table[hashed_6]
f_7 = self.hash_table[hashed_7]
f_03 = f_0 * offset[..., 0:1] + f_3 * (1 - offset[..., 0:1])
f_12 = f_1 * offset[..., 0:1] + f_2 * (1 - offset[..., 0:1])
f_56 = f_5 * offset[..., 0:1] + f_6 * (1 - offset[..., 0:1])
f_47 = f_4 * offset[..., 0:1] + f_7 * (1 - offset[..., 0:1])
f0312 = f_03 * offset[..., 1:2] + f_12 * (1 - offset[..., 1:2])
f4756 = f_47 * offset[..., 1:2] + f_56 * (1 - offset[..., 1:2])
encoded_value = f0312 * offset[..., 2:3] + f4756 * (
1 - offset[..., 2:3]
) # [..., num_levels, features_per_level]
return torch.flatten(encoded_value, start_dim=-2, end_dim=-1) # [..., num_levels * features_per_level]
关键还是要看看hash function的代码。
hash函数首先进行了一个乘法操作,然后进行xyz的异或,再通过求模保证其index范围再hash_table_size内,再通过offset将不同尺度下的index平移到各自所在的hash_table_size大小的数组连续块中。
def hash_fn(self, in_tensor: TensorType["bs":..., "num_levels", 3]) -> TensorType["bs":..., "num_levels"]:
"""Returns hash tensor using method described in Instant-NGP
Args:
in_tensor: Tensor to be hashed
"""
# min_val = torch.min(in_tensor)
# max_val = torch.max(in_tensor)
# assert min_val >= 0.0
# assert max_val <= 1.0
in_tensor = in_tensor * torch.tensor([1, 2654435761, 805459861]).to(in_tensor.device)
x = torch.bitwise_xor(in_tensor[..., 0], in_tensor[..., 1])
x = torch.bitwise_xor(x, in_tensor[..., 2])
x %= self.hash_table_size
x += self.hash_offset.to(x.device)
return x