python实现动态for循环

背景

传入图片、张量可能存在维度不一的情况,如果想用同一个函数处理,可以用到动态for循环来实现。

实现

用pytorch中的张量tensor来作为数据的容器。

这里就没有用到tensor张量中的shape等参数来返回形状,而是选择直接传入,目的就是方便大家理解动态for循环的建立过程。

实现求和

实现动态for循环求不同维度图片、张量所有元素之和。

import torch


def dynamicFor(data, data_size):
    '''
    动态for循环之实现求和
    :param data: 传入的待求数据
    :param data_size: 数据的形状
    :return: 数据data中各元素之和
    '''
    count = 0
    place = 0
    ndim = len(data_size)  # 算出维度
    sum_num = 1
    sum_list = []
    sumx = 1
    for i in range(ndim):
        sum_num *= data_size[i]  # 计算所要计算的总次数
        if i != ndim - 1:
            sumx *= data_size[-(i + 1)]  # 计算在每个维度下的除数,方便后面用求余数的方法得到具体位置
            sum_list.append(sumx)
    sum_list.reverse()  # 记得取反,方便后面索引
    # print(sum_list)
    while place < sum_num:
        position = []  # 存储位置
        current_place = place
        for i in range(ndim):
            if i != ndim - 1:
                position.append(current_place // sum_list[i])
                # 下面这步很重要,去掉多余的位置记录,防止后面位置的小除数除出来越界的位置
                current_place = current_place - current_place // sum_list[i] * sum_list[i]
            else:
                position.append(current_place % sum_list[i - 1])
        result = torch.Tensor(data)
        for position_num in position:  # 访问具体位置以方便进行你要进行的操作,比如这里是求和
            result = result[position_num]
        count += result  # 访问到了具体位置,进行你要的操作,求和
        place += 1
    return count

检验一下是否正确:

dim3 = torch.ones(400).reshape([5, 8, 10])
print(dynamicFor(dim3, [5, 8, 10]))

dim4 = torch.ones(400).reshape([5, 5, 4, 4])
print(dynamicFor(dim4, [5, 5, 4, 4]))

dim5 = torch.ones(400).reshape([2, 4, 2, 5, 5])
print(dynamicFor(dim5, [2, 4, 2, 5, 5]))

结果如下:

tensor(400.)
tensor(400.)
tensor(400.)

实现求范数

(做这个动态for循环,起因就是我搭建网络时试着手打了一个计算网络权重L2范数的损失函数,想到其中要可能要处理不同维度大小的数据,就思考了一下。)

要搭建求不同维度范数的动态for循环,要注意保留最后两个维度不算入位置索引内,而是用位置索引定位到矩阵的位置,即只计算前面的位置索引。

import torch


def dynamicFor4Norm(data, data_size):
    '''
    动态for循环之求矩阵范数。即最后两个维度中矩阵的范数
    :param data: 待求的数据
    :param data_size: 数据data的形状
    :return: 数据data中,所有最后两个维度的矩阵的范数平方之和
    '''
    count = 0
    place = 0
    ndim = len(data_size)  # 算出总维度
    if ndim < 2:  # 如果维度小于2,那么应该就是偏置,就跳过
        return 0
        # raise Exception("Dimension of input is less than 2!")
    elif ndim == 2:  # 维度为2,直接求
        return torch.tensor(data).norm()
    else:
        sum_num = 1
        sum_list = []
        sumx = 1
        for i in range(ndim - 2):  # 因为要保留最后两个的维度,所以去掉最后2个维度的总次数和各维度除数的计算
            sum_num *= data_size[i]  # 计算前面所要计算的总次数
            if i != ndim - 3:
                sumx *= data_size[-2 - (i + 1)]  # 计算前面每个维度下的除数,方便后面用求余数的方法得到具体位置
                sum_list.append(sumx)
        sum_list.reverse()  # 记得取反,方便后面索引
        # print(sum_list)
        while place < sum_num:
            position = []  # 存储矩阵位置,方便便利所有的矩阵
            current_place = place
            for i in range(ndim - 2):
                if i != ndim - 3:
                    position.append(current_place // sum_list[i])
                    # 下面这步很重要,去掉多余的位置记录,防止后面位置的小除数除出来越界的位置
                    current_place = current_place - current_place // sum_list[i] * sum_list[i]
                else:
                    position.append(current_place % sum_list[i - 1])
            result = torch.Tensor(data)
            for position_num in position:  # 访问具体位置以方便进行你要进行的操作,比如这里是求范数的平方
                result = result[position_num]
            count += result.norm()  # 访问到了具体位置,进行你要的操作,求范数
            place += 1
    return count

检验正确:

# 用和前面同样的dim4、dim5
dim4 = torch.ones(400).reshape([5, 5, 4, 4])
dim5 = torch.ones(400).reshape([2, 4, 2, 5, 5])

print(dynamicFor4Norm(dim4, [5, 5, 4, 4]))
print(dynamicFor4Norm(dim5, [2, 4, 2, 5, 5]))

结果:

tensor(100.)
tensor(80.)

小注

一般Loss函数中,都是求权重L2范数的平方,比如:

此时,上面求范数的第43行得改为:

count += result.norm().pow(2)  # 没错,就是加了个pow(2)求平方

此时同样的输入结果为:

tensor(400.)
tensor(400.)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值