This commit is contained in:
qwopqwop200 2023-09-06 16:39:22 +09:00 committed by GitHub
parent 1793227283
commit f752336cda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 154 additions and 161 deletions

View file

@ -30,9 +30,9 @@
// }
// #endif
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui;
@ -113,6 +113,7 @@ __global__ void VecQuant4MatMulKernel(
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
@ -312,7 +313,7 @@ __global__ void VecQuant2MatMulKernel(
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
@ -446,12 +447,12 @@ __global__ void VecQuant3MatMulKernel(
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scalar_t((z_tmp) + 1);
zero = scalar_t(z_tmp);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scalar_t((z_tmp) + 1);
zero = scalar_t(z_tmp);
} else {
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
}
if (k_mod == 10) {
@ -545,7 +546,7 @@ __global__ void VecQuant4MatMulKernel(
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
@ -632,7 +633,7 @@ __global__ void VecQuant8MatMulKernel(
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
@ -723,7 +724,7 @@ __global__ void VecQuant2MatMulKernel_old(
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
@ -846,12 +847,12 @@ __global__ void VecQuant3MatMulKernel_old(
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scale * scalar_t((z_tmp) + 1);
zero = scale * scalar_t(z_tmp);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scale * scalar_t((z_tmp) + 1);
zero = scale * scalar_t(z_tmp);
} else {
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
}
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
@ -977,7 +978,7 @@ __global__ void VecQuant4MatMulKernel_old(
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
@ -1064,7 +1065,7 @@ __global__ void VecQuant8MatMulKernel_old(
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
@ -1159,7 +1160,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3)));
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
@ -1287,12 +1288,12 @@ __global__ void VecQuant3MatMulKernelFaster_old(
half2 zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
zero = __float2half2_rn(-(scale_f * z_tmp));
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
zero = __float2half2_rn(-(scale_f * z_tmp));
} else {
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7)));
}
std::memset(&res2, 0, sizeof(half2));
@ -1410,13 +1411,8 @@ __global__ void VecQuant4MatMulKernelFaster_old(
while (k < blockwidth2) {
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
//std::memset(&res2, 0, sizeof(half2));
//res2 = __float2half2_rn((float)0.);
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF)));
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
@ -1426,9 +1422,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
i += width;
k += 4;
res += __low2float(res2) + __high2float(res2);
}
atomicAdd(&mul[b * width + w], res);

View file

@ -313,7 +313,7 @@ __global__ void VecQuant2MatMulKernel(
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
@ -447,12 +447,12 @@ __global__ void VecQuant3MatMulKernel(
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scalar_t((z_tmp) + 1);
zero = scalar_t(z_tmp);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scalar_t((z_tmp) + 1);
zero = scalar_t(z_tmp);
} else {
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
}
if (k_mod == 10) {
@ -546,7 +546,7 @@ __global__ void VecQuant4MatMulKernel(
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
@ -633,7 +633,7 @@ __global__ void VecQuant8MatMulKernel(
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
@ -724,7 +724,7 @@ __global__ void VecQuant2MatMulKernel_old(
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
@ -847,12 +847,12 @@ __global__ void VecQuant3MatMulKernel_old(
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scale * scalar_t((z_tmp) + 1);
zero = scale * scalar_t(z_tmp);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scale * scalar_t((z_tmp) + 1);
zero = scale * scalar_t(z_tmp);
} else {
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
}
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
@ -978,7 +978,7 @@ __global__ void VecQuant4MatMulKernel_old(
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
@ -1065,7 +1065,7 @@ __global__ void VecQuant8MatMulKernel_old(
int g = (g_h + k) / groupsize;
scalar_t scale = scales[g * width + w];
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
@ -1160,7 +1160,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3)));
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);
@ -1288,12 +1288,12 @@ __global__ void VecQuant3MatMulKernelFaster_old(
half2 zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
zero = __float2half2_rn(-(scale_f * z_tmp));
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
zero = __float2half2_rn(-(scale_f * z_tmp));
} else {
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7)));
}
std::memset(&res2, 0, sizeof(half2));
@ -1412,7 +1412,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
int g = (g_h + (k * 2)) / groupsize;
float scale_f = scales[g * width + w];
half2 scale = __float2half2_rn(scale_f);
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF)));
std::memset(&res2, 0, sizeof(half2));
tmp = as_unsigned(mat[i]);

View file

@ -87,7 +87,7 @@ __global__ void q4_matmul_kernel
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = w_zeros_.item(group, w_column);
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
@ -95,7 +95,7 @@ __global__ void q4_matmul_kernel
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = w_zeros_.item(group, w_column);
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
@ -112,7 +112,7 @@ __global__ void q4_matmul_kernel
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = w_zeros_.item(group, w_column);
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
@ -121,7 +121,7 @@ __global__ void q4_matmul_kernel
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = w_zeros_.item(group, w_column);
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);

View file

@ -197,7 +197,7 @@ __global__ void reconstruct_kernel
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_zero = w_zeros_.item(group, column);
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);

View file

@ -94,7 +94,7 @@ __device__ __forceinline__ half2 dot_product_8
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const uint32_t v_zero,
const int count
)
{
@ -145,7 +145,7 @@ __device__ __forceinline__ half dot_product_8_h
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const uint32_t v_zero,
const int count
)
{
@ -192,7 +192,7 @@ __device__ __forceinline__ half2 dot_product_8_x_map
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const uint32_t v_zero,
const int count,
const uint32_t* x_map
)
@ -254,7 +254,7 @@ __device__ __forceinline__ half dot_product_8_x_map_h
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const uint32_t v_zero,
const int count,
const uint32_t* x_map
)

View file

@ -1162,7 +1162,7 @@ def unpack_zeros(bits):
res += f"void unpack_zeros{bits}_cpu(const int* zv, float* ov, int n, int m)"
packed = 32//bits
mask = (2**bits)-1
res += "{\nconst __m256i ones = _mm256_set1_epi32(1);\n"
res += "{\n"
res += f"const __m256i mask = _mm256_set1_epi32({mask});\n"
if bits == 4:
res += "const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);\n"
@ -1179,15 +1179,14 @@ def unpack_zeros(bits):
res += "__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);\n"
res += "__m256i z0 = _mm256_srlv_epi32(z, shift);\n"
res += "__m256i z1 = _mm256_and_si256(z0, mask);\n"
res += "__m256i z2 = _mm256_add_epi32(z1, ones);\n"
res += "__m256 z3 = _mm256_cvtepi32_ps(z2);\n"
res += "_mm256_storeu_ps(&ov[i*m +j], z3);\n"
res += "__m256 z2 = _mm256_cvtepi32_ps(z1);\n"
res += "_mm256_storeu_ps(&ov[i*m +j], z2);\n"
elif bits == 2:
res += f"for (int j = 0; j < m; j+={packed})"
res += "{\n"
res += f"for (int k = 0; k < {packed}; k++)"
res += "{\n"
res += f"ov[i*m + j+k] = (((zv[j/{packed}] >> ({bits}*k)) & {mask})+1);\n"
res += f"ov[i*m + j+k] = ((zv[j/{packed}] >> ({bits}*k)) & {mask});\n"
res += "}\n"
# res += "for(int j = 0; j < m; j+=16){\n"
# res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n"