【小论文代码】TensorFlow与pytorch的转换工作

文章详细列举了将TensorFlow代码转换为PyTorch代码的过程中涉及的关键操作和等效函数,包括张量初始化、操作转换如one_hot、matmul、slice、transpose等,以及动态分区的实现。作者在转换过程中遇到了一些挑战,如某些TensorFlow功能在PyTorch中的对应实现。
摘要由CSDN通过智能技术生成

因为baseline是很新的文章,目前paper with code只有作者一份代码,还是TensorFlow版本的,本来打算就这么用下去,可以无奈修改模型的过程实在太痛苦,我想用的很多新的模型TensorFlow版的代码都有一些问题,无奈硬着头皮把TensorFlow改为pytorch吧。

TensorFlowpytorch
tf.placeholder("float", [None, num_inputs])删掉
tf.Variable()torch.tensor(... ,requires_grad=True)可学习的参数在外面再套一层nn.Parameter(,requires_grad=True)
tf.random_normal([])np.random.normal(size=[])种子pytorch要np.random.seed(seed)
tf.one_hottorch.one_hottorch的功能相对较少,我需要转换的代码是tf.one_hot([1]*hidden, depth=num_inputs, on_value=0.0, off_value=1.0, axis=-1))得到某一列全零,其余列全1的矩阵,然后进行后续的mask,我改为torch.tensor(np.insert(np.zeros([25,40]),1,[1]*hidden,0)).bool(),一行全true,其余false,然后用masked_fill将true的行置为0
tf.matmultorch.matmul
tf.slice(x, [a, b], [c,d])x[a:c,b:d]
tf.linalg.trace(x)torch.trace(x)测试的两维数组是等价的
tf.cast(d, tf.float32)torch.Tensor()或者把非float32的Tensor用x.float()torch.Tensor(a),a如果是一个数字,会返回一个包含a个元素的Tensor
tf.reduce_sum(x,axis=1,keepdims=True)torch.sum(x,dim=1,keepdim=True)
tf.square()torch.square()
tf.concat(x,axis=1)torch.cat(x, dim=1)
tf.norm(x,axis=1)默认x的2范数x.norm(p=2,dim=1)修改的代码用的tf.norm默认的范数,所以这里torch用的2范数进行对应
tf.transpose(x)x.permute(1,0)默认情况下,tf.transpose会在2D输入张量上执行常规矩阵转置,因为修改的代码为2维,无参数,因此转换为x.permute(1,0)
tf.dynamic_partitionpytorch中没有自己实现,见下面代码data与partitions都是Tensor,num_partitions为数字
def dynamic_partition(data, partitions, num_partitions):
    # Create a list of indices for each partition
    indices = [torch.nonzero(partitions == i)[:, 0] for i in range(num_partitions)]
    # Split the data tensor into a list of tensors based on the indices
    partitions = [torch.index_select(data, dim=0, index=index) for index in indices]
    return partitions

输入的值,返回的值都与TensorFlow的dynamic_partition函数相同

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>