在使用libtorch中经常用到vector和cat使用的情况,在此写了几个例子;cat函数一般有三种使用方式,分别如下:
auto tensor1 = torch::randn({1, 3, 4, 4});
auto tensor2 = torch::randn({1, 3, 4, 4});
//method 1
auto cattensors = torch::cat({tensor1, tensor2});
cout << cattensors.sizes() << endl;
//method 2
vector<torch::Tensor> tensor_vec;
tensor_vec.push_back(tensor1);
tensor_vec.push_back(tensor2);
torch::TensorList tensorlist{tensor_vec};
cattensors = torch::cat(tensorlist);
cout << cattensors.sizes() << endl;
//method 3
vector<torch::Tensor> tensor_vec2;
tensor1 = tensor1.permute({0, 3, 1, 2}).contiguous();
tensor_vec2.push_back(tensor1);
tensor2 = tensor2.permute({0, 3, 1, 2}).contiguous();
tensor_vec2.push_back(tensor2);
auto cattensors2 = torch::cat(tensor_vec2);
cout << cattensors2.sizes() << endl;
需要注意的是,cat拼接tensor时必须时连续的tensor