85 lines
2.1 KiB
Python
85 lines
2.1 KiB
Python
|
|
def includes():
|
|
out = " \
|
|
#include <torch/all.h>\n \
|
|
#include <torch/python.h>\n \
|
|
#include <omp.h>\n \
|
|
#include <cmath>\n \
|
|
#include <immintrin.h>\n \
|
|
\n \
|
|
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
|
|
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
|
|
"
|
|
return out
|
|
|
|
|
|
def module(bits_list=[4, 2]):
|
|
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
|
|
for bits in bits_list:
|
|
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
|
|
|
|
for bits in bits_list:
|
|
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
|
|
|
|
for bits in bits_list:
|
|
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
|
|
|
|
for bits in bits_list:
|
|
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
|
|
|
|
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
|
|
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
|
|
|
|
# if oracle:
|
|
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
|
|
|
|
|
|
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
|
|
|
|
out += '}\n'
|
|
return out
|
|
|
|
def quant_scalar():
|
|
out = " \
|
|
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
|
|
//find scales and zeros arrays \n \
|
|
//quantize \n \
|
|
int pack = 32/bits;\n \
|
|
for (int j = 0; j < m; j++){\n \
|
|
for (int i = 0; i < n; i+=pack){\n \
|
|
uint32_t acc = 0;\n \
|
|
for (int ii = i; ii < i+pack; ii++){\n \
|
|
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
|
|
int temp = (int)ftemp;\n \
|
|
acc = acc | (temp << (bits*(ii-i)));\n \
|
|
}\n \
|
|
BQ[(i/pack)*m+j] = acc;\n \
|
|
//BQ[0] = acc;\n \
|
|
}\n \
|
|
}\n \
|
|
}\n \
|
|
\n \
|
|
void quant_scalar_cpu(\n \
|
|
torch::Tensor in, torch::Tensor out, \n \
|
|
torch::Tensor scales, torch::Tensor zeros, int bits\n \
|
|
) {\n \
|
|
\n \
|
|
int N = in.size(0);\n \
|
|
int M = in.size(1);\n \
|
|
\n \
|
|
float* input = in.data_ptr<float>(); \n \
|
|
float* s = scales.data_ptr<float>();\n \
|
|
float* z = zeros.data_ptr<float>();\n \
|
|
int* O = out.data_ptr<int>();\n \
|
|
\n \
|
|
quantize_scalar(input, O, s, z, N, M, bits);\n \
|
|
\n \
|
|
}\n"
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|