这里写目录标题
#资料
avg sse指令查找
简单代码实现数据类型转化
#include <cstring>
#include <cstdint>
#include <immintrin.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TOO_LARGE_STRING_SIZE;
}
constexpr size_t vector_dim = 128;
constexpr size_t vector_length = 256;
class FunctionAngle128h : public IFunction
{
private:
static inline float horizontal_add(__m256 a)
{
__m256 t1 = _mm256_hadd_ps(a,a);
__m256 t2 = _mm256_hadd_ps(t1,t1);
__m128 t3 = _mm256_extractf128_ps(t2,1);
__m128 t4 = _mm_add_ss(_mm256_castps256_ps128(t2),t3);
return _mm_cvtss_f32(t4);
}
static inline float vector_product(float *v1, float *v2, int dim)
{
__m256 sum_ps = _mm256_set_ps(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
int i;
for(i = 0; i < dim; i += 8) {
__m256 f1 = _mm256_loadu_ps(v1 + i);
__m256 f2 = _mm256_loadu_ps(v2 + i);
sum_ps = _mm256_add_ps(sum_ps, _mm256_mul_ps(f1, f2));
}
return horizontal_add(sum_ps);
}
static inline float cosine_avx2(float *v1, float *v2, int dim)
{
float norm1 = vector_product(v1, v1, dim);
float norm2 = vector_product(v2, v2, dim);
float vp = vector_product(v1, v2, dim);
// sse
float rq = _mm_cvtss_f32(_mm_rsqrt_ss(_mm_set_ps1(norm1 * norm2)));
return vp * rq;
}
static inline void frombuffer(const char * src, float *vec, size_t len)
{
for(size_t i = 0; i < len; i += 2) {
uint16_t x = 0;
memcpy(&x, &src[i], 2);
vec[i >> 1] = _cvtsh_ss(x);
}
}
public:
static constexpr auto name = "angle128h";
static FunctionPtr create(ContextPtr)
{
return std::make_shared<FunctionAngle128h>();
}
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override
{
return 2;
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
throw Exception(
"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if(!isString(arguments[1]))
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<DataTypeFloat32>();
}
bool useDefaultImplementationForConstants() const override
{
return true;
}
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
{
return {1};
}
// version 22.5 add
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const IColumn * feature_column = arguments[0].column.get();
const IColumn * input_column = arguments[1].column.get();
auto col_res = ColumnVector<Float32>::create();
typename ColumnVector<Float32>::Container & res_data = col_res->getData();
res_data.resize(input_rows_count);
const auto * input_const = checkAndGetColumnConst<ColumnString>(input_column);
if(!input_const) {
throw Exception("Illegal column " + arguments[1].column->getName()
+ " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
std::string input_string = static_cast<std::string>(input_const->getValue<String>());
if(input_string.size() != vector_length) {
throw Exception("Illegal column input value (len: " + std::to_string(input_string.size()) + ") " + arguments[1].column->getName()
+ " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
const char * input_chars = input_string.c_str();
float input_vector[vector_dim];
frombuffer(input_chars, input_vector, vector_length);
if (const ColumnString * col = checkAndGetColumn<ColumnString>(feature_column))
{
const auto & offsets = col->getOffsets();
const ColumnString::Chars & vec_src = col->getChars();
size_t prev_offset = 0;
for (size_t i = 0; i < input_rows_count; ++i) {
size_t vector_len = offsets[i] - prev_offset - 1;
if(vector_len != vector_length) {
res_data[i] = -1;
} else {
float *vec = new float[vector_dim];
frombuffer(reinterpret_cast<const char *>(&vec_src[prev_offset]), vec, vector_length);
res_data[i] = cosine_avx2(input_vector, vec, vector_dim);
delete[] vec;
}
prev_offset = offsets[i];
}
return col_res;
} else
throw Exception(
"Illegal column " + arguments[0].column->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
}
};
void registerFunctionAngle128h(FunctionFactory & factory)
{
factory.registerFunction<FunctionAngle128h>();
}
}