import haiku as hk
from typing import Union, Sequence
import jax.numpy as jnp
import jax
import numbers
import numpy as np
import pickle
import copy
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
dtype=np.float32)
def get_initializer_scale(initializer_name, input_shape):
"""Get Initializer for weights and scale to multiply activations by."""
if initializer_name == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
# fan-in scaling
scale = 1.
for channel_dim in input_shape:
# 除以每个维度的值
scale /= channel_dim
if initializer_name == 'relu':
scale *= 2
noise_scale = scale
stddev = np.sqrt(noise_scale)
# Adjust stddev for truncation.
stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
# 截断正态分布的随机数 (mean - 2 * stddev, mean + 2 * stddev)
w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
# hk.initializers 模块包含了一系列用于初始化模型参数的初始化器。
# 常见的 hk.initializers 模块中的初始化器:
# 1. hk.initializers.Constant(value)
# 2. hk.initializers.RandomNormal(stddev=1.0)
# 3. hk.initializers.TruncatedNormal(stddev=1.0)
# 4. hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal')
# 5. hk.initializers.Orthogonal(gain=1.0)
# 6. hk.initializers.IdentityGain()
return w_init
class Linear(hk.Module):
"""Protein folding specific Linear module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs and outputs of arbitrary rank
* Initializers are specified by strings
"""
def __init__(self,
num_output: Union[int, Sequence[int]],
initializer: str = 'linear',
num_input_dims: int = 1,
use_bias: bool = True,
bias_init: float = 0.,
precision = None,
name: str = 'linear'):
"""Constructs Linear Module.
Args:
num_output: Number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
precision: What precision to use for matrix multiplication, defaults
to None.
name: Name of module, used for name scopes.
"""
super().__init__(name=name)
if isinstance(num_output, numbers.Integral):
self.output_shape = (num_output,)
else:
self.output_shape = tuple(num_output)
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
self.num_input_dims = num_input_dims
self.num_output_dims = len(self.output_shape)
self.precision = precision
def __call__(self, inputs):
"""Connects Module.
Args:
inputs: Tensor with at least num_input_dims dimensions.
Returns:
output of shape [...] + num_output.
"""
num_input_dims = self.num_input_dims
if self.num_input_dims > 0:
in_shape = inputs.shape[-self.num_input_dims:]
else:
in_shape = ()
# 注意初始化weights的数据分布,这样初始化的优点。
weight_init = get_initializer_scale(self.initializer, in_shape)
in_letters = 'abcde'[:self.num_input_dims]
out_letters = 'hijkl'[:self.num_output_dims]
# weights维度是输入数据维度和输出数据维度的合并
weight_shape = in_shape + self.output_shape
# hk.get_parameter:从参数字典中获取参数:四个参数,依次为:
# 1. 参数的名称(字符串),用于唯一标识该参数。
# 2. shape 参数指定了参数的形状。
# 3. dtype 参数指定了参数的数据类型。
# 4. 初始化器,用于设置参数的初始值。可以使用 hk.initializers 模块中的各种初始化器。
weights = hk.get_parameter('weights', weight_shape, inputs.dtype, weight_init)
equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
# equation 字符串类型 ...abc, abch->...h
# jnp.einsum 和np.dot都可以用于执行矩阵乘法
# jnp.einsum 是一个强大的工具,允许你以字符串形式指定矩阵乘法的具体计算规则
# jnp.einsum 则提供了更大的灵活性,允许你执行更复杂的张量操作
output = jnp.einsum(equation, inputs, weights, precision=self.precision)
if self.use_bias:
bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias
return output
with open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:
Human_HBB_tensor_dict = pickle.load(f)
#for k, v in Human_HBB_tensor_dict.items():
# print(v.shape)
#Human_HBB_tensor_dict.keys()
batch = copy.deepcopy(Human_HBB_tensor_dict)
print(batch['aatype'].shape)
print(batch['msa_feat'].shape)
#print(batch['msa_feat'])
msa_channel = 16 # 为了演示,设置大了,计算速度慢
input_data = batch['msa_feat'].numpy()
print(type(input_data))
# 转换为Haiku模块
model = hk.transform(lambda x: Linear(msa_channel,
name='preprocess_msa',
num_input_dims = 1)(input_data))
print(model)
rng = jax.random.PRNGKey(42)
# print(rng)
## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
params = model.init(rng, input_data)
# print(params)
print("params weights shape:")
print(params['preprocess_msa']['weights'].shape)
print("params weights bias:")
print(params['preprocess_msa']['bias'].shape)
output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape)
print("Output Data shape:", output_data.shape)
#print("Output Data:", output_data)
haiku自定义线性模块
于 2024-01-05 10:15:14 首次发布