#include <vector>
#include <iostream>
#include <numeric>
#include <algorithm>
bool GetNextPos(const std::vector<int>& src_dims_info, std::vector<int>& src_pos_info) {
if (src_pos_info.size() != src_dims_info.size()) {
return false;
}
for (int i = ((int)(src_pos_info.size()) - 1); i >= 0; i--) {
if (src_pos_info[i] + 1 == src_dims_info[i]) {
src_pos_info[i] = 0;
} else {
src_pos_info[i] += 1;
return true;
}
}
return true;
}
bool Transpose(int* dst, const int* src, std::vector<int> shape, std::vector<int> permute) {
if (shape.size() != permute.size()) {
std::cout << "shape size : " << shape.size() << " != permute size : " << permute.size() << std::endl;
return false;
}
std::vector<int> dst_dim(shape.size() + 1, 1);
for (int i = (shape.size() - 1); i >= 0; i--) {
dst_dim[i] = dst_dim[i + 1] * shape[permute[i]];
}
auto get_dst_pos = [&dst_dim, &permute](std::vector<int> index) -> int {
int dst_pos = 0;
for (std::size_t i = 0; i < index.size(); i++) {
dst_pos += (index[permute[i]] * dst_dim[i + 1]);
}
return dst_pos;
};
int total_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
std::vector<int> src_pos_info(shape.size(), 0);
for (int src_pos = 0; src_pos < total_num; src_pos++) {
GetNextPos(shape, src_pos_info);
std::cout << src_pos << " : ";
for (std::size_t i = 0; i < src_pos_info.size(); i++) {
std::cout << src_pos_info[i] << " ";
}
std::cout << "dst pos : ";
for (std::size_t i = 0; i < permute.size(); i++) {
std::cout << src_pos_info[permute[i]] << " ";
}
int dst_pos = get_dst_pos(src_pos_info);
std:: cout << " ptr : " << dst_pos;
dst[dst_pos] = src[src_pos];
std::cout << std::endl;
}
return true;
}
int main() {
std::vector<int> shape{2, 3, 2};
std::vector<int> permute{0, 2, 1};
std::vector<int> src(2*3*2);
std::vector<int> dst(2*3*2);
Transpose(dst.data(), src.data(), shape, permute);
return 0;
}
Tranpose实现
最新推荐文章于 2024-07-25 21:13:04 发布