va_list解决可变参数的传入问题

如何处理可变参数的传入问题

目标:

编写一个函数来检查传入的张量是否符合指定的维度要求。张量维度可以动态变化,因此需要用可变参数来传递维度信息。

步骤 1:定义张量类

首先,我们需要定义一个简单的 Tensor 类,这个类包含张量的数据类型、设备类型以及各维度大小。

#include <iostream>
#include <vector>
#include <cstdarg>

namespace tensor {
    class Tensor {
    public:
        Tensor(std::vector<int32_t> dims, std::string device_type, std::string data_type)
            : dims_(dims), device_type_(device_type), data_type_(data_type) {}

        bool is_empty() const { return dims_.empty(); }
        std::string device_type() const { return device_type_; }
        std::string data_type() const { return data_type_; }
        int32_t dims_size() const { return dims_.size(); }
        int32_t get_dim(int32_t index) const { return dims_[index]; }

    private:
        std::vector<int32_t> dims_;
        std::string device_type_;
        std::string data_type_;
    };
}

namespace base {
    namespace error {
        std::string InvalidArgument(const std::string& message) {
            return "InvalidArgument: " + message;
        }

        std::string Success() {
            return "Success";
        }
    }
}

步骤 2:定义检查函数

我们需要编写一个函数,用于检查张量的维度、设备类型和数据类型是否符合要求。这个函数使用可变参数来传递期望的维度。

namespace base {
    class Layer {
    public:
        std::string check_tensor_with_dim(const tensor::Tensor& tensor,
                                          const std::string& device_type, const std::string& data_type,
                                          ...) const {
            std::va_list args;
            if (tensor.is_empty()) {
                return base::error::InvalidArgument("The tensor parameter is empty.");
            }
            if (tensor.device_type() != device_type) {
                return base::error::InvalidArgument("The tensor has a wrong device type.");
            }
            if (tensor.data_type() != data_type) {
                return base::error::InvalidArgument("The tensor has a wrong data type.");
            }

            va_start(args, data_type);
            int32_t dims = tensor.dims_size();
            for (int32_t i = 0; i < dims; ++i) {
                int32_t expected_dim = va_arg(args, int32_t);
                if (expected_dim != tensor.get_dim(i)) {
                    va_end(args);
                    return base::error::InvalidArgument("The tensor has a wrong dim in dim " + std::to_string(i));
                }
            }
            va_end(args);
            return base::error::Success();
        }
    };
}
步骤 3:编写测试用例

编写一个简单的测试用例来验证我们的函数是否能正确地检查张量的维度。

int main() {
    tensor::Tensor tensor({3, 224, 224}, "GPU", "float32");

    base::Layer layer;

    // 正确的维度
    std::string result = layer.check_tensor_with_dim(tensor, "GPU", "float32", 3, 224, 224);
    std::cout << "Test 1: " << result << std::endl;

    // 错误的维度
    result = layer.check_tensor_with_dim(tensor, "GPU", "float32", 3, 225, 224);
    std::cout << "Test 2: " << result << std::endl;

    // 错误的设备类型
    result = layer.check_tensor_with_dim(tensor, "CPU", "float32", 3, 224, 224);
    std::cout << "Test 3: " << result << std::endl;

    // 错误的数据类型
    result = layer.check_tensor_with_dim(tensor, "GPU", "int32", 3, 224, 224);
    std::cout << "Test 4: " << result << std::endl;

    return 0;
}

 结果:

Test 1: Success
Test 2: InvalidArgument: The tensor has a wrong dim in dim 1
Test 3: InvalidArgument: The tensor has a wrong device type.
Test 4: InvalidArgument: The tensor has a wrong data type.

va_list - cppreference.com

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值