hk.module must be initialized inside an hk.transform
临时解决方法:
def test_output_dimension(self, final_endpoint, expected_shape):
input_shape = (2, 32, 224, 224, 3)
def f():
data = jnp.zeros(input_shape)
net = tsm_resnet.TSMResNetV2()
return net(data, final_endpoint=final_endpoint)
init_fn, apply_fn = hk.transform(f)
out = apply_fn(init_fn(jax.random.PRNGKey(42)), None)
self.assertEqual(out.shape, expected_shape)
参考:
https://ask.csdn.net/questions/1221077
python
class Parameter():
def __init__(self, init_value: float, name: Tex