学习笔记4.3.AI相关-crypten使用gpu运行卷积conv时出现AssertionError: more than one group is unsupported on GPU问题的解决

这篇文章介绍了在使用Crypten库进行AI相关任务时,遇到在GPU上运行卷积conv操作时的AssertionError,解决方法是针对特定问题修改了cuda_tensor.py文件中的__patched_conv_ops方法,移除或处理groups参数以适应GPU环境。
摘要由CSDN通过智能技术生成


学习笔记4.3.AI相关-crypten使用gpu运行卷积conv时出现AssertionError: more than one group is unsupported on GPU问题的解决


参考链接
https://github.com/facebookresearch/CrypTen/issues/386

解决方法
找到crypten/cuda/cuda_tensor.py中的__patched_conv_ops方法,做如下修改:

    @staticmethod
    def __patched_conv_ops(op, x, y, *args, **kwargs):
        if "groups" in kwargs:
            groups = kwargs["groups"]
            #del kwargs["groups"]
        else:
            groups = 1
        bs, c, *img = x.size()
        c_out, c_in, *ks = y.size()
        kernel_elements = functools.reduce(operator.mul, ks)

        nb = 3 if kernel_elements < 256 else 4
        nb2 = nb**2

        x_encoded = CUDALongTensor.__encode_as_fp64(x, nb).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y, nb).data

        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(nb, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=nb, dim=0)

        x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, nb2 * c, *img)
        y_enc_span = y_enc_span.reshape(nb2 * c_out, c_in, *ks)

        c_z = c_out if op in ["conv1d", "conv2d"] else c_in

        if "groups" in kwargs:
            kwargs["groups"] *= nb2
        else:
            kwargs["groups"] = nb2

        z_encoded = getattr(torch, op)(
            x_enc_span, y_enc_span, *args, **kwargs
        )

        groups = kwargs["groups"] // nb2 if op in ["conv_transpose1d", "conv_transpose2d"] else 1
        z_encoded = z_encoded.reshape(bs, nb2, c_z * groups, *z_encoded.size()[2:]).transpose_(
            0, 1
        )

        return CUDALongTensor.__decode_as_int64(z_encoded, nb)


  • 8
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值