From 07021b9a1c60fff6151d65caf3a44b0f0e49d877 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Thu, 26 Oct 2023 10:26:42 -0400 Subject: [PATCH] Generated files so that when they fail to work in pipeline then it still continues with what should be some ok defaults --- autogptq_extension/qigen/backend.cpp | 2428 ++++++++++++++++++++++++++ autogptq_extension/qigen/foo | 0 autogptq_extension/qigen/forward.h | 480 +++++ autogptq_extension/qigen/mmm | Bin 0 -> 41056 bytes autogptq_extension/qigen/tmp.csv | 37 + 5 files changed, 2945 insertions(+) create mode 100644 autogptq_extension/qigen/backend.cpp create mode 100644 autogptq_extension/qigen/foo create mode 100644 autogptq_extension/qigen/forward.h create mode 100755 autogptq_extension/qigen/mmm create mode 100644 autogptq_extension/qigen/tmp.csv diff --git a/autogptq_extension/qigen/backend.cpp b/autogptq_extension/qigen/backend.cpp new file mode 100644 index 0000000..00b9aff --- /dev/null +++ b/autogptq_extension/qigen/backend.cpp @@ -0,0 +1,2428 @@ + #include + #include + #include + #include + #include + + #define mymin(a,b) ((a)<(b)?(a):(b)) + #define mymax(a,b) ((a)>(b)?(a):(b)) + void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ + //find scales and zeros arrays + //quantize + int pack = 32/bits; + for (int j = 0; j < m; j++){ + for (int i = 0; i < n; i+=pack){ + uint32_t acc = 0; + for (int ii = i; ii < i+pack; ii++){ + float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]); + int temp = (int)ftemp; + acc = acc | (temp << (bits*(ii-i))); + } + BQ[(i/pack)*m+j] = acc; + //BQ[0] = acc; + } + } + } + + void quant_scalar_cpu( + torch::Tensor in, torch::Tensor out, + torch::Tensor scales, torch::Tensor zeros, int bits + ) { + + int N = in.size(0); + int M = in.size(1); + + float* input = in.data_ptr(); + float* s = scales.data_ptr(); + float* z = zeros.data_ptr(); + int* O = out.data_ptr(); + + quantize_scalar(input, O, s, z, N, M, bits); + + } +void compute_reduction_cpu(const float* in, float* out, int n, int m, int gs){ +#pragma omp parallel num_threads(8) +{ +#pragma omp for collapse(2) +for(int i = 0; i < n; i++){ +for(int j0 = 0; j0 < m; j0+=gs){ +__m256 acc = _mm256_setzero_ps(); +for(int j1 = j0; j1 < j0+gs; j1+=8){ +__m256 x = _mm256_loadu_ps(&in[i*m + j1]); +acc = _mm256_add_ps(acc, x); +} +const __m128 hiQuad = _mm256_extractf128_ps(acc, 1); +const __m128 loQuad = _mm256_castps256_ps128(acc); +const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad); +const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad); +const __m128 sumDual = _mm_add_ps(sumQuad, hiDual); +const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1); +const __m128 sum = _mm_add_ss(hi, sumDual); +out[(i*m + j0)/gs] = _mm_cvtss_f32(sum); +} +} +} +} +void compute_reduction(torch::Tensor in, torch::Tensor out, int N, int M, int gs){ +float* I = in.data_ptr(); +float* O = out.data_ptr(); +compute_reduction_cpu(I, O, N, M, gs); +} +void unquantize_sim_cpu(const int* in, float* out, float* s, float* z, int n, int m, int bits, int gs){ +#pragma omp parallel num_threads(8) +{ +int packed = 32/bits; +int mask = (1<> (bits*k)) & mask) - z[(row)*m + j0]) * s[(row)*m + j0]; +} +} +} +} +} +} +void unquantize_sim(torch::Tensor in, torch::Tensor out, torch::Tensor s, torch::Tensor z, int N, int M, int bits, int gs){ +int* I = in.data_ptr(); +float* O = out.data_ptr(); +float* S = s.data_ptr(); +float* Z = z.data_ptr(); +unquantize_sim_cpu(I, O, S, Z, N, M, bits, gs); +} +inline +void q2gemm(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 16; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(3); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16: + tid*tt*m/16; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +for(; j1 < tb-tu+1; j1+=tu) { +__m256 acc0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 acc0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +for(int k1 = 0; k1 < mb; k1+=mu) { +for(int k2 = k1; k2 < k1+mu; k2+=16){ +__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]); +__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]); +__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]); +__m256i ws0_15 = _mm256_srli_epi32(w0, 30); +__m256i ws8_15 = _mm256_srli_epi32(w8, 30); +__m256i wsa0_15= _mm256_and_si256(ws0_15, mask); +__m256i wsa8_15= _mm256_and_si256(ws8_15, mask); +__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15); +__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15); +acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8); +__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]); +__m256i ws0_14 = _mm256_srli_epi32(w0, 28); +__m256i ws8_14 = _mm256_srli_epi32(w8, 28); +__m256i wsa0_14= _mm256_and_si256(ws0_14, mask); +__m256i wsa8_14= _mm256_and_si256(ws8_14, mask); +__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14); +__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14); +acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8); +__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]); +__m256i ws0_13 = _mm256_srli_epi32(w0, 26); +__m256i ws8_13 = _mm256_srli_epi32(w8, 26); +__m256i wsa0_13= _mm256_and_si256(ws0_13, mask); +__m256i wsa8_13= _mm256_and_si256(ws8_13, mask); +__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13); +__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13); +acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8); +__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]); +__m256i ws0_12 = _mm256_srli_epi32(w0, 24); +__m256i ws8_12 = _mm256_srli_epi32(w8, 24); +__m256i wsa0_12= _mm256_and_si256(ws0_12, mask); +__m256i wsa8_12= _mm256_and_si256(ws8_12, mask); +__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12); +__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12); +acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8); +__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]); +__m256i ws0_11 = _mm256_srli_epi32(w0, 22); +__m256i ws8_11 = _mm256_srli_epi32(w8, 22); +__m256i wsa0_11= _mm256_and_si256(ws0_11, mask); +__m256i wsa8_11= _mm256_and_si256(ws8_11, mask); +__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11); +__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11); +acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8); +__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]); +__m256i ws0_10 = _mm256_srli_epi32(w0, 20); +__m256i ws8_10 = _mm256_srli_epi32(w8, 20); +__m256i wsa0_10= _mm256_and_si256(ws0_10, mask); +__m256i wsa8_10= _mm256_and_si256(ws8_10, mask); +__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10); +__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10); +acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8); +__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]); +__m256i ws0_9 = _mm256_srli_epi32(w0, 18); +__m256i ws8_9 = _mm256_srli_epi32(w8, 18); +__m256i wsa0_9= _mm256_and_si256(ws0_9, mask); +__m256i wsa8_9= _mm256_and_si256(ws8_9, mask); +__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9); +__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9); +acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8); +__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]); +__m256i ws0_8 = _mm256_srli_epi32(w0, 16); +__m256i ws8_8 = _mm256_srli_epi32(w8, 16); +__m256i wsa0_8= _mm256_and_si256(ws0_8, mask); +__m256i wsa8_8= _mm256_and_si256(ws8_8, mask); +__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8); +__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8); +acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256i ws0_7 = _mm256_srli_epi32(w0, 14); +__m256i ws8_7 = _mm256_srli_epi32(w8, 14); +__m256i wsa0_7= _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7= _mm256_and_si256(ws8_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256i ws0_6 = _mm256_srli_epi32(w0, 12); +__m256i ws8_6 = _mm256_srli_epi32(w8, 12); +__m256i wsa0_6= _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6= _mm256_and_si256(ws8_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256i ws0_5 = _mm256_srli_epi32(w0, 10); +__m256i ws8_5 = _mm256_srli_epi32(w8, 10); +__m256i wsa0_5= _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5= _mm256_and_si256(ws8_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256i ws0_4 = _mm256_srli_epi32(w0, 8); +__m256i ws8_4 = _mm256_srli_epi32(w8, 8); +__m256i wsa0_4= _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4= _mm256_and_si256(ws8_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256i ws0_3 = _mm256_srli_epi32(w0, 6); +__m256i ws8_3 = _mm256_srli_epi32(w8, 6); +__m256i wsa0_3= _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3= _mm256_and_si256(ws8_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256i ws0_2 = _mm256_srli_epi32(w0, 4); +__m256i ws8_2 = _mm256_srli_epi32(w8, 4); +__m256i wsa0_2= _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2= _mm256_and_si256(ws8_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256i ws0_1 = _mm256_srli_epi32(w0, 2); +__m256i ws8_1 = _mm256_srli_epi32(w8, 2); +__m256i wsa0_1= _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1= _mm256_and_si256(ws8_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w8, 0); +__m256i wsa0_0= _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0= _mm256_and_si256(ws8_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +} +} +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], acc0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], acc0_8); +} +} +} +} +} +#pragma omp barrier +for (int i = 0; i < n; i++) { +__m256 r = _mm256_set1_ps(sums[i]); +for (int j = 0; j < tt; j+=16){ +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + j + 8]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 s0 = _mm256_loadu_ps(&scales[base_output + j + 0]); +__m256 s8 = _mm256_loadu_ps(&scales[base_output + j + 8]); +__m256 zr0 = _mm256_fnmadd_ps(z0, r, o0); +__m256 zr8 = _mm256_fnmadd_ps(z8, r, o8); +__m256 o20 = _mm256_fmadd_ps(zr0, s0, b0); +__m256 o28 = _mm256_fmadd_ps(zr8, s8, b8); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +} +} +} +} +inline void forward2_cpu( +torch::Tensor in, torch::Tensor weight, torch::Tensor out, +torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums, +int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){ +int* W = weight.data_ptr(); +float* input = in.data_ptr(); +float* b = bias.data_ptr(); +float* s = scales.data_ptr(); +float* z = zeros.data_ptr(); +float* r = sums.data_ptr(); +float* O = out.data_ptr(); + +q2gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff); +} +inline +void q2gemm_gs(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int gs, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 32; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(3); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16: + tid*tt*m/16; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +for(; j1 < tb-tu+1; j1+=tu) { +for(int k1 = 0; k1 < mb; k1+=gs) { +__m256 acc0_0 = _mm256_setzero_ps(); +__m256 acc0_8 = _mm256_setzero_ps(); +__m256 acc0_16 = _mm256_setzero_ps(); +__m256 acc0_24 = _mm256_setzero_ps(); +for(int k2 = k1; k2 < k1+gs; k2+=16) +{ +__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]); +__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]); +__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+16]); +__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+24]); +__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]); +__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]); +__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]); +__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]); +__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]); +__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]); +__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]); +__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]); +__m256i ws0_8 = _mm256_srli_epi32(w0, 16); +__m256i ws8_8 = _mm256_srli_epi32(w8, 16); +__m256i ws16_8 = _mm256_srli_epi32(w16, 16); +__m256i ws24_8 = _mm256_srli_epi32(w24, 16); +__m256i wsa0_8= _mm256_and_si256(ws0_8, mask); +__m256i wsa8_8= _mm256_and_si256(ws8_8, mask); +__m256i wsa16_8= _mm256_and_si256(ws16_8, mask); +__m256i wsa24_8= _mm256_and_si256(ws24_8, mask); +__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8); +__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8); +__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8); +__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8); +acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24); +__m256i ws0_9 = _mm256_srli_epi32(w0, 18); +__m256i ws8_9 = _mm256_srli_epi32(w8, 18); +__m256i ws16_9 = _mm256_srli_epi32(w16, 18); +__m256i ws24_9 = _mm256_srli_epi32(w24, 18); +__m256i wsa0_9= _mm256_and_si256(ws0_9, mask); +__m256i wsa8_9= _mm256_and_si256(ws8_9, mask); +__m256i wsa16_9= _mm256_and_si256(ws16_9, mask); +__m256i wsa24_9= _mm256_and_si256(ws24_9, mask); +__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9); +__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9); +__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9); +__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9); +acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24); +__m256i ws0_10 = _mm256_srli_epi32(w0, 20); +__m256i ws8_10 = _mm256_srli_epi32(w8, 20); +__m256i ws16_10 = _mm256_srli_epi32(w16, 20); +__m256i ws24_10 = _mm256_srli_epi32(w24, 20); +__m256i wsa0_10= _mm256_and_si256(ws0_10, mask); +__m256i wsa8_10= _mm256_and_si256(ws8_10, mask); +__m256i wsa16_10= _mm256_and_si256(ws16_10, mask); +__m256i wsa24_10= _mm256_and_si256(ws24_10, mask); +__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10); +__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10); +__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10); +__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10); +acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24); +__m256i ws0_11 = _mm256_srli_epi32(w0, 22); +__m256i ws8_11 = _mm256_srli_epi32(w8, 22); +__m256i ws16_11 = _mm256_srli_epi32(w16, 22); +__m256i ws24_11 = _mm256_srli_epi32(w24, 22); +__m256i wsa0_11= _mm256_and_si256(ws0_11, mask); +__m256i wsa8_11= _mm256_and_si256(ws8_11, mask); +__m256i wsa16_11= _mm256_and_si256(ws16_11, mask); +__m256i wsa24_11= _mm256_and_si256(ws24_11, mask); +__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11); +__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11); +__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11); +__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11); +acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24); +__m256i ws0_12 = _mm256_srli_epi32(w0, 24); +__m256i ws8_12 = _mm256_srli_epi32(w8, 24); +__m256i ws16_12 = _mm256_srli_epi32(w16, 24); +__m256i ws24_12 = _mm256_srli_epi32(w24, 24); +__m256i wsa0_12= _mm256_and_si256(ws0_12, mask); +__m256i wsa8_12= _mm256_and_si256(ws8_12, mask); +__m256i wsa16_12= _mm256_and_si256(ws16_12, mask); +__m256i wsa24_12= _mm256_and_si256(ws24_12, mask); +__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12); +__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12); +__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12); +__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12); +acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24); +__m256i ws0_13 = _mm256_srli_epi32(w0, 26); +__m256i ws8_13 = _mm256_srli_epi32(w8, 26); +__m256i ws16_13 = _mm256_srli_epi32(w16, 26); +__m256i ws24_13 = _mm256_srli_epi32(w24, 26); +__m256i wsa0_13= _mm256_and_si256(ws0_13, mask); +__m256i wsa8_13= _mm256_and_si256(ws8_13, mask); +__m256i wsa16_13= _mm256_and_si256(ws16_13, mask); +__m256i wsa24_13= _mm256_and_si256(ws24_13, mask); +__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13); +__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13); +__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13); +__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13); +acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24); +__m256i ws0_14 = _mm256_srli_epi32(w0, 28); +__m256i ws8_14 = _mm256_srli_epi32(w8, 28); +__m256i ws16_14 = _mm256_srli_epi32(w16, 28); +__m256i ws24_14 = _mm256_srli_epi32(w24, 28); +__m256i wsa0_14= _mm256_and_si256(ws0_14, mask); +__m256i wsa8_14= _mm256_and_si256(ws8_14, mask); +__m256i wsa16_14= _mm256_and_si256(ws16_14, mask); +__m256i wsa24_14= _mm256_and_si256(ws24_14, mask); +__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14); +__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14); +__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14); +__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14); +acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24); +__m256i ws0_15 = _mm256_srli_epi32(w0, 30); +__m256i ws8_15 = _mm256_srli_epi32(w8, 30); +__m256i ws16_15 = _mm256_srli_epi32(w16, 30); +__m256i ws24_15 = _mm256_srli_epi32(w24, 30); +__m256i wsa0_15= _mm256_and_si256(ws0_15, mask); +__m256i wsa8_15= _mm256_and_si256(ws8_15, mask); +__m256i wsa16_15= _mm256_and_si256(ws16_15, mask); +__m256i wsa24_15= _mm256_and_si256(ws24_15, mask); +__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15); +__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15); +__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15); +__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15); +acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w8, 0); +__m256i ws16_0 = _mm256_srli_epi32(w16, 0); +__m256i ws24_0 = _mm256_srli_epi32(w24, 0); +__m256i wsa0_0= _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0= _mm256_and_si256(ws8_0, mask); +__m256i wsa16_0= _mm256_and_si256(ws16_0, mask); +__m256i wsa24_0= _mm256_and_si256(ws24_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0); +__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24); +__m256i ws0_1 = _mm256_srli_epi32(w0, 2); +__m256i ws8_1 = _mm256_srli_epi32(w8, 2); +__m256i ws16_1 = _mm256_srli_epi32(w16, 2); +__m256i ws24_1 = _mm256_srli_epi32(w24, 2); +__m256i wsa0_1= _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1= _mm256_and_si256(ws8_1, mask); +__m256i wsa16_1= _mm256_and_si256(ws16_1, mask); +__m256i wsa24_1= _mm256_and_si256(ws24_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1); +__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24); +__m256i ws0_2 = _mm256_srli_epi32(w0, 4); +__m256i ws8_2 = _mm256_srli_epi32(w8, 4); +__m256i ws16_2 = _mm256_srli_epi32(w16, 4); +__m256i ws24_2 = _mm256_srli_epi32(w24, 4); +__m256i wsa0_2= _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2= _mm256_and_si256(ws8_2, mask); +__m256i wsa16_2= _mm256_and_si256(ws16_2, mask); +__m256i wsa24_2= _mm256_and_si256(ws24_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2); +__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24); +__m256i ws0_3 = _mm256_srli_epi32(w0, 6); +__m256i ws8_3 = _mm256_srli_epi32(w8, 6); +__m256i ws16_3 = _mm256_srli_epi32(w16, 6); +__m256i ws24_3 = _mm256_srli_epi32(w24, 6); +__m256i wsa0_3= _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3= _mm256_and_si256(ws8_3, mask); +__m256i wsa16_3= _mm256_and_si256(ws16_3, mask); +__m256i wsa24_3= _mm256_and_si256(ws24_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3); +__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24); +__m256i ws0_4 = _mm256_srli_epi32(w0, 8); +__m256i ws8_4 = _mm256_srli_epi32(w8, 8); +__m256i ws16_4 = _mm256_srli_epi32(w16, 8); +__m256i ws24_4 = _mm256_srli_epi32(w24, 8); +__m256i wsa0_4= _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4= _mm256_and_si256(ws8_4, mask); +__m256i wsa16_4= _mm256_and_si256(ws16_4, mask); +__m256i wsa24_4= _mm256_and_si256(ws24_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4); +__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24); +__m256i ws0_5 = _mm256_srli_epi32(w0, 10); +__m256i ws8_5 = _mm256_srli_epi32(w8, 10); +__m256i ws16_5 = _mm256_srli_epi32(w16, 10); +__m256i ws24_5 = _mm256_srli_epi32(w24, 10); +__m256i wsa0_5= _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5= _mm256_and_si256(ws8_5, mask); +__m256i wsa16_5= _mm256_and_si256(ws16_5, mask); +__m256i wsa24_5= _mm256_and_si256(ws24_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5); +__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24); +__m256i ws0_6 = _mm256_srli_epi32(w0, 12); +__m256i ws8_6 = _mm256_srli_epi32(w8, 12); +__m256i ws16_6 = _mm256_srli_epi32(w16, 12); +__m256i ws24_6 = _mm256_srli_epi32(w24, 12); +__m256i wsa0_6= _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6= _mm256_and_si256(ws8_6, mask); +__m256i wsa16_6= _mm256_and_si256(ws16_6, mask); +__m256i wsa24_6= _mm256_and_si256(ws24_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6); +__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24); +__m256i ws0_7 = _mm256_srli_epi32(w0, 14); +__m256i ws8_7 = _mm256_srli_epi32(w8, 14); +__m256i ws16_7 = _mm256_srli_epi32(w16, 14); +__m256i ws24_7 = _mm256_srli_epi32(w24, 14); +__m256i wsa0_7= _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7= _mm256_and_si256(ws8_7, mask); +__m256i wsa16_7= _mm256_and_si256(ws16_7, mask); +__m256i wsa24_7= _mm256_and_si256(ws24_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7); +__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24); +} +__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]); +__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]); +__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]); +__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]); +__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]); +__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]); +__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0); +__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8); +__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16); +__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24); +} +} +} +} +} +} +#pragma omp barrier +const int ngs = m/gs; +for (int i = 0; i < n; i++) { +for (int j = 0; j < tt; j+=32){ +__m256 acc0 = _mm256_setzero_ps(); +__m256 acc8 = _mm256_setzero_ps(); +__m256 acc16 = _mm256_setzero_ps(); +__m256 acc24 = _mm256_setzero_ps(); +for (int i1 = 0; i1 < ngs; i1++){ +__m256 r = _mm256_set1_ps(sums[i*ngs + i1]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]); +__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]); +__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]); +__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]); +__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]); +__m256 s16 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 16]); +__m256 s24 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 24]); +__m256 zs0 = _mm256_mul_ps(z0, s0); +__m256 zs8 = _mm256_mul_ps(z8, s8); +__m256 zs16 = _mm256_mul_ps(z16, s16); +__m256 zs24 = _mm256_mul_ps(z24, s24); +acc0 = _mm256_fmadd_ps(zs0, r, acc0); +acc8 = _mm256_fmadd_ps(zs8, r, acc8); +acc16 = _mm256_fmadd_ps(zs16, r, acc16); +acc24 = _mm256_fmadd_ps(zs24, r, acc24); +} +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]); +__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]); +__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]); +__m256 o10 = _mm256_sub_ps(o0, acc0); +__m256 o18 = _mm256_sub_ps(o8, acc8); +__m256 o116 = _mm256_sub_ps(o16, acc16); +__m256 o124 = _mm256_sub_ps(o24, acc24); +__m256 o20 = _mm256_add_ps(o10, b0); +__m256 o28 = _mm256_add_ps(o18, b8); +__m256 o216 = _mm256_add_ps(o116, b16); +__m256 o224 = _mm256_add_ps(o124, b24); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216); +_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224); +} +} +} +} +inline void forward2_gs_cpu( +torch::Tensor in, torch::Tensor weight, torch::Tensor out, +torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums, +int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){ +int* W = weight.data_ptr(); +float* input = in.data_ptr(); +float* b = bias.data_ptr(); +float* s = scales.data_ptr(); +float* z = zeros.data_ptr(); +float* r = sums.data_ptr(); +float* O = out.data_ptr(); + +q2gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff); +} +inline void pack2_qw_inner(int* A, int* B, const int N, const int M, const int nb, int mb, int cutoff){ +// copy the full matrix A in blocked format into B +uint64_t idx = 0; +for(int j = 0, tid = 0; j < M; j+=mb, tid++){ +for(int i = 0; i < N; i+=nb){ + for(int ii = i; ii < mymin(i+nb, N); ii++){ + for(int jj = j; jj < mymin(j+mb, M); jj++){ + B[idx] = A[ii*M+jj]; + idx++; + } + } + } +} +} +inline void pack2_w_cpu( +torch::Tensor in, torch::Tensor out, +int N, int M, int nb, int mb, int cutoff){ +int* input = in.data_ptr(); +int* O = out.data_ptr(); + pack2_qw_inner(input, O, N, M, nb, mb, cutoff); +} +void unpack_zeros2_cpu(const int* zv, float* ov, int n, int m){ +const __m256i ones = _mm256_set1_epi32(1); +const __m256i mask = _mm256_set1_epi32(3); +const __m256i shift0 = _mm256_set_epi32(30,28,26,24,22,20,18,16); +const __m256i shift1 = _mm256_set_epi32(14,12,10,8,6,4,2,0); +for(int i = 0; i < n; i++){ +for (int j = 0; j < m; j+=16){ +for (int k = 0; k < 16; k++){ +ov[i*m + j+k] = (((zv[j/16] >> (2*k)) & 3)+1); +} +} +} +} +void unpack_zeros2(torch::Tensor zeros, torch::Tensor out, int N, int M){ +int* Z = zeros.data_ptr(); +float* O = out.data_ptr(); +unpack_zeros2_cpu(Z, O, N, M); +} +inline +void q3gemm(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 16; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(7); +const __m256i mask4 = _mm256_set1_epi32(4); +const __m256i mask6 = _mm256_set1_epi32(6); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/32*3: + tid*tt*m/32*3; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +int jw = 0; +for(; j1 < tb-tu+1; j1+=tu, jw+=48){ +__m256 acc0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 acc0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +for(int k1 = 0; k1 < mb; k1+=mu) { +for(int k2 = k1; k2 < k1+mu; k2+=32){ +__m256i w0_0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+0]); +__m256i w1_0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+0+8]); +__m256i w2_0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+0+16]); +__m256i w0_8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+24]); +__m256i w1_8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+24+8]); +__m256i w2_8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+24+16]); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0_0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w0_8, 0); +__m256i wsa0_0 = _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0 = _mm256_and_si256(ws8_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +__m256i ws0_1 = _mm256_srli_epi32(w0_0, 3); +__m256i ws8_1 = _mm256_srli_epi32(w0_8, 3); +__m256i wsa0_1 = _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1 = _mm256_and_si256(ws8_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256i ws0_2 = _mm256_srli_epi32(w0_0, 6); +__m256i ws8_2 = _mm256_srli_epi32(w0_8, 6); +__m256i wsa0_2 = _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2 = _mm256_and_si256(ws8_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +__m256i ws0_3 = _mm256_srli_epi32(w0_0, 9); +__m256i ws8_3 = _mm256_srli_epi32(w0_8, 9); +__m256i wsa0_3 = _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3 = _mm256_and_si256(ws8_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256i ws0_4 = _mm256_srli_epi32(w0_0, 12); +__m256i ws8_4 = _mm256_srli_epi32(w0_8, 12); +__m256i wsa0_4 = _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4 = _mm256_and_si256(ws8_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +__m256i ws0_5 = _mm256_srli_epi32(w0_0, 15); +__m256i ws8_5 = _mm256_srli_epi32(w0_8, 15); +__m256i wsa0_5 = _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5 = _mm256_and_si256(ws8_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256i ws0_6 = _mm256_srli_epi32(w0_0, 18); +__m256i ws8_6 = _mm256_srli_epi32(w0_8, 18); +__m256i wsa0_6 = _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6 = _mm256_and_si256(ws8_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +__m256i ws0_7 = _mm256_srli_epi32(w0_0, 21); +__m256i ws8_7 = _mm256_srli_epi32(w0_8, 21); +__m256i wsa0_7 = _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7 = _mm256_and_si256(ws8_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]); +__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]); +__m256i ws0_8 = _mm256_srli_epi32(w0_0, 24); +__m256i ws8_8 = _mm256_srli_epi32(w0_8, 24); +__m256i wsa0_8 = _mm256_and_si256(ws0_8, mask); +__m256i wsa8_8 = _mm256_and_si256(ws8_8, mask); +__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8); +__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8); +acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8); +__m256i ws0_9 = _mm256_srli_epi32(w0_0, 27); +__m256i ws8_9 = _mm256_srli_epi32(w0_8, 27); +__m256i wsa0_9 = _mm256_and_si256(ws0_9, mask); +__m256i wsa8_9 = _mm256_and_si256(ws8_9, mask); +__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9); +__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9); +acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8); +__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]); +__m256i ws0_10 = _mm256_srli_epi32(w0_0, 30); +__m256i temp0_0 = _mm256_slli_epi32(w1_0, 2); +temp0_0 = _mm256_and_si256(temp0_0, mask); +ws0_10 = _mm256_or_si256(ws0_10, temp0_0); +__m256i wsa0_10 = _mm256_and_si256(ws0_10, mask); +__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10); +acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0); +__m256i ws8_10 = _mm256_srli_epi32(w0_8, 30); +__m256i temp0_8 = _mm256_slli_epi32(w1_8, 2); +temp0_8 = _mm256_and_si256(temp0_8, mask); +ws8_10 = _mm256_or_si256(ws8_10, temp0_8); +__m256i wsa8_10 = _mm256_and_si256(ws8_10, mask); +__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10); +acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8); +__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]); +__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]); +__m256i ws0_11 = _mm256_srli_epi32(w1_0, 1); +__m256i ws8_11 = _mm256_srli_epi32(w1_8, 1); +__m256i wsa0_11 = _mm256_and_si256(ws0_11, mask); +__m256i wsa8_11 = _mm256_and_si256(ws8_11, mask); +__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11); +__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11); +acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8); +__m256i ws0_12 = _mm256_srli_epi32(w1_0, 4); +__m256i ws8_12 = _mm256_srli_epi32(w1_8, 4); +__m256i wsa0_12 = _mm256_and_si256(ws0_12, mask); +__m256i wsa8_12 = _mm256_and_si256(ws8_12, mask); +__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12); +__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12); +acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8); +__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]); +__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]); +__m256i ws0_13 = _mm256_srli_epi32(w1_0, 7); +__m256i ws8_13 = _mm256_srli_epi32(w1_8, 7); +__m256i wsa0_13 = _mm256_and_si256(ws0_13, mask); +__m256i wsa8_13 = _mm256_and_si256(ws8_13, mask); +__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13); +__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13); +acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8); +__m256i ws0_14 = _mm256_srli_epi32(w1_0, 10); +__m256i ws8_14 = _mm256_srli_epi32(w1_8, 10); +__m256i wsa0_14 = _mm256_and_si256(ws0_14, mask); +__m256i wsa8_14 = _mm256_and_si256(ws8_14, mask); +__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14); +__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14); +acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8); +__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]); +__m256 v0_16 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+16)*nb + i1+0]); +__m256i ws0_15 = _mm256_srli_epi32(w1_0, 13); +__m256i ws8_15 = _mm256_srli_epi32(w1_8, 13); +__m256i wsa0_15 = _mm256_and_si256(ws0_15, mask); +__m256i wsa8_15 = _mm256_and_si256(ws8_15, mask); +__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15); +__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15); +acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8); +__m256i ws0_16 = _mm256_srli_epi32(w1_0, 16); +__m256i ws8_16 = _mm256_srli_epi32(w1_8, 16); +__m256i wsa0_16 = _mm256_and_si256(ws0_16, mask); +__m256i wsa8_16 = _mm256_and_si256(ws8_16, mask); +__m256 l0_16 = _mm256_cvtepi32_ps(wsa0_16); +__m256 l8_16 = _mm256_cvtepi32_ps(wsa8_16); +acc0_0 = _mm256_fmadd_ps(v0_16, l0_16, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_16, l8_16, acc0_8); +__m256 v0_17 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+17)*nb + i1+0]); +__m256 v0_18 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+18)*nb + i1+0]); +__m256i ws0_17 = _mm256_srli_epi32(w1_0, 19); +__m256i ws8_17 = _mm256_srli_epi32(w1_8, 19); +__m256i wsa0_17 = _mm256_and_si256(ws0_17, mask); +__m256i wsa8_17 = _mm256_and_si256(ws8_17, mask); +__m256 l0_17 = _mm256_cvtepi32_ps(wsa0_17); +__m256 l8_17 = _mm256_cvtepi32_ps(wsa8_17); +acc0_0 = _mm256_fmadd_ps(v0_17, l0_17, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_17, l8_17, acc0_8); +__m256i ws0_18 = _mm256_srli_epi32(w1_0, 22); +__m256i ws8_18 = _mm256_srli_epi32(w1_8, 22); +__m256i wsa0_18 = _mm256_and_si256(ws0_18, mask); +__m256i wsa8_18 = _mm256_and_si256(ws8_18, mask); +__m256 l0_18 = _mm256_cvtepi32_ps(wsa0_18); +__m256 l8_18 = _mm256_cvtepi32_ps(wsa8_18); +acc0_0 = _mm256_fmadd_ps(v0_18, l0_18, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_18, l8_18, acc0_8); +__m256 v0_19 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+19)*nb + i1+0]); +__m256 v0_20 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+20)*nb + i1+0]); +__m256i ws0_19 = _mm256_srli_epi32(w1_0, 25); +__m256i ws8_19 = _mm256_srli_epi32(w1_8, 25); +__m256i wsa0_19 = _mm256_and_si256(ws0_19, mask); +__m256i wsa8_19 = _mm256_and_si256(ws8_19, mask); +__m256 l0_19 = _mm256_cvtepi32_ps(wsa0_19); +__m256 l8_19 = _mm256_cvtepi32_ps(wsa8_19); +acc0_0 = _mm256_fmadd_ps(v0_19, l0_19, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_19, l8_19, acc0_8); +__m256i ws0_20 = _mm256_srli_epi32(w1_0, 28); +__m256i ws8_20 = _mm256_srli_epi32(w1_8, 28); +__m256i wsa0_20 = _mm256_and_si256(ws0_20, mask); +__m256i wsa8_20 = _mm256_and_si256(ws8_20, mask); +__m256 l0_20 = _mm256_cvtepi32_ps(wsa0_20); +__m256 l8_20 = _mm256_cvtepi32_ps(wsa8_20); +acc0_0 = _mm256_fmadd_ps(v0_20, l0_20, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_20, l8_20, acc0_8); +__m256 v0_21 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+21)*nb + i1+0]); +__m256i ws0_21 = _mm256_srli_epi32(w1_0, 31); +__m256i temp1_0 = _mm256_slli_epi32(w2_0, 1); +temp1_0 = _mm256_and_si256(temp1_0, mask); +ws0_21 = _mm256_or_si256(ws0_21, temp1_0); +__m256i wsa0_21 = _mm256_and_si256(ws0_21, mask); +__m256 l0_21 = _mm256_cvtepi32_ps(wsa0_21); +acc0_0 = _mm256_fmadd_ps(v0_21, l0_21, acc0_0); +__m256i ws8_21 = _mm256_srli_epi32(w1_8, 31); +__m256i temp1_8 = _mm256_slli_epi32(w2_8, 1); +temp1_8 = _mm256_and_si256(temp1_8, mask); +ws8_21 = _mm256_or_si256(ws8_21, temp1_8); +__m256i wsa8_21 = _mm256_and_si256(ws8_21, mask); +__m256 l8_21 = _mm256_cvtepi32_ps(wsa8_21); +acc0_8 = _mm256_fmadd_ps(v0_21, l8_21, acc0_8); +__m256 v0_22 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+22)*nb + i1+0]); +__m256 v0_23 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+23)*nb + i1+0]); +__m256i ws0_22 = _mm256_srli_epi32(w2_0, 2); +__m256i ws8_22 = _mm256_srli_epi32(w2_8, 2); +__m256i wsa0_22 = _mm256_and_si256(ws0_22, mask); +__m256i wsa8_22 = _mm256_and_si256(ws8_22, mask); +__m256 l0_22 = _mm256_cvtepi32_ps(wsa0_22); +__m256 l8_22 = _mm256_cvtepi32_ps(wsa8_22); +acc0_0 = _mm256_fmadd_ps(v0_22, l0_22, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_22, l8_22, acc0_8); +__m256i ws0_23 = _mm256_srli_epi32(w2_0, 5); +__m256i ws8_23 = _mm256_srli_epi32(w2_8, 5); +__m256i wsa0_23 = _mm256_and_si256(ws0_23, mask); +__m256i wsa8_23 = _mm256_and_si256(ws8_23, mask); +__m256 l0_23 = _mm256_cvtepi32_ps(wsa0_23); +__m256 l8_23 = _mm256_cvtepi32_ps(wsa8_23); +acc0_0 = _mm256_fmadd_ps(v0_23, l0_23, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_23, l8_23, acc0_8); +__m256 v0_24 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+24)*nb + i1+0]); +__m256 v0_25 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+25)*nb + i1+0]); +__m256i ws0_24 = _mm256_srli_epi32(w2_0, 8); +__m256i ws8_24 = _mm256_srli_epi32(w2_8, 8); +__m256i wsa0_24 = _mm256_and_si256(ws0_24, mask); +__m256i wsa8_24 = _mm256_and_si256(ws8_24, mask); +__m256 l0_24 = _mm256_cvtepi32_ps(wsa0_24); +__m256 l8_24 = _mm256_cvtepi32_ps(wsa8_24); +acc0_0 = _mm256_fmadd_ps(v0_24, l0_24, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_24, l8_24, acc0_8); +__m256i ws0_25 = _mm256_srli_epi32(w2_0, 11); +__m256i ws8_25 = _mm256_srli_epi32(w2_8, 11); +__m256i wsa0_25 = _mm256_and_si256(ws0_25, mask); +__m256i wsa8_25 = _mm256_and_si256(ws8_25, mask); +__m256 l0_25 = _mm256_cvtepi32_ps(wsa0_25); +__m256 l8_25 = _mm256_cvtepi32_ps(wsa8_25); +acc0_0 = _mm256_fmadd_ps(v0_25, l0_25, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_25, l8_25, acc0_8); +__m256 v0_26 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+26)*nb + i1+0]); +__m256 v0_27 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+27)*nb + i1+0]); +__m256i ws0_26 = _mm256_srli_epi32(w2_0, 14); +__m256i ws8_26 = _mm256_srli_epi32(w2_8, 14); +__m256i wsa0_26 = _mm256_and_si256(ws0_26, mask); +__m256i wsa8_26 = _mm256_and_si256(ws8_26, mask); +__m256 l0_26 = _mm256_cvtepi32_ps(wsa0_26); +__m256 l8_26 = _mm256_cvtepi32_ps(wsa8_26); +acc0_0 = _mm256_fmadd_ps(v0_26, l0_26, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_26, l8_26, acc0_8); +__m256i ws0_27 = _mm256_srli_epi32(w2_0, 17); +__m256i ws8_27 = _mm256_srli_epi32(w2_8, 17); +__m256i wsa0_27 = _mm256_and_si256(ws0_27, mask); +__m256i wsa8_27 = _mm256_and_si256(ws8_27, mask); +__m256 l0_27 = _mm256_cvtepi32_ps(wsa0_27); +__m256 l8_27 = _mm256_cvtepi32_ps(wsa8_27); +acc0_0 = _mm256_fmadd_ps(v0_27, l0_27, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_27, l8_27, acc0_8); +__m256 v0_28 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+28)*nb + i1+0]); +__m256 v0_29 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+29)*nb + i1+0]); +__m256i ws0_28 = _mm256_srli_epi32(w2_0, 20); +__m256i ws8_28 = _mm256_srli_epi32(w2_8, 20); +__m256i wsa0_28 = _mm256_and_si256(ws0_28, mask); +__m256i wsa8_28 = _mm256_and_si256(ws8_28, mask); +__m256 l0_28 = _mm256_cvtepi32_ps(wsa0_28); +__m256 l8_28 = _mm256_cvtepi32_ps(wsa8_28); +acc0_0 = _mm256_fmadd_ps(v0_28, l0_28, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_28, l8_28, acc0_8); +__m256i ws0_29 = _mm256_srli_epi32(w2_0, 23); +__m256i ws8_29 = _mm256_srli_epi32(w2_8, 23); +__m256i wsa0_29 = _mm256_and_si256(ws0_29, mask); +__m256i wsa8_29 = _mm256_and_si256(ws8_29, mask); +__m256 l0_29 = _mm256_cvtepi32_ps(wsa0_29); +__m256 l8_29 = _mm256_cvtepi32_ps(wsa8_29); +acc0_0 = _mm256_fmadd_ps(v0_29, l0_29, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_29, l8_29, acc0_8); +__m256 v0_30 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+30)*nb + i1+0]); +__m256 v0_31 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+31)*nb + i1+0]); +__m256i ws0_30 = _mm256_srli_epi32(w2_0, 26); +__m256i ws8_30 = _mm256_srli_epi32(w2_8, 26); +__m256i wsa0_30 = _mm256_and_si256(ws0_30, mask); +__m256i wsa8_30 = _mm256_and_si256(ws8_30, mask); +__m256 l0_30 = _mm256_cvtepi32_ps(wsa0_30); +__m256 l8_30 = _mm256_cvtepi32_ps(wsa8_30); +acc0_0 = _mm256_fmadd_ps(v0_30, l0_30, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_30, l8_30, acc0_8); +__m256i ws0_31 = _mm256_srli_epi32(w2_0, 29); +__m256i ws8_31 = _mm256_srli_epi32(w2_8, 29); +__m256i wsa0_31 = _mm256_and_si256(ws0_31, mask); +__m256i wsa8_31 = _mm256_and_si256(ws8_31, mask); +__m256 l0_31 = _mm256_cvtepi32_ps(wsa0_31); +__m256 l8_31 = _mm256_cvtepi32_ps(wsa8_31); +acc0_0 = _mm256_fmadd_ps(v0_31, l0_31, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_31, l8_31, acc0_8); +} +} +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], acc0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], acc0_8); +} +} +} +} +} +#pragma omp barrier +for (int i = 0; i < n; i++) { +__m256 r = _mm256_set1_ps(sums[i]); +for (int j = 0; j < tt; j+=16){ +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + j + 8]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 s0 = _mm256_loadu_ps(&scales[base_output + j + 0]); +__m256 s8 = _mm256_loadu_ps(&scales[base_output + j + 8]); +__m256 os0 = _mm256_mul_ps(o0, s0); +__m256 os8 = _mm256_mul_ps(o8, s8); +__m256 zr0 = _mm256_fnmadd_ps(z0, r, os0); +__m256 zr8 = _mm256_fnmadd_ps(z8, r, os8); +__m256 o20 = _mm256_add_ps(zr0, b0); +__m256 o28 = _mm256_add_ps(zr8, b8); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +} +} +} +} +inline void forward3_cpu( +torch::Tensor in, torch::Tensor weight, torch::Tensor out, +torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums, +int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){ +int* W = weight.data_ptr(); +float* input = in.data_ptr(); +float* b = bias.data_ptr(); +float* s = scales.data_ptr(); +float* z = zeros.data_ptr(); +float* r = sums.data_ptr(); +float* O = out.data_ptr(); + +q3gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff); +} +inline +void q3gemm_gs(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int gs, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 32; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(7); +const __m256i mask4 = _mm256_set1_epi32(4); +const __m256i mask6 = _mm256_set1_epi32(6); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/32*3: + tid*tt*m/32*3; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +int jw = 0; +for(; j1 < tb-tu+1; j1+=tu, jw+=96){ +for(int k1 = 0; k1 < mb; k1+=gs) { +__m256 acc0_0 = _mm256_setzero_ps(); +__m256 acc0_8 = _mm256_setzero_ps(); +__m256 acc0_16 = _mm256_setzero_ps(); +__m256 acc0_24 = _mm256_setzero_ps(); +for(int k2 = k1; k2 < k1+gs; k2+=32) +{ +__m256i w0_0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+0]); +__m256i w1_0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+0+8]); +__m256i w2_0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+0+16]); +__m256i w0_8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+24]); +__m256i w1_8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+24+8]); +__m256i w2_8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+24+16]); +__m256i w0_16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+48]); +__m256i w1_16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+48+8]); +__m256i w2_16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+48+16]); +__m256i w0_24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+72]); +__m256i w1_24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+72+8]); +__m256i w2_24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/32*3 + k*mb*tb/32*3 + k2*tb/32*3 + jw+72+16]); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0_0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w0_8, 0); +__m256i ws16_0 = _mm256_srli_epi32(w0_16, 0); +__m256i ws24_0 = _mm256_srli_epi32(w0_24, 0); +__m256i wsa0_0 = _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0 = _mm256_and_si256(ws8_0, mask); +__m256i wsa16_0 = _mm256_and_si256(ws16_0, mask); +__m256i wsa24_0 = _mm256_and_si256(ws24_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0); +__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24); +__m256i ws0_1 = _mm256_srli_epi32(w0_0, 3); +__m256i ws8_1 = _mm256_srli_epi32(w0_8, 3); +__m256i ws16_1 = _mm256_srli_epi32(w0_16, 3); +__m256i ws24_1 = _mm256_srli_epi32(w0_24, 3); +__m256i wsa0_1 = _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1 = _mm256_and_si256(ws8_1, mask); +__m256i wsa16_1 = _mm256_and_si256(ws16_1, mask); +__m256i wsa24_1 = _mm256_and_si256(ws24_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1); +__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256i ws0_2 = _mm256_srli_epi32(w0_0, 6); +__m256i ws8_2 = _mm256_srli_epi32(w0_8, 6); +__m256i ws16_2 = _mm256_srli_epi32(w0_16, 6); +__m256i ws24_2 = _mm256_srli_epi32(w0_24, 6); +__m256i wsa0_2 = _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2 = _mm256_and_si256(ws8_2, mask); +__m256i wsa16_2 = _mm256_and_si256(ws16_2, mask); +__m256i wsa24_2 = _mm256_and_si256(ws24_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2); +__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24); +__m256i ws0_3 = _mm256_srli_epi32(w0_0, 9); +__m256i ws8_3 = _mm256_srli_epi32(w0_8, 9); +__m256i ws16_3 = _mm256_srli_epi32(w0_16, 9); +__m256i ws24_3 = _mm256_srli_epi32(w0_24, 9); +__m256i wsa0_3 = _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3 = _mm256_and_si256(ws8_3, mask); +__m256i wsa16_3 = _mm256_and_si256(ws16_3, mask); +__m256i wsa24_3 = _mm256_and_si256(ws24_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3); +__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256i ws0_4 = _mm256_srli_epi32(w0_0, 12); +__m256i ws8_4 = _mm256_srli_epi32(w0_8, 12); +__m256i ws16_4 = _mm256_srli_epi32(w0_16, 12); +__m256i ws24_4 = _mm256_srli_epi32(w0_24, 12); +__m256i wsa0_4 = _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4 = _mm256_and_si256(ws8_4, mask); +__m256i wsa16_4 = _mm256_and_si256(ws16_4, mask); +__m256i wsa24_4 = _mm256_and_si256(ws24_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4); +__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24); +__m256i ws0_5 = _mm256_srli_epi32(w0_0, 15); +__m256i ws8_5 = _mm256_srli_epi32(w0_8, 15); +__m256i ws16_5 = _mm256_srli_epi32(w0_16, 15); +__m256i ws24_5 = _mm256_srli_epi32(w0_24, 15); +__m256i wsa0_5 = _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5 = _mm256_and_si256(ws8_5, mask); +__m256i wsa16_5 = _mm256_and_si256(ws16_5, mask); +__m256i wsa24_5 = _mm256_and_si256(ws24_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5); +__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256i ws0_6 = _mm256_srli_epi32(w0_0, 18); +__m256i ws8_6 = _mm256_srli_epi32(w0_8, 18); +__m256i ws16_6 = _mm256_srli_epi32(w0_16, 18); +__m256i ws24_6 = _mm256_srli_epi32(w0_24, 18); +__m256i wsa0_6 = _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6 = _mm256_and_si256(ws8_6, mask); +__m256i wsa16_6 = _mm256_and_si256(ws16_6, mask); +__m256i wsa24_6 = _mm256_and_si256(ws24_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6); +__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24); +__m256i ws0_7 = _mm256_srli_epi32(w0_0, 21); +__m256i ws8_7 = _mm256_srli_epi32(w0_8, 21); +__m256i ws16_7 = _mm256_srli_epi32(w0_16, 21); +__m256i ws24_7 = _mm256_srli_epi32(w0_24, 21); +__m256i wsa0_7 = _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7 = _mm256_and_si256(ws8_7, mask); +__m256i wsa16_7 = _mm256_and_si256(ws16_7, mask); +__m256i wsa24_7 = _mm256_and_si256(ws24_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7); +__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24); +__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]); +__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]); +__m256i ws0_8 = _mm256_srli_epi32(w0_0, 24); +__m256i ws8_8 = _mm256_srli_epi32(w0_8, 24); +__m256i ws16_8 = _mm256_srli_epi32(w0_16, 24); +__m256i ws24_8 = _mm256_srli_epi32(w0_24, 24); +__m256i wsa0_8 = _mm256_and_si256(ws0_8, mask); +__m256i wsa8_8 = _mm256_and_si256(ws8_8, mask); +__m256i wsa16_8 = _mm256_and_si256(ws16_8, mask); +__m256i wsa24_8 = _mm256_and_si256(ws24_8, mask); +__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8); +__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8); +__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8); +__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8); +acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24); +__m256i ws0_9 = _mm256_srli_epi32(w0_0, 27); +__m256i ws8_9 = _mm256_srli_epi32(w0_8, 27); +__m256i ws16_9 = _mm256_srli_epi32(w0_16, 27); +__m256i ws24_9 = _mm256_srli_epi32(w0_24, 27); +__m256i wsa0_9 = _mm256_and_si256(ws0_9, mask); +__m256i wsa8_9 = _mm256_and_si256(ws8_9, mask); +__m256i wsa16_9 = _mm256_and_si256(ws16_9, mask); +__m256i wsa24_9 = _mm256_and_si256(ws24_9, mask); +__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9); +__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9); +__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9); +__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9); +acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24); +__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]); +__m256i ws0_10 = _mm256_srli_epi32(w0_0, 30); +__m256i temp0_0 = _mm256_slli_epi32(w1_0, 2); +temp0_0 = _mm256_and_si256(temp0_0, mask); +ws0_10 = _mm256_or_si256(ws0_10, temp0_0); +__m256i wsa0_10 = _mm256_and_si256(ws0_10, mask); +__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10); +acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0); +__m256i ws8_10 = _mm256_srli_epi32(w0_8, 30); +__m256i temp0_8 = _mm256_slli_epi32(w1_8, 2); +temp0_8 = _mm256_and_si256(temp0_8, mask); +ws8_10 = _mm256_or_si256(ws8_10, temp0_8); +__m256i wsa8_10 = _mm256_and_si256(ws8_10, mask); +__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10); +acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8); +__m256i ws16_10 = _mm256_srli_epi32(w0_16, 30); +__m256i temp0_16 = _mm256_slli_epi32(w1_16, 2); +temp0_16 = _mm256_and_si256(temp0_16, mask); +ws16_10 = _mm256_or_si256(ws16_10, temp0_16); +__m256i wsa16_10 = _mm256_and_si256(ws16_10, mask); +__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10); +acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16); +__m256i ws24_10 = _mm256_srli_epi32(w0_24, 30); +__m256i temp0_24 = _mm256_slli_epi32(w1_24, 2); +temp0_24 = _mm256_and_si256(temp0_24, mask); +ws24_10 = _mm256_or_si256(ws24_10, temp0_24); +__m256i wsa24_10 = _mm256_and_si256(ws24_10, mask); +__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10); +acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24); +__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]); +__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]); +__m256i ws0_11 = _mm256_srli_epi32(w1_0, 1); +__m256i ws8_11 = _mm256_srli_epi32(w1_8, 1); +__m256i ws16_11 = _mm256_srli_epi32(w1_16, 1); +__m256i ws24_11 = _mm256_srli_epi32(w1_24, 1); +__m256i wsa0_11 = _mm256_and_si256(ws0_11, mask); +__m256i wsa8_11 = _mm256_and_si256(ws8_11, mask); +__m256i wsa16_11 = _mm256_and_si256(ws16_11, mask); +__m256i wsa24_11 = _mm256_and_si256(ws24_11, mask); +__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11); +__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11); +__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11); +__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11); +acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24); +__m256i ws0_12 = _mm256_srli_epi32(w1_0, 4); +__m256i ws8_12 = _mm256_srli_epi32(w1_8, 4); +__m256i ws16_12 = _mm256_srli_epi32(w1_16, 4); +__m256i ws24_12 = _mm256_srli_epi32(w1_24, 4); +__m256i wsa0_12 = _mm256_and_si256(ws0_12, mask); +__m256i wsa8_12 = _mm256_and_si256(ws8_12, mask); +__m256i wsa16_12 = _mm256_and_si256(ws16_12, mask); +__m256i wsa24_12 = _mm256_and_si256(ws24_12, mask); +__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12); +__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12); +__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12); +__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12); +acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24); +__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]); +__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]); +__m256i ws0_13 = _mm256_srli_epi32(w1_0, 7); +__m256i ws8_13 = _mm256_srli_epi32(w1_8, 7); +__m256i ws16_13 = _mm256_srli_epi32(w1_16, 7); +__m256i ws24_13 = _mm256_srli_epi32(w1_24, 7); +__m256i wsa0_13 = _mm256_and_si256(ws0_13, mask); +__m256i wsa8_13 = _mm256_and_si256(ws8_13, mask); +__m256i wsa16_13 = _mm256_and_si256(ws16_13, mask); +__m256i wsa24_13 = _mm256_and_si256(ws24_13, mask); +__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13); +__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13); +__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13); +__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13); +acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24); +__m256i ws0_14 = _mm256_srli_epi32(w1_0, 10); +__m256i ws8_14 = _mm256_srli_epi32(w1_8, 10); +__m256i ws16_14 = _mm256_srli_epi32(w1_16, 10); +__m256i ws24_14 = _mm256_srli_epi32(w1_24, 10); +__m256i wsa0_14 = _mm256_and_si256(ws0_14, mask); +__m256i wsa8_14 = _mm256_and_si256(ws8_14, mask); +__m256i wsa16_14 = _mm256_and_si256(ws16_14, mask); +__m256i wsa24_14 = _mm256_and_si256(ws24_14, mask); +__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14); +__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14); +__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14); +__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14); +acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24); +__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]); +__m256 v0_16 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+16)*nb + i1+0]); +__m256i ws0_15 = _mm256_srli_epi32(w1_0, 13); +__m256i ws8_15 = _mm256_srli_epi32(w1_8, 13); +__m256i ws16_15 = _mm256_srli_epi32(w1_16, 13); +__m256i ws24_15 = _mm256_srli_epi32(w1_24, 13); +__m256i wsa0_15 = _mm256_and_si256(ws0_15, mask); +__m256i wsa8_15 = _mm256_and_si256(ws8_15, mask); +__m256i wsa16_15 = _mm256_and_si256(ws16_15, mask); +__m256i wsa24_15 = _mm256_and_si256(ws24_15, mask); +__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15); +__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15); +__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15); +__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15); +acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24); +__m256i ws0_16 = _mm256_srli_epi32(w1_0, 16); +__m256i ws8_16 = _mm256_srli_epi32(w1_8, 16); +__m256i ws16_16 = _mm256_srli_epi32(w1_16, 16); +__m256i ws24_16 = _mm256_srli_epi32(w1_24, 16); +__m256i wsa0_16 = _mm256_and_si256(ws0_16, mask); +__m256i wsa8_16 = _mm256_and_si256(ws8_16, mask); +__m256i wsa16_16 = _mm256_and_si256(ws16_16, mask); +__m256i wsa24_16 = _mm256_and_si256(ws24_16, mask); +__m256 l0_16 = _mm256_cvtepi32_ps(wsa0_16); +__m256 l8_16 = _mm256_cvtepi32_ps(wsa8_16); +__m256 l16_16 = _mm256_cvtepi32_ps(wsa16_16); +__m256 l24_16 = _mm256_cvtepi32_ps(wsa24_16); +acc0_0 = _mm256_fmadd_ps(v0_16, l0_16, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_16, l8_16, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_16, l16_16, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_16, l24_16, acc0_24); +__m256 v0_17 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+17)*nb + i1+0]); +__m256 v0_18 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+18)*nb + i1+0]); +__m256i ws0_17 = _mm256_srli_epi32(w1_0, 19); +__m256i ws8_17 = _mm256_srli_epi32(w1_8, 19); +__m256i ws16_17 = _mm256_srli_epi32(w1_16, 19); +__m256i ws24_17 = _mm256_srli_epi32(w1_24, 19); +__m256i wsa0_17 = _mm256_and_si256(ws0_17, mask); +__m256i wsa8_17 = _mm256_and_si256(ws8_17, mask); +__m256i wsa16_17 = _mm256_and_si256(ws16_17, mask); +__m256i wsa24_17 = _mm256_and_si256(ws24_17, mask); +__m256 l0_17 = _mm256_cvtepi32_ps(wsa0_17); +__m256 l8_17 = _mm256_cvtepi32_ps(wsa8_17); +__m256 l16_17 = _mm256_cvtepi32_ps(wsa16_17); +__m256 l24_17 = _mm256_cvtepi32_ps(wsa24_17); +acc0_0 = _mm256_fmadd_ps(v0_17, l0_17, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_17, l8_17, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_17, l16_17, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_17, l24_17, acc0_24); +__m256i ws0_18 = _mm256_srli_epi32(w1_0, 22); +__m256i ws8_18 = _mm256_srli_epi32(w1_8, 22); +__m256i ws16_18 = _mm256_srli_epi32(w1_16, 22); +__m256i ws24_18 = _mm256_srli_epi32(w1_24, 22); +__m256i wsa0_18 = _mm256_and_si256(ws0_18, mask); +__m256i wsa8_18 = _mm256_and_si256(ws8_18, mask); +__m256i wsa16_18 = _mm256_and_si256(ws16_18, mask); +__m256i wsa24_18 = _mm256_and_si256(ws24_18, mask); +__m256 l0_18 = _mm256_cvtepi32_ps(wsa0_18); +__m256 l8_18 = _mm256_cvtepi32_ps(wsa8_18); +__m256 l16_18 = _mm256_cvtepi32_ps(wsa16_18); +__m256 l24_18 = _mm256_cvtepi32_ps(wsa24_18); +acc0_0 = _mm256_fmadd_ps(v0_18, l0_18, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_18, l8_18, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_18, l16_18, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_18, l24_18, acc0_24); +__m256 v0_19 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+19)*nb + i1+0]); +__m256 v0_20 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+20)*nb + i1+0]); +__m256i ws0_19 = _mm256_srli_epi32(w1_0, 25); +__m256i ws8_19 = _mm256_srli_epi32(w1_8, 25); +__m256i ws16_19 = _mm256_srli_epi32(w1_16, 25); +__m256i ws24_19 = _mm256_srli_epi32(w1_24, 25); +__m256i wsa0_19 = _mm256_and_si256(ws0_19, mask); +__m256i wsa8_19 = _mm256_and_si256(ws8_19, mask); +__m256i wsa16_19 = _mm256_and_si256(ws16_19, mask); +__m256i wsa24_19 = _mm256_and_si256(ws24_19, mask); +__m256 l0_19 = _mm256_cvtepi32_ps(wsa0_19); +__m256 l8_19 = _mm256_cvtepi32_ps(wsa8_19); +__m256 l16_19 = _mm256_cvtepi32_ps(wsa16_19); +__m256 l24_19 = _mm256_cvtepi32_ps(wsa24_19); +acc0_0 = _mm256_fmadd_ps(v0_19, l0_19, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_19, l8_19, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_19, l16_19, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_19, l24_19, acc0_24); +__m256i ws0_20 = _mm256_srli_epi32(w1_0, 28); +__m256i ws8_20 = _mm256_srli_epi32(w1_8, 28); +__m256i ws16_20 = _mm256_srli_epi32(w1_16, 28); +__m256i ws24_20 = _mm256_srli_epi32(w1_24, 28); +__m256i wsa0_20 = _mm256_and_si256(ws0_20, mask); +__m256i wsa8_20 = _mm256_and_si256(ws8_20, mask); +__m256i wsa16_20 = _mm256_and_si256(ws16_20, mask); +__m256i wsa24_20 = _mm256_and_si256(ws24_20, mask); +__m256 l0_20 = _mm256_cvtepi32_ps(wsa0_20); +__m256 l8_20 = _mm256_cvtepi32_ps(wsa8_20); +__m256 l16_20 = _mm256_cvtepi32_ps(wsa16_20); +__m256 l24_20 = _mm256_cvtepi32_ps(wsa24_20); +acc0_0 = _mm256_fmadd_ps(v0_20, l0_20, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_20, l8_20, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_20, l16_20, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_20, l24_20, acc0_24); +__m256 v0_21 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+21)*nb + i1+0]); +__m256i ws0_21 = _mm256_srli_epi32(w1_0, 31); +__m256i temp1_0 = _mm256_slli_epi32(w2_0, 1); +temp1_0 = _mm256_and_si256(temp1_0, mask); +ws0_21 = _mm256_or_si256(ws0_21, temp1_0); +__m256i wsa0_21 = _mm256_and_si256(ws0_21, mask); +__m256 l0_21 = _mm256_cvtepi32_ps(wsa0_21); +acc0_0 = _mm256_fmadd_ps(v0_21, l0_21, acc0_0); +__m256i ws8_21 = _mm256_srli_epi32(w1_8, 31); +__m256i temp1_8 = _mm256_slli_epi32(w2_8, 1); +temp1_8 = _mm256_and_si256(temp1_8, mask); +ws8_21 = _mm256_or_si256(ws8_21, temp1_8); +__m256i wsa8_21 = _mm256_and_si256(ws8_21, mask); +__m256 l8_21 = _mm256_cvtepi32_ps(wsa8_21); +acc0_8 = _mm256_fmadd_ps(v0_21, l8_21, acc0_8); +__m256i ws16_21 = _mm256_srli_epi32(w1_16, 31); +__m256i temp1_16 = _mm256_slli_epi32(w2_16, 1); +temp1_16 = _mm256_and_si256(temp1_16, mask); +ws16_21 = _mm256_or_si256(ws16_21, temp1_16); +__m256i wsa16_21 = _mm256_and_si256(ws16_21, mask); +__m256 l16_21 = _mm256_cvtepi32_ps(wsa16_21); +acc0_16 = _mm256_fmadd_ps(v0_21, l16_21, acc0_16); +__m256i ws24_21 = _mm256_srli_epi32(w1_24, 31); +__m256i temp1_24 = _mm256_slli_epi32(w2_24, 1); +temp1_24 = _mm256_and_si256(temp1_24, mask); +ws24_21 = _mm256_or_si256(ws24_21, temp1_24); +__m256i wsa24_21 = _mm256_and_si256(ws24_21, mask); +__m256 l24_21 = _mm256_cvtepi32_ps(wsa24_21); +acc0_24 = _mm256_fmadd_ps(v0_21, l24_21, acc0_24); +__m256 v0_22 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+22)*nb + i1+0]); +__m256 v0_23 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+23)*nb + i1+0]); +__m256i ws0_22 = _mm256_srli_epi32(w2_0, 2); +__m256i ws8_22 = _mm256_srli_epi32(w2_8, 2); +__m256i ws16_22 = _mm256_srli_epi32(w2_16, 2); +__m256i ws24_22 = _mm256_srli_epi32(w2_24, 2); +__m256i wsa0_22 = _mm256_and_si256(ws0_22, mask); +__m256i wsa8_22 = _mm256_and_si256(ws8_22, mask); +__m256i wsa16_22 = _mm256_and_si256(ws16_22, mask); +__m256i wsa24_22 = _mm256_and_si256(ws24_22, mask); +__m256 l0_22 = _mm256_cvtepi32_ps(wsa0_22); +__m256 l8_22 = _mm256_cvtepi32_ps(wsa8_22); +__m256 l16_22 = _mm256_cvtepi32_ps(wsa16_22); +__m256 l24_22 = _mm256_cvtepi32_ps(wsa24_22); +acc0_0 = _mm256_fmadd_ps(v0_22, l0_22, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_22, l8_22, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_22, l16_22, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_22, l24_22, acc0_24); +__m256i ws0_23 = _mm256_srli_epi32(w2_0, 5); +__m256i ws8_23 = _mm256_srli_epi32(w2_8, 5); +__m256i ws16_23 = _mm256_srli_epi32(w2_16, 5); +__m256i ws24_23 = _mm256_srli_epi32(w2_24, 5); +__m256i wsa0_23 = _mm256_and_si256(ws0_23, mask); +__m256i wsa8_23 = _mm256_and_si256(ws8_23, mask); +__m256i wsa16_23 = _mm256_and_si256(ws16_23, mask); +__m256i wsa24_23 = _mm256_and_si256(ws24_23, mask); +__m256 l0_23 = _mm256_cvtepi32_ps(wsa0_23); +__m256 l8_23 = _mm256_cvtepi32_ps(wsa8_23); +__m256 l16_23 = _mm256_cvtepi32_ps(wsa16_23); +__m256 l24_23 = _mm256_cvtepi32_ps(wsa24_23); +acc0_0 = _mm256_fmadd_ps(v0_23, l0_23, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_23, l8_23, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_23, l16_23, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_23, l24_23, acc0_24); +__m256 v0_24 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+24)*nb + i1+0]); +__m256 v0_25 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+25)*nb + i1+0]); +__m256i ws0_24 = _mm256_srli_epi32(w2_0, 8); +__m256i ws8_24 = _mm256_srli_epi32(w2_8, 8); +__m256i ws16_24 = _mm256_srli_epi32(w2_16, 8); +__m256i ws24_24 = _mm256_srli_epi32(w2_24, 8); +__m256i wsa0_24 = _mm256_and_si256(ws0_24, mask); +__m256i wsa8_24 = _mm256_and_si256(ws8_24, mask); +__m256i wsa16_24 = _mm256_and_si256(ws16_24, mask); +__m256i wsa24_24 = _mm256_and_si256(ws24_24, mask); +__m256 l0_24 = _mm256_cvtepi32_ps(wsa0_24); +__m256 l8_24 = _mm256_cvtepi32_ps(wsa8_24); +__m256 l16_24 = _mm256_cvtepi32_ps(wsa16_24); +__m256 l24_24 = _mm256_cvtepi32_ps(wsa24_24); +acc0_0 = _mm256_fmadd_ps(v0_24, l0_24, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_24, l8_24, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_24, l16_24, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_24, l24_24, acc0_24); +__m256i ws0_25 = _mm256_srli_epi32(w2_0, 11); +__m256i ws8_25 = _mm256_srli_epi32(w2_8, 11); +__m256i ws16_25 = _mm256_srli_epi32(w2_16, 11); +__m256i ws24_25 = _mm256_srli_epi32(w2_24, 11); +__m256i wsa0_25 = _mm256_and_si256(ws0_25, mask); +__m256i wsa8_25 = _mm256_and_si256(ws8_25, mask); +__m256i wsa16_25 = _mm256_and_si256(ws16_25, mask); +__m256i wsa24_25 = _mm256_and_si256(ws24_25, mask); +__m256 l0_25 = _mm256_cvtepi32_ps(wsa0_25); +__m256 l8_25 = _mm256_cvtepi32_ps(wsa8_25); +__m256 l16_25 = _mm256_cvtepi32_ps(wsa16_25); +__m256 l24_25 = _mm256_cvtepi32_ps(wsa24_25); +acc0_0 = _mm256_fmadd_ps(v0_25, l0_25, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_25, l8_25, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_25, l16_25, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_25, l24_25, acc0_24); +__m256 v0_26 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+26)*nb + i1+0]); +__m256 v0_27 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+27)*nb + i1+0]); +__m256i ws0_26 = _mm256_srli_epi32(w2_0, 14); +__m256i ws8_26 = _mm256_srli_epi32(w2_8, 14); +__m256i ws16_26 = _mm256_srli_epi32(w2_16, 14); +__m256i ws24_26 = _mm256_srli_epi32(w2_24, 14); +__m256i wsa0_26 = _mm256_and_si256(ws0_26, mask); +__m256i wsa8_26 = _mm256_and_si256(ws8_26, mask); +__m256i wsa16_26 = _mm256_and_si256(ws16_26, mask); +__m256i wsa24_26 = _mm256_and_si256(ws24_26, mask); +__m256 l0_26 = _mm256_cvtepi32_ps(wsa0_26); +__m256 l8_26 = _mm256_cvtepi32_ps(wsa8_26); +__m256 l16_26 = _mm256_cvtepi32_ps(wsa16_26); +__m256 l24_26 = _mm256_cvtepi32_ps(wsa24_26); +acc0_0 = _mm256_fmadd_ps(v0_26, l0_26, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_26, l8_26, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_26, l16_26, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_26, l24_26, acc0_24); +__m256i ws0_27 = _mm256_srli_epi32(w2_0, 17); +__m256i ws8_27 = _mm256_srli_epi32(w2_8, 17); +__m256i ws16_27 = _mm256_srli_epi32(w2_16, 17); +__m256i ws24_27 = _mm256_srli_epi32(w2_24, 17); +__m256i wsa0_27 = _mm256_and_si256(ws0_27, mask); +__m256i wsa8_27 = _mm256_and_si256(ws8_27, mask); +__m256i wsa16_27 = _mm256_and_si256(ws16_27, mask); +__m256i wsa24_27 = _mm256_and_si256(ws24_27, mask); +__m256 l0_27 = _mm256_cvtepi32_ps(wsa0_27); +__m256 l8_27 = _mm256_cvtepi32_ps(wsa8_27); +__m256 l16_27 = _mm256_cvtepi32_ps(wsa16_27); +__m256 l24_27 = _mm256_cvtepi32_ps(wsa24_27); +acc0_0 = _mm256_fmadd_ps(v0_27, l0_27, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_27, l8_27, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_27, l16_27, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_27, l24_27, acc0_24); +__m256 v0_28 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+28)*nb + i1+0]); +__m256 v0_29 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+29)*nb + i1+0]); +__m256i ws0_28 = _mm256_srli_epi32(w2_0, 20); +__m256i ws8_28 = _mm256_srli_epi32(w2_8, 20); +__m256i ws16_28 = _mm256_srli_epi32(w2_16, 20); +__m256i ws24_28 = _mm256_srli_epi32(w2_24, 20); +__m256i wsa0_28 = _mm256_and_si256(ws0_28, mask); +__m256i wsa8_28 = _mm256_and_si256(ws8_28, mask); +__m256i wsa16_28 = _mm256_and_si256(ws16_28, mask); +__m256i wsa24_28 = _mm256_and_si256(ws24_28, mask); +__m256 l0_28 = _mm256_cvtepi32_ps(wsa0_28); +__m256 l8_28 = _mm256_cvtepi32_ps(wsa8_28); +__m256 l16_28 = _mm256_cvtepi32_ps(wsa16_28); +__m256 l24_28 = _mm256_cvtepi32_ps(wsa24_28); +acc0_0 = _mm256_fmadd_ps(v0_28, l0_28, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_28, l8_28, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_28, l16_28, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_28, l24_28, acc0_24); +__m256i ws0_29 = _mm256_srli_epi32(w2_0, 23); +__m256i ws8_29 = _mm256_srli_epi32(w2_8, 23); +__m256i ws16_29 = _mm256_srli_epi32(w2_16, 23); +__m256i ws24_29 = _mm256_srli_epi32(w2_24, 23); +__m256i wsa0_29 = _mm256_and_si256(ws0_29, mask); +__m256i wsa8_29 = _mm256_and_si256(ws8_29, mask); +__m256i wsa16_29 = _mm256_and_si256(ws16_29, mask); +__m256i wsa24_29 = _mm256_and_si256(ws24_29, mask); +__m256 l0_29 = _mm256_cvtepi32_ps(wsa0_29); +__m256 l8_29 = _mm256_cvtepi32_ps(wsa8_29); +__m256 l16_29 = _mm256_cvtepi32_ps(wsa16_29); +__m256 l24_29 = _mm256_cvtepi32_ps(wsa24_29); +acc0_0 = _mm256_fmadd_ps(v0_29, l0_29, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_29, l8_29, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_29, l16_29, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_29, l24_29, acc0_24); +__m256 v0_30 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+30)*nb + i1+0]); +__m256 v0_31 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+31)*nb + i1+0]); +__m256i ws0_30 = _mm256_srli_epi32(w2_0, 26); +__m256i ws8_30 = _mm256_srli_epi32(w2_8, 26); +__m256i ws16_30 = _mm256_srli_epi32(w2_16, 26); +__m256i ws24_30 = _mm256_srli_epi32(w2_24, 26); +__m256i wsa0_30 = _mm256_and_si256(ws0_30, mask); +__m256i wsa8_30 = _mm256_and_si256(ws8_30, mask); +__m256i wsa16_30 = _mm256_and_si256(ws16_30, mask); +__m256i wsa24_30 = _mm256_and_si256(ws24_30, mask); +__m256 l0_30 = _mm256_cvtepi32_ps(wsa0_30); +__m256 l8_30 = _mm256_cvtepi32_ps(wsa8_30); +__m256 l16_30 = _mm256_cvtepi32_ps(wsa16_30); +__m256 l24_30 = _mm256_cvtepi32_ps(wsa24_30); +acc0_0 = _mm256_fmadd_ps(v0_30, l0_30, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_30, l8_30, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_30, l16_30, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_30, l24_30, acc0_24); +__m256i ws0_31 = _mm256_srli_epi32(w2_0, 29); +__m256i ws8_31 = _mm256_srli_epi32(w2_8, 29); +__m256i ws16_31 = _mm256_srli_epi32(w2_16, 29); +__m256i ws24_31 = _mm256_srli_epi32(w2_24, 29); +__m256i wsa0_31 = _mm256_and_si256(ws0_31, mask); +__m256i wsa8_31 = _mm256_and_si256(ws8_31, mask); +__m256i wsa16_31 = _mm256_and_si256(ws16_31, mask); +__m256i wsa24_31 = _mm256_and_si256(ws24_31, mask); +__m256 l0_31 = _mm256_cvtepi32_ps(wsa0_31); +__m256 l8_31 = _mm256_cvtepi32_ps(wsa8_31); +__m256 l16_31 = _mm256_cvtepi32_ps(wsa16_31); +__m256 l24_31 = _mm256_cvtepi32_ps(wsa24_31); +acc0_0 = _mm256_fmadd_ps(v0_31, l0_31, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_31, l8_31, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_31, l16_31, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_31, l24_31, acc0_24); +} +__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]); +__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]); +__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]); +__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]); +__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]); +__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]); +__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0); +__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8); +__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16); +__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24); +} +} +} +} +} +} +#pragma omp barrier +const int ngs = m/gs; +for (int i = 0; i < n; i++) { +for (int j = 0; j < tt; j+=32){ +__m256 acc0 = _mm256_setzero_ps(); +__m256 acc8 = _mm256_setzero_ps(); +__m256 acc16 = _mm256_setzero_ps(); +__m256 acc24 = _mm256_setzero_ps(); +for (int i1 = 0; i1 < ngs; i1++){ +__m256 r = _mm256_set1_ps(sums[i*ngs + i1]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]); +__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]); +__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]); +acc0 = _mm256_fmadd_ps(z0, r, acc0); +acc8 = _mm256_fmadd_ps(z8, r, acc8); +acc16 = _mm256_fmadd_ps(z16, r, acc16); +acc24 = _mm256_fmadd_ps(z24, r, acc24); +} +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]); +__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]); +__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]); +__m256 o10 = _mm256_sub_ps(o0, acc0); +__m256 o18 = _mm256_sub_ps(o8, acc8); +__m256 o116 = _mm256_sub_ps(o16, acc16); +__m256 o124 = _mm256_sub_ps(o24, acc24); +__m256 o20 = _mm256_add_ps(o10, b0); +__m256 o28 = _mm256_add_ps(o18, b8); +__m256 o216 = _mm256_add_ps(o116, b16); +__m256 o224 = _mm256_add_ps(o124, b24); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216); +_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224); +} +} +} +} +inline void forward3_gs_cpu( +torch::Tensor in, torch::Tensor weight, torch::Tensor out, +torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums, +int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){ +int* W = weight.data_ptr(); +float* input = in.data_ptr(); +float* b = bias.data_ptr(); +float* s = scales.data_ptr(); +float* z = zeros.data_ptr(); +float* r = sums.data_ptr(); +float* O = out.data_ptr(); + +q3gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff); +} +inline void pack3_qw_inner(int* A, int* B, const int N, const int M, const int nb, const int mb, int cutoff){ +// copy the full matrix A in blocked format into B +uint64_t idx = 0; +for(int j = 0, tid = 0; j < M; j+=mb, tid++){ +for(int i = 0; i < N; i+=nb){ + for(int ii = i; ii < mymin(i+nb, N); ii+=3){ + for(int jj = j; jj < mymin(j+mb, M); jj+=8){ + for(int iii = ii; iii < ii + 3; iii++){ + for(int jjj = jj; jjj < jj + 8; jjj++){ + B[idx] = A[iii*M+jjj]; + idx++; + } + } + } + } + } + } + } +inline void pack3_w_cpu( +torch::Tensor in, torch::Tensor out, +int N, int M, int nb, int mb, int cutoff){ +int* input = in.data_ptr(); +int* O = out.data_ptr(); +pack3_qw_inner(input, O, N, M, nb, mb, cutoff); +} +void unpack_zeros3_cpu(const int* zv, float* ov, int n, int m){ +const __m256i ones = _mm256_set1_epi32(1); +const __m256i mask = _mm256_set1_epi32(7); +for(int i = 0; i < n; i++){ +for(int j = 0; j < m; j+=32){ +std::cout<<"not yet implemented"<(); +float* O = out.data_ptr(); +unpack_zeros3_cpu(Z, O, N, M); +} +inline +void q4gemm(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 32; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(15); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/8: + tid*tt*m/8; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +for(; j1 < tb-tu+1; j1+=tu) { +__m256 acc0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 acc0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +__m256 acc0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]); +__m256 acc0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]); +for(int k1 = 0; k1 < mb; k1+=mu) { +for(int k2 = k1; k2 < k1+mu; k2+=8){ +__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/8 + k*mb*tb/8 + k2*tb/8 + j1+0]); +__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/8 + k*mb*tb/8 + k2*tb/8 + j1+8]); +__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/8 + k*mb*tb/8 + k2*tb/8 + j1+16]); +__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/8 + k*mb*tb/8 + k2*tb/8 + j1+24]); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256i ws0_7 = _mm256_srli_epi32(w0, 28); +__m256i ws8_7 = _mm256_srli_epi32(w8, 28); +__m256i ws16_7 = _mm256_srli_epi32(w16, 28); +__m256i ws24_7 = _mm256_srli_epi32(w24, 28); +__m256i wsa0_7= _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7= _mm256_and_si256(ws8_7, mask); +__m256i wsa16_7= _mm256_and_si256(ws16_7, mask); +__m256i wsa24_7= _mm256_and_si256(ws24_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7); +__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256i ws0_6 = _mm256_srli_epi32(w0, 24); +__m256i ws8_6 = _mm256_srli_epi32(w8, 24); +__m256i ws16_6 = _mm256_srli_epi32(w16, 24); +__m256i ws24_6 = _mm256_srli_epi32(w24, 24); +__m256i wsa0_6= _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6= _mm256_and_si256(ws8_6, mask); +__m256i wsa16_6= _mm256_and_si256(ws16_6, mask); +__m256i wsa24_6= _mm256_and_si256(ws24_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6); +__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256i ws0_5 = _mm256_srli_epi32(w0, 20); +__m256i ws8_5 = _mm256_srli_epi32(w8, 20); +__m256i ws16_5 = _mm256_srli_epi32(w16, 20); +__m256i ws24_5 = _mm256_srli_epi32(w24, 20); +__m256i wsa0_5= _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5= _mm256_and_si256(ws8_5, mask); +__m256i wsa16_5= _mm256_and_si256(ws16_5, mask); +__m256i wsa24_5= _mm256_and_si256(ws24_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5); +__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256i ws0_4 = _mm256_srli_epi32(w0, 16); +__m256i ws8_4 = _mm256_srli_epi32(w8, 16); +__m256i ws16_4 = _mm256_srli_epi32(w16, 16); +__m256i ws24_4 = _mm256_srli_epi32(w24, 16); +__m256i wsa0_4= _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4= _mm256_and_si256(ws8_4, mask); +__m256i wsa16_4= _mm256_and_si256(ws16_4, mask); +__m256i wsa24_4= _mm256_and_si256(ws24_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4); +__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256i ws0_3 = _mm256_srli_epi32(w0, 12); +__m256i ws8_3 = _mm256_srli_epi32(w8, 12); +__m256i ws16_3 = _mm256_srli_epi32(w16, 12); +__m256i ws24_3 = _mm256_srli_epi32(w24, 12); +__m256i wsa0_3= _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3= _mm256_and_si256(ws8_3, mask); +__m256i wsa16_3= _mm256_and_si256(ws16_3, mask); +__m256i wsa24_3= _mm256_and_si256(ws24_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3); +__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256i ws0_2 = _mm256_srli_epi32(w0, 8); +__m256i ws8_2 = _mm256_srli_epi32(w8, 8); +__m256i ws16_2 = _mm256_srli_epi32(w16, 8); +__m256i ws24_2 = _mm256_srli_epi32(w24, 8); +__m256i wsa0_2= _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2= _mm256_and_si256(ws8_2, mask); +__m256i wsa16_2= _mm256_and_si256(ws16_2, mask); +__m256i wsa24_2= _mm256_and_si256(ws24_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2); +__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256i ws0_1 = _mm256_srli_epi32(w0, 4); +__m256i ws8_1 = _mm256_srli_epi32(w8, 4); +__m256i ws16_1 = _mm256_srli_epi32(w16, 4); +__m256i ws24_1 = _mm256_srli_epi32(w24, 4); +__m256i wsa0_1= _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1= _mm256_and_si256(ws8_1, mask); +__m256i wsa16_1= _mm256_and_si256(ws16_1, mask); +__m256i wsa24_1= _mm256_and_si256(ws24_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1); +__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w8, 0); +__m256i ws16_0 = _mm256_srli_epi32(w16, 0); +__m256i ws24_0 = _mm256_srli_epi32(w24, 0); +__m256i wsa0_0= _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0= _mm256_and_si256(ws8_0, mask); +__m256i wsa16_0= _mm256_and_si256(ws16_0, mask); +__m256i wsa24_0= _mm256_and_si256(ws24_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0); +__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24); +} +} +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], acc0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], acc0_8); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], acc0_16); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], acc0_24); +} +} +} +} +} +#pragma omp barrier +for (int i = 0; i < n; i++) { +__m256 r = _mm256_set1_ps(sums[i]); +for (int j = 0; j < tt; j+=32){ +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]); +__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + j + 8]); +__m256 z16 = _mm256_loadu_ps(&zeros[base_output + j + 16]); +__m256 z24 = _mm256_loadu_ps(&zeros[base_output + j + 24]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]); +__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]); +__m256 s0 = _mm256_loadu_ps(&scales[base_output + j + 0]); +__m256 s8 = _mm256_loadu_ps(&scales[base_output + j + 8]); +__m256 s16 = _mm256_loadu_ps(&scales[base_output + j + 16]); +__m256 s24 = _mm256_loadu_ps(&scales[base_output + j + 24]); +__m256 zr0 = _mm256_fnmadd_ps(z0, r, o0); +__m256 zr8 = _mm256_fnmadd_ps(z8, r, o8); +__m256 zr16 = _mm256_fnmadd_ps(z16, r, o16); +__m256 zr24 = _mm256_fnmadd_ps(z24, r, o24); +__m256 o20 = _mm256_fmadd_ps(zr0, s0, b0); +__m256 o28 = _mm256_fmadd_ps(zr8, s8, b8); +__m256 o216 = _mm256_fmadd_ps(zr16, s16, b16); +__m256 o224 = _mm256_fmadd_ps(zr24, s24, b24); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216); +_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224); +} +} +} +} +inline void forward4_cpu( +torch::Tensor in, torch::Tensor weight, torch::Tensor out, +torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums, +int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){ +int* W = weight.data_ptr(); +float* input = in.data_ptr(); +float* b = bias.data_ptr(); +float* s = scales.data_ptr(); +float* z = zeros.data_ptr(); +float* r = sums.data_ptr(); +float* O = out.data_ptr(); + +q4gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff); +} +inline +void q4gemm_gs(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int gs, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 16; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(15); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/8: + tid*tt*m/8; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +for(; j1 < tb-tu+1; j1+=tu) { +for(int k1 = 0; k1 < mb; k1+=gs) { +__m256 acc0_0 = _mm256_setzero_ps(); +__m256 acc0_8 = _mm256_setzero_ps(); +for(int k2 = k1; k2 < k1+gs; k2+=8) +{ +__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/8 + k*mb*tb/8 + k2*tb/8 + j1+0]); +__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/8 + k*mb*tb/8 + k2*tb/8 + j1+8]); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256i ws0_4 = _mm256_srli_epi32(w0, 16); +__m256i ws8_4 = _mm256_srli_epi32(w8, 16); +__m256i wsa0_4= _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4= _mm256_and_si256(ws8_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +__m256i ws0_5 = _mm256_srli_epi32(w0, 20); +__m256i ws8_5 = _mm256_srli_epi32(w8, 20); +__m256i wsa0_5= _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5= _mm256_and_si256(ws8_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +__m256i ws0_6 = _mm256_srli_epi32(w0, 24); +__m256i ws8_6 = _mm256_srli_epi32(w8, 24); +__m256i wsa0_6= _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6= _mm256_and_si256(ws8_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +__m256i ws0_7 = _mm256_srli_epi32(w0, 28); +__m256i ws8_7 = _mm256_srli_epi32(w8, 28); +__m256i wsa0_7= _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7= _mm256_and_si256(ws8_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w8, 0); +__m256i wsa0_0= _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0= _mm256_and_si256(ws8_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +__m256i ws0_1 = _mm256_srli_epi32(w0, 4); +__m256i ws8_1 = _mm256_srli_epi32(w8, 4); +__m256i wsa0_1= _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1= _mm256_and_si256(ws8_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +__m256i ws0_2 = _mm256_srli_epi32(w0, 8); +__m256i ws8_2 = _mm256_srli_epi32(w8, 8); +__m256i wsa0_2= _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2= _mm256_and_si256(ws8_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +__m256i ws0_3 = _mm256_srli_epi32(w0, 12); +__m256i ws8_3 = _mm256_srli_epi32(w8, 12); +__m256i wsa0_3= _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3= _mm256_and_si256(ws8_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +} +__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]); +__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]); +__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0); +__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8); +} +} +} +} +} +} +#pragma omp barrier +const int ngs = m/gs; +for (int i = 0; i < n; i++) { +for (int j = 0; j < tt; j+=16){ +__m256 acc0 = _mm256_setzero_ps(); +__m256 acc8 = _mm256_setzero_ps(); +for (int i1 = 0; i1 < ngs; i1++){ +__m256 r = _mm256_set1_ps(sums[i*ngs + i1]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]); +__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]); +__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]); +__m256 zs0 = _mm256_mul_ps(z0, s0); +__m256 zs8 = _mm256_mul_ps(z8, s8); +acc0 = _mm256_fmadd_ps(zs0, r, acc0); +acc8 = _mm256_fmadd_ps(zs8, r, acc8); +} +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 o10 = _mm256_sub_ps(o0, acc0); +__m256 o18 = _mm256_sub_ps(o8, acc8); +__m256 o20 = _mm256_add_ps(o10, b0); +__m256 o28 = _mm256_add_ps(o18, b8); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +} +} +} +} +inline void forward4_gs_cpu( +torch::Tensor in, torch::Tensor weight, torch::Tensor out, +torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums, +int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){ +int* W = weight.data_ptr(); +float* input = in.data_ptr(); +float* b = bias.data_ptr(); +float* s = scales.data_ptr(); +float* z = zeros.data_ptr(); +float* r = sums.data_ptr(); +float* O = out.data_ptr(); + +q4gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff); +} +inline void pack4_qw_inner(int* A, int* B, const int N, const int M, const int nb, int mb, int cutoff){ +// copy the full matrix A in blocked format into B +uint64_t idx = 0; +for(int j = 0, tid = 0; j < M; j+=mb, tid++){ +for(int i = 0; i < N; i+=nb){ + for(int ii = i; ii < mymin(i+nb, N); ii++){ + for(int jj = j; jj < mymin(j+mb, M); jj++){ + B[idx] = A[ii*M+jj]; + idx++; + } + } + } +} +} +inline void pack4_w_cpu( +torch::Tensor in, torch::Tensor out, +int N, int M, int nb, int mb, int cutoff){ +int* input = in.data_ptr(); +int* O = out.data_ptr(); + pack4_qw_inner(input, O, N, M, nb, mb, cutoff); +} +void unpack_zeros4_cpu(const int* zv, float* ov, int n, int m){ +const __m256i ones = _mm256_set1_epi32(1); +const __m256i mask = _mm256_set1_epi32(15); +const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0); +for(int i = 0; i < n; i++){ +for(int j = 0; j < m; j+=8){ +__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]); +__m256i z0 = _mm256_srlv_epi32(z, shift); +__m256i z1 = _mm256_and_si256(z0, mask); +__m256i z2 = _mm256_add_epi32(z1, ones); +__m256 z3 = _mm256_cvtepi32_ps(z2); +_mm256_storeu_ps(&ov[i*m +j], z3); +} +} +} +void unpack_zeros4(torch::Tensor zeros, torch::Tensor out, int N, int M){ +int* Z = zeros.data_ptr(); +float* O = out.data_ptr(); +unpack_zeros4_cpu(Z, O, N, M); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward2", &forward2_cpu); + m.def("forward3", &forward3_cpu); + m.def("forward4", &forward4_cpu); + m.def("unpack_zeros2", &unpack_zeros2); + m.def("unpack_zeros3", &unpack_zeros3); + m.def("unpack_zeros4", &unpack_zeros4); + m.def("forward_gs2", &forward2_gs_cpu); + m.def("forward_gs3", &forward3_gs_cpu); + m.def("forward_gs4", &forward4_gs_cpu); + m.def("pack2", &pack2_w_cpu); + m.def("pack3", &pack3_w_cpu); + m.def("pack4", &pack4_w_cpu); +m.def("compute_reduction_cpp", &compute_reduction); +m.def("unquantize_sim", &unquantize_sim); +m.def("quant_scalar_scaled", &quant_scalar_cpu); +} diff --git a/autogptq_extension/qigen/foo b/autogptq_extension/qigen/foo new file mode 100644 index 0000000..e69de29 diff --git a/autogptq_extension/qigen/forward.h b/autogptq_extension/qigen/forward.h new file mode 100644 index 0000000..9904896 --- /dev/null +++ b/autogptq_extension/qigen/forward.h @@ -0,0 +1,480 @@ +#include +#include +#include + +#define mymin(a,b) ((a)<(b)?(a):(b)) +#define mymax(a,b) ((a)>(b)?(a):(b)) +inline +void q2gemm_gs(const float* __restrict__ input, +const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, + const float* __restrict__ sums, + float* __restrict__ output, +const int n, +const int m, +const int t, +const int nb, +const int mb, +const int tb, +int ogtt, +const int gs, +const int cutoff){ +#pragma omp parallel num_threads(8) +{ +int tid; +const int mu = 16; +const int nu = 1; +const int tu = 32; +const int on = n / nb; +const int om = m / mb; +const __m256i mask = _mm256_set1_epi32(3); +tid = omp_get_thread_num(); +int tt = ogtt; +if(tid >= cutoff){ +tt -= tb; +} +const int base_output = tid >= cutoff ? + (tid-cutoff)*tt + (tt+tb)*cutoff: + tid*tt; +const int base_W = tid >= cutoff ? + ((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16: + tid*tt*m/16; +for(int j = 0; j < tt; j+=tb){ +for(int i = 0; i < on; i++) { +for(int k = 0; k < om; k++) { +for(int i1 = 0; i1 < nb; i1+=nu) { +int j1 = 0; +for(; j1 < tb-tu+1; j1+=tu) { +for(int k1 = 0; k1 < mb; k1+=gs) { +__m256 acc0_0 = _mm256_setzero_ps(); +__m256 acc0_8 = _mm256_setzero_ps(); +__m256 acc0_16 = _mm256_setzero_ps(); +__m256 acc0_24 = _mm256_setzero_ps(); +for(int k2 = k1; k2 < k1+gs; k2+=16) +{ +__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]); +__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]); +__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+16]); +__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+24]); +__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]); +__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]); +__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]); +__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]); +__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]); +__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]); +__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]); +__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]); +__m256i ws0_8 = _mm256_srli_epi32(w0, 16); +__m256i ws8_8 = _mm256_srli_epi32(w8, 16); +__m256i ws16_8 = _mm256_srli_epi32(w16, 16); +__m256i ws24_8 = _mm256_srli_epi32(w24, 16); +__m256i wsa0_8= _mm256_and_si256(ws0_8, mask); +__m256i wsa8_8= _mm256_and_si256(ws8_8, mask); +__m256i wsa16_8= _mm256_and_si256(ws16_8, mask); +__m256i wsa24_8= _mm256_and_si256(ws24_8, mask); +__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8); +__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8); +__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8); +__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8); +acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24); +__m256i ws0_9 = _mm256_srli_epi32(w0, 18); +__m256i ws8_9 = _mm256_srli_epi32(w8, 18); +__m256i ws16_9 = _mm256_srli_epi32(w16, 18); +__m256i ws24_9 = _mm256_srli_epi32(w24, 18); +__m256i wsa0_9= _mm256_and_si256(ws0_9, mask); +__m256i wsa8_9= _mm256_and_si256(ws8_9, mask); +__m256i wsa16_9= _mm256_and_si256(ws16_9, mask); +__m256i wsa24_9= _mm256_and_si256(ws24_9, mask); +__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9); +__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9); +__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9); +__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9); +acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24); +__m256i ws0_10 = _mm256_srli_epi32(w0, 20); +__m256i ws8_10 = _mm256_srli_epi32(w8, 20); +__m256i ws16_10 = _mm256_srli_epi32(w16, 20); +__m256i ws24_10 = _mm256_srli_epi32(w24, 20); +__m256i wsa0_10= _mm256_and_si256(ws0_10, mask); +__m256i wsa8_10= _mm256_and_si256(ws8_10, mask); +__m256i wsa16_10= _mm256_and_si256(ws16_10, mask); +__m256i wsa24_10= _mm256_and_si256(ws24_10, mask); +__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10); +__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10); +__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10); +__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10); +acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24); +__m256i ws0_11 = _mm256_srli_epi32(w0, 22); +__m256i ws8_11 = _mm256_srli_epi32(w8, 22); +__m256i ws16_11 = _mm256_srli_epi32(w16, 22); +__m256i ws24_11 = _mm256_srli_epi32(w24, 22); +__m256i wsa0_11= _mm256_and_si256(ws0_11, mask); +__m256i wsa8_11= _mm256_and_si256(ws8_11, mask); +__m256i wsa16_11= _mm256_and_si256(ws16_11, mask); +__m256i wsa24_11= _mm256_and_si256(ws24_11, mask); +__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11); +__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11); +__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11); +__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11); +acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24); +__m256i ws0_12 = _mm256_srli_epi32(w0, 24); +__m256i ws8_12 = _mm256_srli_epi32(w8, 24); +__m256i ws16_12 = _mm256_srli_epi32(w16, 24); +__m256i ws24_12 = _mm256_srli_epi32(w24, 24); +__m256i wsa0_12= _mm256_and_si256(ws0_12, mask); +__m256i wsa8_12= _mm256_and_si256(ws8_12, mask); +__m256i wsa16_12= _mm256_and_si256(ws16_12, mask); +__m256i wsa24_12= _mm256_and_si256(ws24_12, mask); +__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12); +__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12); +__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12); +__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12); +acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24); +__m256i ws0_13 = _mm256_srli_epi32(w0, 26); +__m256i ws8_13 = _mm256_srli_epi32(w8, 26); +__m256i ws16_13 = _mm256_srli_epi32(w16, 26); +__m256i ws24_13 = _mm256_srli_epi32(w24, 26); +__m256i wsa0_13= _mm256_and_si256(ws0_13, mask); +__m256i wsa8_13= _mm256_and_si256(ws8_13, mask); +__m256i wsa16_13= _mm256_and_si256(ws16_13, mask); +__m256i wsa24_13= _mm256_and_si256(ws24_13, mask); +__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13); +__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13); +__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13); +__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13); +acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24); +__m256i ws0_14 = _mm256_srli_epi32(w0, 28); +__m256i ws8_14 = _mm256_srli_epi32(w8, 28); +__m256i ws16_14 = _mm256_srli_epi32(w16, 28); +__m256i ws24_14 = _mm256_srli_epi32(w24, 28); +__m256i wsa0_14= _mm256_and_si256(ws0_14, mask); +__m256i wsa8_14= _mm256_and_si256(ws8_14, mask); +__m256i wsa16_14= _mm256_and_si256(ws16_14, mask); +__m256i wsa24_14= _mm256_and_si256(ws24_14, mask); +__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14); +__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14); +__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14); +__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14); +acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24); +__m256i ws0_15 = _mm256_srli_epi32(w0, 30); +__m256i ws8_15 = _mm256_srli_epi32(w8, 30); +__m256i ws16_15 = _mm256_srli_epi32(w16, 30); +__m256i ws24_15 = _mm256_srli_epi32(w24, 30); +__m256i wsa0_15= _mm256_and_si256(ws0_15, mask); +__m256i wsa8_15= _mm256_and_si256(ws8_15, mask); +__m256i wsa16_15= _mm256_and_si256(ws16_15, mask); +__m256i wsa24_15= _mm256_and_si256(ws24_15, mask); +__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15); +__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15); +__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15); +__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15); +acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24); +__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]); +__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]); +__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]); +__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]); +__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]); +__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]); +__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]); +__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]); +__m256i ws0_0 = _mm256_srli_epi32(w0, 0); +__m256i ws8_0 = _mm256_srli_epi32(w8, 0); +__m256i ws16_0 = _mm256_srli_epi32(w16, 0); +__m256i ws24_0 = _mm256_srli_epi32(w24, 0); +__m256i wsa0_0= _mm256_and_si256(ws0_0, mask); +__m256i wsa8_0= _mm256_and_si256(ws8_0, mask); +__m256i wsa16_0= _mm256_and_si256(ws16_0, mask); +__m256i wsa24_0= _mm256_and_si256(ws24_0, mask); +__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0); +__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0); +__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0); +__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0); +acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24); +__m256i ws0_1 = _mm256_srli_epi32(w0, 2); +__m256i ws8_1 = _mm256_srli_epi32(w8, 2); +__m256i ws16_1 = _mm256_srli_epi32(w16, 2); +__m256i ws24_1 = _mm256_srli_epi32(w24, 2); +__m256i wsa0_1= _mm256_and_si256(ws0_1, mask); +__m256i wsa8_1= _mm256_and_si256(ws8_1, mask); +__m256i wsa16_1= _mm256_and_si256(ws16_1, mask); +__m256i wsa24_1= _mm256_and_si256(ws24_1, mask); +__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1); +__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1); +__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1); +__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1); +acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24); +__m256i ws0_2 = _mm256_srli_epi32(w0, 4); +__m256i ws8_2 = _mm256_srli_epi32(w8, 4); +__m256i ws16_2 = _mm256_srli_epi32(w16, 4); +__m256i ws24_2 = _mm256_srli_epi32(w24, 4); +__m256i wsa0_2= _mm256_and_si256(ws0_2, mask); +__m256i wsa8_2= _mm256_and_si256(ws8_2, mask); +__m256i wsa16_2= _mm256_and_si256(ws16_2, mask); +__m256i wsa24_2= _mm256_and_si256(ws24_2, mask); +__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2); +__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2); +__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2); +__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2); +acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24); +__m256i ws0_3 = _mm256_srli_epi32(w0, 6); +__m256i ws8_3 = _mm256_srli_epi32(w8, 6); +__m256i ws16_3 = _mm256_srli_epi32(w16, 6); +__m256i ws24_3 = _mm256_srli_epi32(w24, 6); +__m256i wsa0_3= _mm256_and_si256(ws0_3, mask); +__m256i wsa8_3= _mm256_and_si256(ws8_3, mask); +__m256i wsa16_3= _mm256_and_si256(ws16_3, mask); +__m256i wsa24_3= _mm256_and_si256(ws24_3, mask); +__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3); +__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3); +__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3); +__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3); +acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24); +__m256i ws0_4 = _mm256_srli_epi32(w0, 8); +__m256i ws8_4 = _mm256_srli_epi32(w8, 8); +__m256i ws16_4 = _mm256_srli_epi32(w16, 8); +__m256i ws24_4 = _mm256_srli_epi32(w24, 8); +__m256i wsa0_4= _mm256_and_si256(ws0_4, mask); +__m256i wsa8_4= _mm256_and_si256(ws8_4, mask); +__m256i wsa16_4= _mm256_and_si256(ws16_4, mask); +__m256i wsa24_4= _mm256_and_si256(ws24_4, mask); +__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4); +__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4); +__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4); +__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4); +acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24); +__m256i ws0_5 = _mm256_srli_epi32(w0, 10); +__m256i ws8_5 = _mm256_srli_epi32(w8, 10); +__m256i ws16_5 = _mm256_srli_epi32(w16, 10); +__m256i ws24_5 = _mm256_srli_epi32(w24, 10); +__m256i wsa0_5= _mm256_and_si256(ws0_5, mask); +__m256i wsa8_5= _mm256_and_si256(ws8_5, mask); +__m256i wsa16_5= _mm256_and_si256(ws16_5, mask); +__m256i wsa24_5= _mm256_and_si256(ws24_5, mask); +__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5); +__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5); +__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5); +__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5); +acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24); +__m256i ws0_6 = _mm256_srli_epi32(w0, 12); +__m256i ws8_6 = _mm256_srli_epi32(w8, 12); +__m256i ws16_6 = _mm256_srli_epi32(w16, 12); +__m256i ws24_6 = _mm256_srli_epi32(w24, 12); +__m256i wsa0_6= _mm256_and_si256(ws0_6, mask); +__m256i wsa8_6= _mm256_and_si256(ws8_6, mask); +__m256i wsa16_6= _mm256_and_si256(ws16_6, mask); +__m256i wsa24_6= _mm256_and_si256(ws24_6, mask); +__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6); +__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6); +__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6); +__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6); +acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24); +__m256i ws0_7 = _mm256_srli_epi32(w0, 14); +__m256i ws8_7 = _mm256_srli_epi32(w8, 14); +__m256i ws16_7 = _mm256_srli_epi32(w16, 14); +__m256i ws24_7 = _mm256_srli_epi32(w24, 14); +__m256i wsa0_7= _mm256_and_si256(ws0_7, mask); +__m256i wsa8_7= _mm256_and_si256(ws8_7, mask); +__m256i wsa16_7= _mm256_and_si256(ws16_7, mask); +__m256i wsa24_7= _mm256_and_si256(ws24_7, mask); +__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7); +__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7); +__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7); +__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7); +acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0); +acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8); +acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16); +acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24); +} +__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]); +__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]); +__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]); +__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]); +__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]); +__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]); +__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]); +__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]); +__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0); +__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8); +__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16); +__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16); +_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24); +} +} +} +} +} +} +#pragma omp barrier +const int ngs = m/gs; +for (int i = 0; i < n; i++) { +for (int j = 0; j < tt; j+=32){ +__m256 acc0 = _mm256_setzero_ps(); +__m256 acc8 = _mm256_setzero_ps(); +__m256 acc16 = _mm256_setzero_ps(); +__m256 acc24 = _mm256_setzero_ps(); +for (int i1 = 0; i1 < ngs; i1++){ +__m256 r = _mm256_set1_ps(sums[i*ngs + i1]); +__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]); +__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]); +__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]); +__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]); +__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]); +__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]); +__m256 s16 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 16]); +__m256 s24 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 24]); +__m256 zs0 = _mm256_mul_ps(z0, s0); +__m256 zs8 = _mm256_mul_ps(z8, s8); +__m256 zs16 = _mm256_mul_ps(z16, s16); +__m256 zs24 = _mm256_mul_ps(z24, s24); +acc0 = _mm256_fmadd_ps(zs0, r, acc0); +acc8 = _mm256_fmadd_ps(zs8, r, acc8); +acc16 = _mm256_fmadd_ps(zs16, r, acc16); +acc24 = _mm256_fmadd_ps(zs24, r, acc24); +} +__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]); +__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]); +__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]); +__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]); +__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]); +__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]); +__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]); +__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]); +__m256 o10 = _mm256_add_ps(o0, acc0); +__m256 o18 = _mm256_add_ps(o8, acc8); +__m256 o116 = _mm256_add_ps(o16, acc16); +__m256 o124 = _mm256_add_ps(o24, acc24); +__m256 o20 = _mm256_add_ps(o10, b0); +__m256 o28 = _mm256_add_ps(o18, b8); +__m256 o216 = _mm256_add_ps(o116, b16); +__m256 o224 = _mm256_add_ps(o124, b24); +_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20); +_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28); +_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216); +_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224); +} +} +} +} +inline void qforward(const float* __restrict__ input, + const int* __restrict__ W, +const float* __restrict__ scales, +const float* __restrict__ zeros, +const float* __restrict__ bias, +const float* __restrict__ sums, +float* __restrict__ output, +int n, + int m, + int t) { +q2gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, 1, 1024, 32, 512, 64, 9); +} +inline void pack_input(float* A, float* B){ + // copy the full matrix A in blocked format into B + uint64_t idx = 0; + const int N = 1; + const int M = 4096; + const int nb = 1; + const int mb = 1024; + for(int i = 0; i < N; i+=nb){ + for(int j = 0; j < M; j+=mb){ + for(int jj = j; jj < mymin(j+mb, M); jj++){ + for(int ii = i; ii < mymin(i+nb, N); ii++){ + B[idx] = A[ii*M+jj]; + idx++; + } + } + } + } + } +inline void pack_qw_inner(int* A, int* B, int cutoff){ + // copy the full matrix A in blocked format into B + uint64_t idx = 0; + const int N = 256; + const int M = 4096; + const int nb = 64; +int mb = 32; + for(int j = 0, tid = 0; j < M; j+=mb, tid++){ + for(int i = 0; i < N; i+=nb){ + for(int ii = i; ii < mymin(i+nb, N); ii++){ + for(int jj = j; jj < mymin(j+mb, M); jj++){ + B[idx] = A[ii*M+jj]; + idx++; + } + } + } +} +} +inline void pack_qw(int* A, int* B){ + pack_qw_inner(A, B, 65); +} +inline void pack_output(float* A, float* B){ + // copy the full matrix A in blocked format into B + uint64_t idx = 0; + const int N = 1; + const int M = 4096; + const int nb = 1; + const int mb = 32; + for(int i = 0; i < N; i+=nb){ + for(int j = 0; j < M; j+=mb){ + for(int ii = i; ii < mymin(i+nb, N); ii++){ + for(int jj = j; jj < mymin(j+mb, M); jj++){ + B[idx] = A[ii*M+jj]; + idx++; + } + } + } + } + } +void print_parameters(){ +std::ofstream outfile; +outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app); +outfile << 2 << "," << 1 << "," << 16 << "," << 32 << "," << 8 << "," << 8 << "," << 64 << ","; +} diff --git a/autogptq_extension/qigen/mmm b/autogptq_extension/qigen/mmm new file mode 100755 index 0000000000000000000000000000000000000000..51127a3803872caaebe7b242c0d40429d272b2be GIT binary patch literal 41056 zcmeHwe|%F_w(m)sLc!Vur(jjoz(fZGYXgO8M!3^Jq9>YawM8nTp`X+8 zUt9G|2r_SU?9Av~=IP93-sqM4J~$V>qDh-TaU5-7#47%%_$w8WiVEe2=6%;b`=qBK zahQ4UeeU~%qsiWD?X}ikYwfjvpOTF>#|;@8ji5Z4!aoTFmCTYcCA`4U$umewgks@B zd`}S!!Wh6>hLg)n^%??mdpL57u7&By9%;Pgr9wd4RiN~-3gqHSwgv9Rr@qtgbaj38|1{0bkr zi%^y8WctcYqf-O-r;;lEBpvBf%>9|17Rq7-B~9k5<$TGhmGdbn)j?1BNB-)6q`orF z*U$Z!oIWLs7nD?~?>6u$uHq!$ndo`x?YzEp?NG_vK}ji2=(*18Sw6SuI&by0UQc~< z%e5`VbFZCSRM6xrkUIwHWx$8nZ@guZfMy91j=NFFXq90_9+l#-Pc4zdJ_-P&_nWglPBQokRuw|AJ3I&gj+R)Bh6;x)9Ze`JK8?Mr5HzX$ z_E;ME_B8UJNP}OLM*mCF)cZi1dT&Uh|C4Ftyp{%^n@0XUY3Tn09VL~2Ca003OGCdT z4gI%i>W!rF+u}5SSeT~X{b}fFoG{>%{LD+E=XGiDku>cNbyCWonFjwaY4pDt^-fJz z67UtMcZ!fNv%D$d?TkH;P_UKpAP?%oMl|TLO%`k zQ!&bx)yRZnIeirEI037Tk}Wbq4?6ZZoh#~m_0A@LRioeO6rA>oQfIZh(Y?acfIGp%e`*7EOp6}Tgv@&y*^i!*KIY~RtZi*n3q>Id0bA9ugUHz_nSQqZ;BZ2A323ZtE;2uG(9+s!lGN>#A+^)%#{U z@0ewp)wBj`);V3M``+30zSX>LpSQ{8k;zT7Wi@I%Uib3mnslo8n%oMn*;(rJ)Hk^s z{q|~`?e=nC1^lw)mYd7{b6oy44Q?o3Tva0jP^;03)sWHT=TVWcW})1&XzX(_q{!C*Z+!E~t}6UHTIIVWFJX>5hX_J_vdsu$YJVkwQU(6)BNlUNdxf*2 z&}l25>nyvue74iewD-+!X!gtTr4BthSoL&W27z{RRDv0v#cmh!fwO8gFw` zEhDU7O`V90`Wu~7_TS$iDq;$h!JHJ)Y;xLUA&CZO;miBLIkvh?&4O_xn5r5ZtJXN{ z-K(8mcl`=K#b&v`$mMJH%N3qekjn01lHK(bpcLdc-dbAbY^Z9i@_OA~!B^MdT;cXp z7`qX}_04rM!Ez|zaW@LiMfIya;Jw}5)LchlNPVWpLovA4E!4T|>U^sRjit7W53qGr zo_e9Fv8ukBlViYZ@%RMh3~-^E1mu#5Zd|y~IlG`x*4(6IiYVi-FIecDRZx`3nMSg+N3ydD<_HTb zD=iD`PE*0`f?}DxvJzs83QVR^pjl0+21Zet%xVzRPzL)i6W>~v%VOymVJu4tmyORj zVZ3kwD<@12J{K}57kDjpDp?H5z}`WJ8IV#9zGaDW4Ft;%mF>mct*kT3u3mG20KWHu zqep|C$qT%^1aR-Pt1*M>g+T>pBsf?2l*7H>C-_3)Sb|?Gyur(8j?eRWE)-zFX^wmN zXTQMqW~`8fZ_;50gU1MZ4mV%w8Am(b%K+1(0qxVAm*AGZXP|G6@B>a?HT)<{yFe)7 zu(Xj@2;+r?9A3ldvxVC@y#JYdAuUU|hr=GmpDB1a?4SPMv4zeMS^(3d?C+I*wShKb zc((I(iGtV9lJPr}@BFl_F)`MtJTrR-~-k2w$WS zBURRmgfCR#mGvUw$v?`Ytmg@LVD!h7smaoEpPepG~;prSjd8VoG6BQ)V zLKU7L39`~+75-useTfRM^erkDRrpI(^kpjir7C=-3V)djzf6T!&NQjGR)xP@Mc<&p zU!lUcsPI>+@U1GmL4|)%g`c9rZ&Bf=s_>mE{P$J(tt$LgD*QGT{%RF|y9)mU6@I4* zKV5|nsqojR@I5O03>E%-_Phtqd*HkW&U@gz2hMxoya&#E;2-nAm%1xXiEW={i-D{+ zeOT^xg!~!tUa>8b-OU8X=Paoggk$lm?n0uwQXu+TDmgazemoxE!eH7!9UJURz_f8X zHW*I8v|&0n`0ojrHcH0^f17}5gLG{0X9<`#M#l#4OTe_LI5xO40nHaH;x(}w5Rpe6z982rVzioLYKf&B@X zHa4(70n>&C_9tN4$iV&tOdA;3pMYuO0{as%ZCGG`0;Y`$>`%b7L4o}Vm^LP`KLOK* z1okIj+K9mZ1WX$c*q?xD;{p2pMYsY0s9j$ zZ6siS0;Yoj*q?xDLjn5}Fl{7Ye*&fr1ngH}MB~Idibkuz;t_E4-qG-`B>34R_~|70 zcS-Q$N$?{{a4-q(NP^cV!Ocl3?p_>m+y zsKC+T)v;j&W~>dP#o$#Xg5Z!&5Hb_6L;584N0uHXl!0XrVW|DwjiEMh;ZaH?tI;3^ zN@8~DucGA88XQtwlpZtYV^EE4hGY^_Ga}@V=#_*jA3=2jWtZL&1H1A;|6&VzOYB3& zc{DZ?WU&<>&|TRIJnJuxV1+STZ2NG)5^o${BuXDy7F+JHEV5KsZZ8+_`>X`Rb+s7K zUPb4FqGanag*xK8_Akbh&Kt5t>C5Ml6->utM^T9w_-DQ8xLw-AkfPKvbS55`I$pqc z>@}1}!fWb@C_nOnW$KW|bez#ePab%-n7E}5Jrb#-5Z@MQckD4j2L9Q=>7~%`)U!%MzdUhL>L>v&z;qG=k5g6lm*5O~uGBr# z?Qji99RjeW<}APRvaEPi}nH04Aw}LQd=i={A%xVL_EJ zQQAeO#}-kY0hULo#BLY?!PyVa*leI)qJ;zyj3tANXN}Vl+OQ1F)DY2A`Z?o75Cydl zfN(q2DLuWFl}vb6KvmQcq|>n?h>M06$D>`>gv7RPZS>TgOjh*NgoS6s>mMi1<~JSE zVMlfheOqZx!@r3y=ZUiL3$EE6W0`kBNaG#a&ynmF?&of?zZEtzRa(`W%`r zF_`!B=qcY+1ni%pr#2dp|A}5~+l|b|T0Njf$~6#vZ6R_k`0_0&L9P;CzT3-?E5n!X zre(;9`0_1iK(3_WzA#y5ey#aumitcA^@y&+ixJb_7BblH8>TW{2d%cz^~jSp(}$*m zrP4QW5_M&V^!ZJ}d0U@vg}QCshS)gRXGdGx+m5tC9JrfbM^n&f9S9cE_aii0pdor# z!a6#-C5etLMT6{w21YSFFa|w-2au?T<_;@5K0#7Pr;IE?5Rz1_oVqB{m+srA>REr(FhYL0CAXXKQ6bqKAtff09y+fvrem z(_moW@qA=E4(c|q!|o=yFy1yiMt8pp`S9UP=`HlXEK?{tyf_pIuhY*Pi^+&-8(Er@ zDZ;JMukP%No~rDM#&*BF2w!iZmdbE6Ht}UjBOk24q)!ta+TFJZ;KAsz%5FwHG_eoJ zy+qqZbfUSd`CnvBq_^4*4NJY|w{-0XFrQ=j5eSlo)BA!e<89w&tom^z+@jBWR8b_h zH0S5i%Wa3wM2GI&tLa4*y$~Iy`nsaY^_XAwFA-fmk?>l5&O=;6;8_wMJ+?F_fWQpG zcS2CqL=LTdE^K~v!wjNzT&!&iWq`!-J*`}@d&5O?K{mr?Gi-EV@xkb!YnZzKNm~*7 zOC|WGOlqS{e`&;D8D2WJ$YS^?*47qZwA;x0qs2JEsG#X9b2Icj}stw83-H| zD5kxmLxeSVQhZ3C^Vr})!9rmdh2U8c@hp#sT`VFdeU0AnXI%%a+g@VpG3!@QU=K`S zp-HY1gJoERg<4RBK%lwHpYt5;%tYx8ssGDAqD~?kF&>1hOPIf8wow0J?n#(_)q5pp zoBs6wbJKsvO@Er1ei;l^Opk3o236&CmpFsLHo~%CzMD);-?$~zb$UU)Q?6--i3@}?gQ`JvE^~3qk z)7Bhv-uWj==I1&wUO&f-vh`GpGvqb0uH$En%H4aHWr<~}`LgxO)#< z&5$~SeniEH!m*X8PSNub$R`y_NzZM#z($m36>{aXxt?2~3=wUE0)uPeG4$%eS21Rq z_Q_+*U_JmEWCnL46T`7hoX%&k3>|zCGUR-exOzE=b>rZlkgrkZ7ou4t2Ctn1OGVVR zlo>n!FLx98xfqt)HV>97WTm4Kc^ zoz}pLvVhYX*jN$}=SeRZt3e?KS_%VY#Uj8$8Js5VFfOAkP!E!(ffj?b$+&|lz0SA{ zaK7}g@nH^E0@f34rOb|VJjS{ey3o$DOzQpRiPOk5NIkzi@hvjm*dK-)C&bO@6$1`@wdQOk&p4n~FJ@Z-MVOB3R+D`gwfbW&sjb2n0*u-!h`+^H}I(t`m zSc`&AP<8CpZO($cfGjLfVQdKA@P5ZW-R2yevH|o`9xnn3G*=GeB(5W*+dPC6BFS(P z@IHaFR-*`7QtCuMl-PnYsRw&IQr2p8NJF}3 z!jMf3B2^fzz?B5GzYb~_(yqm7%nxesl5++|3K3;QN+QHgM^MW)O`;1M%rMbsVgF(KP5hFbNf!7012JNvLX?Bq;Ov{8-k81C)yZcHA5om8=~y=e##E>?DTgD5{OnV$YyLij0Ub- zu?P~h2ziVQr`N$@FCqr*jXb)bh@@r++`)NY(ieGqX264tpXHiAW& zB;T%5N0lmp3!4=Y&SL#7GjyoplTr)cs)r4o~&{No|H8MKAYp2 z;i!^X0zBSPWAp=z{xR+#$mR0MsSfmX*#nTv@zm4~G#hggD z3#gT_(FCTSW2gxcb)5j|p`u=16yk6XU>%Lu7U@Lh4i?iE^9fz>2##BLb+gNUGz9C300@+hjBl)?zz&F#2p`uRdZ~ zch^Az_eQ$1b{>vC-nd^MeY&xa73_^HnBVytm3~NH`$+yp!~t;KUHgdRL;5-v-PPEO z0tltAkSuh;qRw|o-bEyjv4tWF=5+2Jk*5;1;O@?kBMVA94@CdcxGyjIPGfIgG zuABB0GnrHbL6VB4o$nA66uY}KO3F?mWha3Poga=OZSqGHCcW}e-Ep<%bD#!neH|4u zePPDVcj4;eKztZEG{5sJQaqm&XQE_%NlVlW7cbg*fb@M8mCB8Z%*RymrO2kWb3Oa>`bq$N`_$dn8+B_pjVEE+P5Nh({@GlAK7^o4PTct=;r7uN3nV^LESR`6LZzQ3JU%!Q?r_X_5;vZU zQ^z1R@3%PRxa)nYdVJ*2M06+A&uEpPO$|6pkGtZ!t?;2ZzaEe)VM!4dZck*ZuINM zzTD_HjlH=Lbr7F@_)r%oi3jo7hYt-(p-7i|=f{z*x}Bd#x|(+$heLPn&fLTnd?FTU zAq?O&TW#8ilT!y~CYq?CXQurXpLgJe2-44iriB@;fGUw{a4qoyo9C#(9|XpMb$Bad z)Mq%bf@xv-X@7m-%)5_djis#PJ_D4BMh>|1f?|WP)Y%BmNz`Q zhfd-ct*xCcrwmX+1dmx>E$TWRIgW#bxV(aekJ#P`w8fy&YW!HPf`P5ofGi;-#s*rD ziDsS62+DvAZiYXn!Zx0A3@WD zPcJ^m>kA@X>+=oKzuq|=N5nYVTeK5lrLQN9$3P%@gucFnz!CyqA#j;PbDS_pzaV79 zf!1z?2F*`LKdUUmk!Yoe2(2tx7X7di+YrBTnWZgkKv~5kDs)VuQtKoto;*oxi?t#e z%eLaHl5%1QUnK)#+o!G8>*d|5ZqrmWs3jiFly~12&D&Pe<X63WbvJnw$130t=A1w!&4UgpM6kW>Gh9p^62qVV*r7Qj z_MH;Vy}AxM+!BLUL4$qf9x*UWluS@`(8YY!3NvvI2E7%*HLCL^jK}TNgDiL0rQ>v| zZI{k0Di?#!Y?$o#S$1Wk@bFCKJj!l<+kcTLz2%sG=;q+OnRMUthxbr9&M58DOrsq< z*NcG}_UmUF*Xz2F z@TtrZ$K&kWr#0Gx^CDtkY~+KJle#qaK*O*-{+GxHYji!?xW(OoOM%`Ee$hOn+q4&T zisobM8o;+P$4I0O3@Xy$!xLZ(jaW1c+>BFc9MGEk)_o}kZ-lJ}HY`DANU;~~QOL|S zAK3VXXnt>%LzHHVl8$i&Ru7AT3+#bBTn`Aj6ibeR+T1I|V%1V0wr4rki~BnIzJT;Q1g5gkrT zpOTMF$BA+yD2YgH`zCwi>tb6Zn@4hY!e5pO94gZ5bGXteiD5%NmNw?49UuBDEv@T} z#pp2K_%FBUUi2Fmvh&n~Br=kHHw?Ed9ef8M)zthpEFx8c6Q93=TA`%yMS95sjpy%S zIZ(xWbuSJbelAuCV#Qx{zNx&vLFcw!5&t#|G@2$#eq%Ncy!2w=hhktcg7hfPptP~m zs1XOkbWR9`O9M}7WrdIE+Fz$=Y+YY0=-O*2t9!9|At|@*EozB9u|lxoL`P32co8Uv zGaRkWbUfaJLm(XSV5Y@ngewg#oeJeUf>%9(8(4mBO0>9Lzlo9AT^J?iZ2=B!4r4wI zBf-h{(RdL3Vlggejz+?bOUA}@FccNDLvr*?d)ppu?0WK;`E>IEF<@gR;Cv;EOb{_` zwO<6iTmQLO@UPar*u)%u_)R7QqkgObc33n0My=KCHR^R8wG=IZEHhod%Z z)K}yN)U^jFpJ>5fk-)adG^J}(LplfZAX^?E-;@0n$P|MYem;3rTlVPMJyfUk9&5_M zr2yzae>IrxWKMA63!^!Db?s~E9TDlM%u&aHPhvL4!Q$noXK~!Y#LD3^Xa^9x8su{S z9ThJaZ~R6aj5&-0bT0atjg}qz<(4g`JY86c!5dkF4qJkjxWiR#ET-ng$=YdAdfg_y zd{eOKcOb*U0T;o?5dqPm#d{O|Lkz4l4%=OOFcGzhIAYM2kfrh4n#n|0|-+h zo*hyb=A-EFw6LvX--bzB@O}jDv}sY{c`ifL923pq=EJy@wC=*x0d#n6rS>^|PtZOm z#`oCGXPS4Rsi~EVZJIZzp+D1g(D}X6Toz}yYPw+WuUZz{rPCJaW%(s2Ot4>WEMW`Xox8wt>rU!@0!HN!H4;DsL zkM;r!M$}6a{)H>>l0!&~cT+M%b5HY8QS+w7d~D-?SmF_z=5?Fqq}5z){HpmOmHpCd z7WSsiU?KdXat6gOyZ^CDAyPNR9S5x+ML4Hdl-3zz6h1o>;q!Le`uw=gX3GyO(o44d zo#=`Jn(ag=-V;4`O|KPUR?Q;pS6Q&c^(NM&cUh&^9Ko4UvQ78Q`h3&Aj)Qic__{+H zq*@H3W}i5{kCwfTdq0?rmBr&k$qv)W&)YGp!X=j`7Qo^6v#{8KY6#Dz#u6e7?LmVP z*B-4=X0{r898$G$n?v#%w>n@#Cqh>lPp#7MFHfw`$1qoIg&eI_I&IBdpHB_FM~kXm z4&zt4XRzLJ;7C3EK5hbDqkJumV;KII8V5wpNwgUv94n-O3~K<32g0)75Oc?%tH&Myg;e> zRsSV+DNG8W_Kh7<@8OaD8<=ObnB%(5FXJTwGH)j>#B1Rd^wR04*$F!?4`_u6myunV zvc-6h9X_(6t5jo9=m?2AaTtb9CXr`Bd!+Y>z7SU@u}oz0L9T84WLPr{Ku_zW z8CK%}>OcZ#p;w_694EIDKOSJRAukVtbEv77Q4^Fo0%%#fvj{oR1ZWYoF*U|^a&jlu zU#O*pY$#@%Dcm4pP%3n2`W>2gaRqPDJ&Fa>0VE;YxE|5`YKiXA-57HKkK1jXvb$}a zQOp2P(!N{Q9%I`k>;o|bVw!~?#OX|H18qB}>Ds8%5(nM{(Zhe(fahOBUsIS9iKD?v z^B-Mik5bYN0HBPhfAZc0(H9;2hcs4z%s zFltWPjaNnTyEV}>OK~oDPbB&+J1fFgs1op<5sVUJK!34zl7=OBWI)ix+JDNVH%vkr z78Ji=C39thf(!;-jl0%&ZRQ?b2TfVfI}Zj<*0payABQhYl-ETjbChGJ6fohOg0U1~Rbm9D7SJXHw)oSUUk$ToGx5t+(kT-BnXA(U%6bnXnb_ z%PftOgGg*ybQt!Wf<43Zfjz@_Mj|xuXWkstmVy6btzN^DGFfSTrd8@`8#trcZ9xOJ z?}Jj%yqc~MT9LD4w8g zvj%U_SYiA!-RAosL=gtVs%0r8Q<0>n2A4A)=TVUBEs+tl#xX+ z{!=rDFpffLt`FJX7UxWDh(4`|?*^?HhnKmLVp|umVJ$SS)^)_s&`WUWDuKK6HIt3# zB}1B}#;>gAb;h@K4{1)P+5so#D|A3KgO+;rGA!6QKB4PaMb|hqOyxe-R1Qrq=C9_x zgH}0?CZiWuLZTLFnJF~*GG69T%}6zDeqMz6C}PQ9{9%d-sh0Ynw25DLJuF`?;<~ID zcv{nRT2WX@{g6*ix<8ewW~Y^GcDnSMft?1lCPsce?wgi^oJ~~!iFb~y(w}wD+?~H5 zs11TbaJc#x2DNsYkzAPOyD@op5dNm1md{P?>_j_Y%eT9}2x?2%r8KQi5scx1tdTon zI=jLYxQgxSusg@ZuocpNB;I?Vh=6Sv*dF5uEV&#s+wH*ULtKn%9w9OPVw(mS`H)@u zT9m#J!yn^rw&u~?huyupzMARC||_IwVk93h?6O_#aByg#AP&rTXpIKWEoJy5|p3sKlYTrply zi2{50qb$N+FIWo7b%d3LlLZ(>X!R-K`jYIIa*QITb!>Im1!F#~G~}!m#q@yV z;piIQ<-72wFZ8CsBS`VGZ>jXIUHS%ngXnOu#ggWHHse3c!UWMm=ZBgvkfBnV_yc2U zv}kAL4$RACOb6WRB9+1jn90D))9A}fqgcqZ=^umUJ@Ux;yhY0UPpp~!m~Qt_^{8`Q z5c7Wz9r}SCo+6v*h?9Gg>kHB8!C?b16FdEZ)7N-9!%)nU~Ix`+Ehfa@2 zV}{Sd2vAE!(`gS{I}jkyTUMX>99q(eGouUUmBf2y$fv}xl6q!1mW7T=w$OPrdQGrs ze6V7|GWNED6;;y80*eY--0WbBNl}A$L9p7Og6-^rAB-GXNywcXSVlmI18kAl!+{1W z>E!@cB-EWSzQ!14ld%sER46YQDIehFF-oy+ zqmyWQC{LkDPh|O;{6%^!i7K1*Hq$<=d)wCQGj;9suQ{S|s3-EyiWxlEj5QQM*4ub5 zv*P4He(Q!-a`J5fwf;0}4-6L+wKPYwo{7ezE^CwiD!r<#x1=GwcEARF(_Ztdy3MZ; zBf^~srDM^xdL2O|Zu-lB#s=5M+u~TyJp)~=!8~J??!Li4y=E#_Whd(gq9?swk?uRE_GwC8Jrp?xh_(iA%XHvkiQE$DUODw; z&82qC_eyVotG_gucL@=bIpL#Ntu+_-Xz0*=kiG`Ufd}!@cy#b6eAvo7dmq*h zjSrxm4dT7vNqZvQ z>+^TZ%gBkR4o8+xEW!QgJrPU1JNosvuh}(k+BI*{;u6*07>yn|j9LA3a{#MIb1%$1 z+!g(zsXrT;FlDr4#Nja3_DYk$Qyaq8VBT2s8@l#8(aO_LS%XXBmbTLwy8EAi{#ZG_ zz~;JASHC0kF?k5oyjL1rx=X-0?QE=6_1Yhj1s7m}Iz_u$ANgSA6rG-$;81jE5?-?| z+`SjCJueMM!o|CLL@QM%#ul)AU$i0r^rc*CRx>lCpHES4iXw8jG;StT5 z3D{;?#+{VjX#3=hMRST&{ykJgbI=fJllDt*;hJEdWqLQJ)w}hU;JiHikq+sGLkc9M z!rFU6(Ni;rBZsD7cR%i=CAd~&KCyne<}F)ra$c$HAgvg!j|LL(@;}vyAIQ@KSW+ zP=PLZijl>M06Izua>|@$)eMMOC*}ON>vym@O1J4_EEJV^vq^_6@h%YyO$tx3O}_WU zs#&N4KJvy1?FEJK4c%jlI1jou7HqXy3F6sg?rEA9p$ly_G0fAwnT#ys$Oan3 zWmF|cwXm_0u`XjMdB@JFYB}mb>J2hu14nHcLA7ucn*(I3R*u>_LiHd=v8h0&+QL!W zN2oeEicJVI)mDxQ5fz40fnd348!w{i!67{*zo#Y_ZRbTaL(nT~6m^_yCoiH&!-|*X zs0Z;PHe@4Q2bjPYQp#BEF`h_sM7(M$tq~ zM;<||;f+P~*&YHuUD_A@^d2mxW*(ygjS7L0K%bXMl)Y-owKMS-7Q7(D3oPM*OfF_$ zqOhMAh7XOgj_YE@YTf}}oT@A~B6pl+T<=+;hZ3T5Qn3$@uQ+Fdn!-pha7t};>G!1z?`%M$Dtred!hVb`Z(Z%D9zFBN;s2>YZ|><=c`FX8MKtn2%R z;jxi0+maCS{$vc;ipLTrb|!>mq|#w)f_)$>l@8ky?AfWoDqyo;(`Doz!0-BKpln<@Mbb9JS-m#2rCDR_KAO z|3IljbAsNsmH$wj*3?7brbm><(%U5qLzu#YSw+)G0WbFFOn`J0mj%enE8^IzsGHr2 z#$(?D{qyt>S>Uf-#^U{Zx|i~yn9g-k=h_%5#Gh#B+UYh9e=Yh1=wE~dAUqNK69C-z zJ$nKEw2{=nWgT{uRHZ)lQzm4cQU4;9!~#@LGEGf_Mo%;=Qqe3(&=`n@>^n=&6$zSb zqPa2^&8Zj3DAI?~oCeETat&b!ZJr)5m^7;;Wk5Iwj z{BLnJco6J#=>xk4uY)p_r#~2)!}Bt8A3)mw z@9{Y8>taX^NNab*<3`BZhxB2j`Ojh{jkGu%kH3Pn7HJRoD!bzGTG03Gi^sPj9o`?0 ze~7g6ukrXy=umqg9`_=RA?-w(k2eHAL^=&`-A#ki#du}00cj83|K5tU7JpMegtP^J zl|Y+NdPq-tG;41cG%b3~WjSNB>12TLc)?HT#8w8o37ndZr7aTsxAlvVF8eq9qkO1K z$kpGFn}3t;g4Nlr!jCWg(KWM;Q;ClFm*Ug2JsxL_gv3L5KRzMgTZx}PM6(H>V*Fvl zSdPo*sb*~%Q*O+!-@057WDQ9N13oe6S4YgYT>bqSH(WTT;-SpV+F;hEF>PZrU(=0; zIhG4D9F*FyOB71*`3fIHUp#&l`Cvh=KA2(8&EJ%HW3J(T?Jc>}+Ojf!KR(xB$<0UE zf?V)r+^PcM3b2v?a?mcnL0-5ySHC&KmYW~UT##$nq_yTwyFbg4Ti7-xW6k*7X_j2W z0!5bP0$Xl&_wHBz@+#Qo;nNB^)sVA?bg3X+AZJsim2|O?E;r^DX8uaUq$D)5kOZ_d zSVG|Y3USebJlV1-1621D|H|>X`IcNg>LbHN6n_R<5wUX(`Td4m{X-ehYcuH_)Go-K zwuyAUe~g%0+&0#lyDYOxGd{N%D5wDv*#t6# zhwC6?7&0^o8Da{3tW3uK@n;FKP0rR85}$eK13y8k{RCe*WAtrf@6XsYCYZHZ`%ose*E)Qt|6H5{8hSm} zo-=*MKN5^mCEuBsfP*JZ3~(dTBYYS3Spv8q;|ld)=D4e)yj`SPA=qNWsRJnfAX3h<#)HR z7GMu8CFoIplSlcTEoJ?n%$Z+*8}||2cB~cX;nfNTyz)s8En(;x=H;{(qNkk~(E5QM zWoo-m`wO3!@_G#%?vbcm^G`QuxL6*D;%j)4bWQQi+FM=#jge>n|OaPa02?L z@>9`M|IQHfLg8$gknJ^*y_Ms)E|l?XFNrL+WsJU;Hi>xH{t)R6P>AF-*ce)1`#og4 zIlkmZ8PE1{$Yy28^d(oxc-kk>W8nDz$?^Q)L7+Vy$yeqnwq6JQt-x!ri@EJ*j6E?P z-UECpdp%rE1OClI^>Z)BZ%x8);P~Wu@nPVp-m;{6w=(?2f-*iTb2^Y$3-El%RZuJ~PlLag;q^j@#|hiRgJVY;`Ya5vspOj(UN3CDU9Ohxfswrpc>3LH z<@XW|aOY3zl9U>WF{t%RQ2|NenQ6CetSEOd<{CP0rG_;`KQy+Uy6xz z0@cjh1!U}*4ty$oZUBBNr%@+Wj9xDcT_s~_?@dol8aZo#pOA5xkUWmv&+up4`%B=- z&&uylwle8B3nV<`D8FN=%%`6NpGyC1tOrv0^K#%X%9tQ1>lH3u_&(tkK$VsNpX&F@ z>Gwrb(f^d;^*HmAg=lD?#q+PgPenu||Gs%w8aW5k;AffuB>w6|GNG=|4LSzAO#ijB%T_qf4%l z&dli9kcR#_;7QM9KS$HhzYRRav+{ebtxUQgd=CrhRQmrAc%oN+$F+>p`_j<=Aq_r~ z27fpWKAr}DDHizTpCP-f2s>Xzb}qwnEgMTv({wVMUm0Trz0YX2O-=?O) z7pK8HflnpBJ`MiCG;)M;SEJul;KPq_S1tEA{Z%W3y1KdoS3?7SW8A!Q)(ZT(w{t~P z+08X&H+#wpor(X;YCOtQ;Pl}a#B1vDTifi{;b%iu^-r3c{LWf;RfGK&Gk&CptgG@mJ-DN)^7|U?W!`fC+*NKD<-O%qcCQVO*Lw?Un&Wi(8k$UVoOTf7 zN6?#^7@N(uiv5m!xy4ywa#qZ87I^Se<4Wywo%}cHal_?o@-_PZAqD zbS_E3!`k0bj27^D>hW{$h^;1nGhVfH3DxdK_XGJyO-AxF(YM*n3*S8!2 zQH=m-a#l6B2rh(^7k8M|)TF5>=*Q#F>Xo|aq6{pb$VD@>+Y zaton>JoPJ5wIezP8q-F-j#c0{74b%`;jNmg@r!&7?s{9<%~UqqY4xSM5&H z47$v9)i(O-eY2f+%rebtT7#Bi&E&dwcD--4ZIzm#06%Q+TvEKUxvJjpS?eyVDf5&& z%bheH0AF!Ox!)`+=kXy1Qv{`27&nKS#U9Qa6%jR!dc0jXeWQ7*m_p6{Z?m_zuHspY;JNx7+R{zxg0TE?{ro9 zU9}|(D=RGv>`qg`Y+i_{tn$JivUSc`1+xm~unr{Dx$BzT{!w@Yc9YLpTZMk>R%ySg z3Y|b!p8RN#lp9uG)K@HB}yOqIS58 z4X{pUMd?Ci>{qEQeU96mGn=w2V#NUauRv5`=R;NM}RHIgG5MeKKlXEV9F$ zX9uP%`dl8BSj0rDc&ZVXS6_jMX7USDTcv=7Gf3hXuqg_wFjQK-QCn& z2R#-pbk2r%SvW069DCf2a%?F+C(UVU#nO_7SB%SNwHh2%N6TasjdhD|`Sc z)8uvu1(=)wve~qt(Z?1n1@2nDPOYtmGzCI6E*D#SXvRCx(2%TV$phtz;b zrgBxC2X)EPF*pcVx7XESeXQ&~|4&jnx2K~i95=8hpPw@*sd6vJiy5twr}X>y;3y}- zU(ETHw2+8Y4~|(zJ#zu6`D;19k{U-*(^>QAhqnJ}e&xPGNe!dP&*EQ<^5eKmQ|wpn z8_CLz`7jS{f{i>2G_p!uJcgD&5zXgs|ROAZ=0Q4BxH=XE>dXnSkF_fzL z4F;K9N!h!bDhW8zWsj2nDRP%>8zmm2pj!oiI_@7JS zSMI@-^kNHVOyXCtFD3CS=iy4K-0Lg!$^MIS{)CzhoRX*g3L{SqNif|5lKl$5a$lsR zPbJYO*MBgH|3T^yc$9Qr5`z*yGL;~Fgp9iW2*3AIQtF@TCrO-e0)_N*@2!lTPMk+Q3ZH_U z0cOG}`j>2xOY&QJnpA&@j81&h_(mj!35=Xk_-7>H6>;iRiCs0R*eXLq