def includes(): out = " \ #include \n \ #include \n \ #include \n \ #include \n \ #include \n \ \n \ #define mymin(a,b) ((a)<(b)?(a):(b))\n \ #define mymax(a,b) ((a)>(b)?(a):(b))\n \ " return out def module(bits_list=[4, 2]): out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n' for bits in bits_list: out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits) for bits in bits_list: out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits) for bits in bits_list: out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits) for bits in bits_list: out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits) out += 'm.def("compute_reduction_cpp", &compute_reduction);\n' out += 'm.def("unquantize_sim", &unquantize_sim);\n' # if oracle: # out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n' out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n' out += '}\n' return out def quant_scalar(): out = " \ void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \ //find scales and zeros arrays \n \ //quantize \n \ int pack = 32/bits;\n \ for (int j = 0; j < m; j++){\n \ for (int i = 0; i < n; i+=pack){\n \ uint32_t acc = 0;\n \ for (int ii = i; ii < i+pack; ii++){\n \ float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \ int temp = (int)ftemp;\n \ acc = acc | (temp << (bits*(ii-i)));\n \ }\n \ BQ[(i/pack)*m+j] = acc;\n \ //BQ[0] = acc;\n \ }\n \ }\n \ }\n \ \n \ void quant_scalar_cpu(\n \ torch::Tensor in, torch::Tensor out, \n \ torch::Tensor scales, torch::Tensor zeros, int bits\n \ ) {\n \ \n \ int N = in.size(0);\n \ int M = in.size(1);\n \ \n \ float* input = in.data_ptr(); \n \ float* s = scales.data_ptr();\n \ float* z = zeros.data_ptr();\n \ int* O = out.data_ptr();\n \ \n \ quantize_scalar(input, O, s, z, N, M, bits);\n \ \n \ }\n" return out