Compare commits

..

3 commits

Author SHA1 Message Date
student686
22af50bab0 add new args of save_quantized method to push_to_hub method 2023-10-07 13:59:53 +08:00
student686
fc1184e7bc save_quantized method support shard checkpoint 2023-10-07 13:48:45 +08:00
student686
bf70350153 bump transformers version to 4.34.0 2023-10-07 13:47:37 +08:00
15 changed files with 104 additions and 3048 deletions

View file

@ -12,6 +12,4 @@ from .gpt_bigcode import *
from .codegen import * from .codegen import *
from .baichuan import * from .baichuan import *
from .internlm import * from .internlm import *
from .qwen import * from .qwen import *
from .mistral import *
from .mpt import *

View file

@ -2,6 +2,7 @@ import copy
import json import json
import warnings import warnings
import os import os
import re
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from logging import getLogger from logging import getLogger
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
@ -17,7 +18,7 @@ from safetensors.torch import load_file as safe_load
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers from transformers.utils.generic import ContextManagers
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights, shard_checkpoint
from ._const import * from ._const import *
from ._utils import * from ._utils import *
@ -467,6 +468,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
private: Optional[bool] = None, private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
create_pr: Optional[bool] = False, create_pr: Optional[bool] = False,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
) -> str: ) -> str:
""" """
Upload the model to the Hugging Face Hub. Upload the model to the Hugging Face Hub.
@ -504,7 +507,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if save_dir is not None: if save_dir is not None:
logger.info(f"Saving model to {save_dir}") logger.info(f"Saving model to {save_dir}")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata) self.save_quantized(save_dir, use_safetensors, safetensors_metadata, max_shard_size, model_base_name)
repo_url = create_repo( repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model" repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model"
@ -527,59 +530,104 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
repo_type="model", repo_type="model",
) )
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None): def save_quantized(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk""" """save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
if not self.quantized: if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.") raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to(CPU) if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" state_dict = self.model.state_dict()
if use_safetensors: if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None: model_save_name = model_base_name + ".safetensors"
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata['format'] = "pt"
# Store the quantization configuration as safetensors metadata
from auto_gptq import __version__
safetensors_metadata['auto_gptq_version'] = str(__version__)
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
else: else:
model_save_name = model_base_name + ".bin" model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
# Shard checkpoint
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in shards.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata['format'] = "pt"
# Store the quantization configuration as safetensors metadata
from auto_gptq import __version__
safetensors_metadata['auto_gptq_version'] = str(__version__)
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
torch.save(shard, join(save_dir, shard_file))
if index is not None:
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
self.model.config.save_pretrained(save_dir) self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir) self.quantize_config.save_pretrained(save_dir)
@ -589,7 +637,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs): def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
"""alias of save_quantized""" """alias of save_quantized"""
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata) self.save_quantized(save_dir, use_safetensors, safetensors_metadata, **kwargs)
@classmethod @classmethod
def from_pretrained( def from_pretrained(

View file

@ -21,15 +21,11 @@ SUPPORTED_MODELS = [
"baichuan", "baichuan",
"internlm", "internlm",
"qwen", "qwen",
"mpt",
] ]
if compare_transformers_version("v4.28.0", op="ge"): if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.33.0", op="ge"): if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon") SUPPORTED_MODELS.append("falcon")
if compare_transformers_version("v4.34.0", op="ge"):
SUPPORTED_MODELS.append("mistral")
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

View file

@ -16,8 +16,6 @@ from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM from .qwen import QwenGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = { GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM, "bloom": BloomGPTQForCausalLM,
@ -35,8 +33,6 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"baichuan": BaiChuanGPTQForCausalLM, "baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM, "internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM, "qwen": QwenGPTQForCausalLM,
"mistral": MistralGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
} }

View file

@ -1,16 +0,0 @@
from ._base import *
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MistralDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["MistralGPTQForCausalLM"]

View file

@ -1,18 +0,0 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MPTBlock"
layers_block_name = "transformer.blocks"
outside_layer_modules = [
"transformer.wte", "transformer.norm_f"
]
inside_layer_modules = [
["attn.Wqkv"],
["attn.out_proj"],
["ffn.up_proj"],
["ffn.down_proj"]
]
__all__ = ["MPTGPTQForCausalLM"]

View file

@ -219,7 +219,7 @@ class QuantLinear(nn.Module):
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0) self.wf.unsqueeze(0)
).to(torch.int16 if self.bits == 8 else torch.int8) ).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1) torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
zeros = zeros + 1 zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape) zeros = zeros.reshape(self.scales.shape)
@ -228,7 +228,7 @@ class QuantLinear(nn.Module):
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1) self.wf.unsqueeze(-1)
).to(torch.int16 if self.bits == 8 else torch.int8) ).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2 ** self.bits) - 1) torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
elif self.bits == 3: elif self.bits == 3:
zeros = self.qzeros.reshape( zeros = self.qzeros.reshape(
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1 self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
@ -267,10 +267,10 @@ class QuantLinear(nn.Module):
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1) weights = torch.cat(weights,dim=1)
out = torch.matmul(x.to(weights.dtype), weights) out = torch.matmul(x.half(), weights)
out = out.half().reshape(out_shape) out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.to(x.dtype) return out
__all__ = ["QuantLinear"] __all__ = ["QuantLinear"]

View file

@ -229,7 +229,7 @@ class QuantLinear(nn.Module):
if self.bits in [2,4,8]: if self.bits in [2,4,8]:
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1) torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
zeros = zeros + 1 zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
@ -238,7 +238,7 @@ class QuantLinear(nn.Module):
scales = scales.reshape(-1, 1, scales.shape[-1]) scales = scales.reshape(-1, 1, scales.shape[-1])
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight,(2 ** self.bits) - 1) torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight)
weight = weight.reshape(-1, self.group_size, weight.shape[2]) weight = weight.reshape(-1, self.group_size, weight.shape[2])
elif self.bits == 3: elif self.bits == 3:
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
@ -266,10 +266,10 @@ class QuantLinear(nn.Module):
weight = (scales * (weight - zeros)) weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
out = torch.matmul(x.to(weight.dtype), weight) out = torch.matmul(x.half(), weight)
out = out.half().reshape(out_shape) out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.to(x.dtype) return out
__all__ = ["QuantLinear"] __all__ = ["QuantLinear"]

File diff suppressed because it is too large Load diff

View file

@ -1,480 +0,0 @@
#include<omp.h>
#include<immintrin.h>
#include<fstream>
#define mymin(a,b) ((a)<(b)?(a):(b))
#define mymax(a,b) ((a)>(b)?(a):(b))
inline
void q2gemm_gs(const float* __restrict__ input,
const int* __restrict__ W,
const float* __restrict__ scales,
const float* __restrict__ zeros,
const float* __restrict__ bias,
const float* __restrict__ sums,
float* __restrict__ output,
const int n,
const int m,
const int t,
const int nb,
const int mb,
const int tb,
int ogtt,
const int gs,
const int cutoff){
#pragma omp parallel num_threads(8)
{
int tid;
const int mu = 16;
const int nu = 1;
const int tu = 32;
const int on = n / nb;
const int om = m / mb;
const __m256i mask = _mm256_set1_epi32(3);
tid = omp_get_thread_num();
int tt = ogtt;
if(tid >= cutoff){
tt -= tb;
}
const int base_output = tid >= cutoff ?
(tid-cutoff)*tt + (tt+tb)*cutoff:
tid*tt;
const int base_W = tid >= cutoff ?
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16:
tid*tt*m/16;
for(int j = 0; j < tt; j+=tb){
for(int i = 0; i < on; i++) {
for(int k = 0; k < om; k++) {
for(int i1 = 0; i1 < nb; i1+=nu) {
int j1 = 0;
for(; j1 < tb-tu+1; j1+=tu) {
for(int k1 = 0; k1 < mb; k1+=gs) {
__m256 acc0_0 = _mm256_setzero_ps();
__m256 acc0_8 = _mm256_setzero_ps();
__m256 acc0_16 = _mm256_setzero_ps();
__m256 acc0_24 = _mm256_setzero_ps();
for(int k2 = k1; k2 < k1+gs; k2+=16)
{
__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]);
__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]);
__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+16]);
__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+24]);
__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]);
__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]);
__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]);
__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]);
__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]);
__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]);
__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]);
__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]);
__m256i ws0_8 = _mm256_srli_epi32(w0, 16);
__m256i ws8_8 = _mm256_srli_epi32(w8, 16);
__m256i ws16_8 = _mm256_srli_epi32(w16, 16);
__m256i ws24_8 = _mm256_srli_epi32(w24, 16);
__m256i wsa0_8= _mm256_and_si256(ws0_8, mask);
__m256i wsa8_8= _mm256_and_si256(ws8_8, mask);
__m256i wsa16_8= _mm256_and_si256(ws16_8, mask);
__m256i wsa24_8= _mm256_and_si256(ws24_8, mask);
__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8);
__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8);
__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8);
__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8);
acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24);
__m256i ws0_9 = _mm256_srli_epi32(w0, 18);
__m256i ws8_9 = _mm256_srli_epi32(w8, 18);
__m256i ws16_9 = _mm256_srli_epi32(w16, 18);
__m256i ws24_9 = _mm256_srli_epi32(w24, 18);
__m256i wsa0_9= _mm256_and_si256(ws0_9, mask);
__m256i wsa8_9= _mm256_and_si256(ws8_9, mask);
__m256i wsa16_9= _mm256_and_si256(ws16_9, mask);
__m256i wsa24_9= _mm256_and_si256(ws24_9, mask);
__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9);
__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9);
__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9);
__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9);
acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24);
__m256i ws0_10 = _mm256_srli_epi32(w0, 20);
__m256i ws8_10 = _mm256_srli_epi32(w8, 20);
__m256i ws16_10 = _mm256_srli_epi32(w16, 20);
__m256i ws24_10 = _mm256_srli_epi32(w24, 20);
__m256i wsa0_10= _mm256_and_si256(ws0_10, mask);
__m256i wsa8_10= _mm256_and_si256(ws8_10, mask);
__m256i wsa16_10= _mm256_and_si256(ws16_10, mask);
__m256i wsa24_10= _mm256_and_si256(ws24_10, mask);
__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10);
__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10);
__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10);
__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10);
acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24);
__m256i ws0_11 = _mm256_srli_epi32(w0, 22);
__m256i ws8_11 = _mm256_srli_epi32(w8, 22);
__m256i ws16_11 = _mm256_srli_epi32(w16, 22);
__m256i ws24_11 = _mm256_srli_epi32(w24, 22);
__m256i wsa0_11= _mm256_and_si256(ws0_11, mask);
__m256i wsa8_11= _mm256_and_si256(ws8_11, mask);
__m256i wsa16_11= _mm256_and_si256(ws16_11, mask);
__m256i wsa24_11= _mm256_and_si256(ws24_11, mask);
__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11);
__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11);
__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11);
__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11);
acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24);
__m256i ws0_12 = _mm256_srli_epi32(w0, 24);
__m256i ws8_12 = _mm256_srli_epi32(w8, 24);
__m256i ws16_12 = _mm256_srli_epi32(w16, 24);
__m256i ws24_12 = _mm256_srli_epi32(w24, 24);
__m256i wsa0_12= _mm256_and_si256(ws0_12, mask);
__m256i wsa8_12= _mm256_and_si256(ws8_12, mask);
__m256i wsa16_12= _mm256_and_si256(ws16_12, mask);
__m256i wsa24_12= _mm256_and_si256(ws24_12, mask);
__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12);
__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12);
__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12);
__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12);
acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24);
__m256i ws0_13 = _mm256_srli_epi32(w0, 26);
__m256i ws8_13 = _mm256_srli_epi32(w8, 26);
__m256i ws16_13 = _mm256_srli_epi32(w16, 26);
__m256i ws24_13 = _mm256_srli_epi32(w24, 26);
__m256i wsa0_13= _mm256_and_si256(ws0_13, mask);
__m256i wsa8_13= _mm256_and_si256(ws8_13, mask);
__m256i wsa16_13= _mm256_and_si256(ws16_13, mask);
__m256i wsa24_13= _mm256_and_si256(ws24_13, mask);
__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13);
__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13);
__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13);
__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13);
acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24);
__m256i ws0_14 = _mm256_srli_epi32(w0, 28);
__m256i ws8_14 = _mm256_srli_epi32(w8, 28);
__m256i ws16_14 = _mm256_srli_epi32(w16, 28);
__m256i ws24_14 = _mm256_srli_epi32(w24, 28);
__m256i wsa0_14= _mm256_and_si256(ws0_14, mask);
__m256i wsa8_14= _mm256_and_si256(ws8_14, mask);
__m256i wsa16_14= _mm256_and_si256(ws16_14, mask);
__m256i wsa24_14= _mm256_and_si256(ws24_14, mask);
__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14);
__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14);
__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14);
__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14);
acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24);
__m256i ws0_15 = _mm256_srli_epi32(w0, 30);
__m256i ws8_15 = _mm256_srli_epi32(w8, 30);
__m256i ws16_15 = _mm256_srli_epi32(w16, 30);
__m256i ws24_15 = _mm256_srli_epi32(w24, 30);
__m256i wsa0_15= _mm256_and_si256(ws0_15, mask);
__m256i wsa8_15= _mm256_and_si256(ws8_15, mask);
__m256i wsa16_15= _mm256_and_si256(ws16_15, mask);
__m256i wsa24_15= _mm256_and_si256(ws24_15, mask);
__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15);
__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15);
__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15);
__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15);
acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24);
__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]);
__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]);
__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]);
__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]);
__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]);
__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]);
__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]);
__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]);
__m256i ws0_0 = _mm256_srli_epi32(w0, 0);
__m256i ws8_0 = _mm256_srli_epi32(w8, 0);
__m256i ws16_0 = _mm256_srli_epi32(w16, 0);
__m256i ws24_0 = _mm256_srli_epi32(w24, 0);
__m256i wsa0_0= _mm256_and_si256(ws0_0, mask);
__m256i wsa8_0= _mm256_and_si256(ws8_0, mask);
__m256i wsa16_0= _mm256_and_si256(ws16_0, mask);
__m256i wsa24_0= _mm256_and_si256(ws24_0, mask);
__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0);
__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0);
__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0);
__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0);
acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24);
__m256i ws0_1 = _mm256_srli_epi32(w0, 2);
__m256i ws8_1 = _mm256_srli_epi32(w8, 2);
__m256i ws16_1 = _mm256_srli_epi32(w16, 2);
__m256i ws24_1 = _mm256_srli_epi32(w24, 2);
__m256i wsa0_1= _mm256_and_si256(ws0_1, mask);
__m256i wsa8_1= _mm256_and_si256(ws8_1, mask);
__m256i wsa16_1= _mm256_and_si256(ws16_1, mask);
__m256i wsa24_1= _mm256_and_si256(ws24_1, mask);
__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1);
__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1);
__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1);
__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1);
acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24);
__m256i ws0_2 = _mm256_srli_epi32(w0, 4);
__m256i ws8_2 = _mm256_srli_epi32(w8, 4);
__m256i ws16_2 = _mm256_srli_epi32(w16, 4);
__m256i ws24_2 = _mm256_srli_epi32(w24, 4);
__m256i wsa0_2= _mm256_and_si256(ws0_2, mask);
__m256i wsa8_2= _mm256_and_si256(ws8_2, mask);
__m256i wsa16_2= _mm256_and_si256(ws16_2, mask);
__m256i wsa24_2= _mm256_and_si256(ws24_2, mask);
__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2);
__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2);
__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2);
__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2);
acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24);
__m256i ws0_3 = _mm256_srli_epi32(w0, 6);
__m256i ws8_3 = _mm256_srli_epi32(w8, 6);
__m256i ws16_3 = _mm256_srli_epi32(w16, 6);
__m256i ws24_3 = _mm256_srli_epi32(w24, 6);
__m256i wsa0_3= _mm256_and_si256(ws0_3, mask);
__m256i wsa8_3= _mm256_and_si256(ws8_3, mask);
__m256i wsa16_3= _mm256_and_si256(ws16_3, mask);
__m256i wsa24_3= _mm256_and_si256(ws24_3, mask);
__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3);
__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3);
__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3);
__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3);
acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24);
__m256i ws0_4 = _mm256_srli_epi32(w0, 8);
__m256i ws8_4 = _mm256_srli_epi32(w8, 8);
__m256i ws16_4 = _mm256_srli_epi32(w16, 8);
__m256i ws24_4 = _mm256_srli_epi32(w24, 8);
__m256i wsa0_4= _mm256_and_si256(ws0_4, mask);
__m256i wsa8_4= _mm256_and_si256(ws8_4, mask);
__m256i wsa16_4= _mm256_and_si256(ws16_4, mask);
__m256i wsa24_4= _mm256_and_si256(ws24_4, mask);
__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4);
__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4);
__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4);
__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4);
acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24);
__m256i ws0_5 = _mm256_srli_epi32(w0, 10);
__m256i ws8_5 = _mm256_srli_epi32(w8, 10);
__m256i ws16_5 = _mm256_srli_epi32(w16, 10);
__m256i ws24_5 = _mm256_srli_epi32(w24, 10);
__m256i wsa0_5= _mm256_and_si256(ws0_5, mask);
__m256i wsa8_5= _mm256_and_si256(ws8_5, mask);
__m256i wsa16_5= _mm256_and_si256(ws16_5, mask);
__m256i wsa24_5= _mm256_and_si256(ws24_5, mask);
__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5);
__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5);
__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5);
__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5);
acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24);
__m256i ws0_6 = _mm256_srli_epi32(w0, 12);
__m256i ws8_6 = _mm256_srli_epi32(w8, 12);
__m256i ws16_6 = _mm256_srli_epi32(w16, 12);
__m256i ws24_6 = _mm256_srli_epi32(w24, 12);
__m256i wsa0_6= _mm256_and_si256(ws0_6, mask);
__m256i wsa8_6= _mm256_and_si256(ws8_6, mask);
__m256i wsa16_6= _mm256_and_si256(ws16_6, mask);
__m256i wsa24_6= _mm256_and_si256(ws24_6, mask);
__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6);
__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6);
__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6);
__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6);
acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24);
__m256i ws0_7 = _mm256_srli_epi32(w0, 14);
__m256i ws8_7 = _mm256_srli_epi32(w8, 14);
__m256i ws16_7 = _mm256_srli_epi32(w16, 14);
__m256i ws24_7 = _mm256_srli_epi32(w24, 14);
__m256i wsa0_7= _mm256_and_si256(ws0_7, mask);
__m256i wsa8_7= _mm256_and_si256(ws8_7, mask);
__m256i wsa16_7= _mm256_and_si256(ws16_7, mask);
__m256i wsa24_7= _mm256_and_si256(ws24_7, mask);
__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7);
__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7);
__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7);
__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7);
acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0);
acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8);
acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16);
acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24);
}
__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]);
__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]);
__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]);
__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]);
__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]);
__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]);
__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]);
__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]);
__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0);
__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8);
__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16);
__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16);
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24);
}
}
}
}
}
}
#pragma omp barrier
const int ngs = m/gs;
for (int i = 0; i < n; i++) {
for (int j = 0; j < tt; j+=32){
__m256 acc0 = _mm256_setzero_ps();
__m256 acc8 = _mm256_setzero_ps();
__m256 acc16 = _mm256_setzero_ps();
__m256 acc24 = _mm256_setzero_ps();
for (int i1 = 0; i1 < ngs; i1++){
__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);
__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]);
__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]);
__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]);
__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]);
__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]);
__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]);
__m256 s16 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 16]);
__m256 s24 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 24]);
__m256 zs0 = _mm256_mul_ps(z0, s0);
__m256 zs8 = _mm256_mul_ps(z8, s8);
__m256 zs16 = _mm256_mul_ps(z16, s16);
__m256 zs24 = _mm256_mul_ps(z24, s24);
acc0 = _mm256_fmadd_ps(zs0, r, acc0);
acc8 = _mm256_fmadd_ps(zs8, r, acc8);
acc16 = _mm256_fmadd_ps(zs16, r, acc16);
acc24 = _mm256_fmadd_ps(zs24, r, acc24);
}
__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]);
__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]);
__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]);
__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]);
__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]);
__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]);
__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]);
__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]);
__m256 o10 = _mm256_add_ps(o0, acc0);
__m256 o18 = _mm256_add_ps(o8, acc8);
__m256 o116 = _mm256_add_ps(o16, acc16);
__m256 o124 = _mm256_add_ps(o24, acc24);
__m256 o20 = _mm256_add_ps(o10, b0);
__m256 o28 = _mm256_add_ps(o18, b8);
__m256 o216 = _mm256_add_ps(o116, b16);
__m256 o224 = _mm256_add_ps(o124, b24);
_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20);
_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28);
_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216);
_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224);
}
}
}
}
inline void qforward(const float* __restrict__ input,
const int* __restrict__ W,
const float* __restrict__ scales,
const float* __restrict__ zeros,
const float* __restrict__ bias,
const float* __restrict__ sums,
float* __restrict__ output,
int n,
int m,
int t) {
q2gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, 1, 1024, 32, 512, 64, 9);
}
inline void pack_input(float* A, float* B){
// copy the full matrix A in blocked format into B
uint64_t idx = 0;
const int N = 1;
const int M = 4096;
const int nb = 1;
const int mb = 1024;
for(int i = 0; i < N; i+=nb){
for(int j = 0; j < M; j+=mb){
for(int jj = j; jj < mymin(j+mb, M); jj++){
for(int ii = i; ii < mymin(i+nb, N); ii++){
B[idx] = A[ii*M+jj];
idx++;
}
}
}
}
}
inline void pack_qw_inner(int* A, int* B, int cutoff){
// copy the full matrix A in blocked format into B
uint64_t idx = 0;
const int N = 256;
const int M = 4096;
const int nb = 64;
int mb = 32;
for(int j = 0, tid = 0; j < M; j+=mb, tid++){
for(int i = 0; i < N; i+=nb){
for(int ii = i; ii < mymin(i+nb, N); ii++){
for(int jj = j; jj < mymin(j+mb, M); jj++){
B[idx] = A[ii*M+jj];
idx++;
}
}
}
}
}
inline void pack_qw(int* A, int* B){
pack_qw_inner(A, B, 65);
}
inline void pack_output(float* A, float* B){
// copy the full matrix A in blocked format into B
uint64_t idx = 0;
const int N = 1;
const int M = 4096;
const int nb = 1;
const int mb = 32;
for(int i = 0; i < N; i+=nb){
for(int j = 0; j < M; j+=mb){
for(int ii = i; ii < mymin(i+nb, N); ii++){
for(int jj = j; jj < mymin(j+mb, M); jj++){
B[idx] = A[ii*M+jj];
idx++;
}
}
}
}
}
void print_parameters(){
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
outfile << 2 << "," << 1 << "," << 16 << "," << 32 << "," << 8 << "," << 8 << "," << 64 << ",";
}

Binary file not shown.

View file

@ -1,37 +0,0 @@
bits,nu,mu,tu,unroll,p,gs,time
4,1,16,16,1,8,-1,1.3814e+06
4,1,16,16,2,8,-1,1.44087e+06
4,1,16,16,4,8,-1,1.56173e+06
4,1,16,16,8,8,-1,1.41389e+06
3,1,16,16,5,8,-1,2.14748e+09
2,1,16,16,1,8,-1,1.09513e+06
2,1,16,16,2,8,-1,1.11322e+06
2,1,16,16,4,8,-1,1.12031e+06
2,1,16,16,8,8,-1,1.19086e+06
4,1,16,16,1,8,64,1.69111e+06
4,1,16,16,2,8,64,1.60056e+06
4,1,16,16,4,8,64,1.41263e+06
4,1,16,16,8,8,64,1.74572e+06
3,1,16,16,5,8,64,1.48062e+06
2,1,16,16,1,8,64,1.51234e+06
2,1,16,16,2,8,64,1.68108e+06
2,1,16,16,4,8,64,1.7624e+06
2,1,16,16,8,8,64,1.69563e+06
4,1,16,32,1,8,-1,1.24798e+06
4,1,16,32,2,8,-1,1.58421e+06
4,1,16,32,4,8,-1,2.10718e+06
4,1,16,32,8,8,-1,1.54288e+06
3,1,16,32,5,8,-1,2.14748e+09
2,1,16,32,1,8,-1,1.55906e+06
2,1,16,32,2,8,-1,1.58576e+06
2,1,16,32,4,8,-1,1.57993e+06
2,1,16,32,8,8,-1,1.80443e+06
4,1,16,32,1,8,64,1.58354e+06
4,1,16,32,2,8,64,1.63248e+06
4,1,16,32,4,8,64,1.91902e+06
4,1,16,32,8,8,64,1.9243e+06
3,1,16,32,5,8,64,1.33812e+06
2,1,16,32,1,8,64,1.77522e+06
2,1,16,32,2,8,64,1.54702e+06
2,1,16,32,4,8,64,1.78772e+06
2,1,16,32,8,8,64,1.49612e+06
1 bits nu mu tu unroll p gs time
2 4 1 16 16 1 8 -1 1.3814e+06
3 4 1 16 16 2 8 -1 1.44087e+06
4 4 1 16 16 4 8 -1 1.56173e+06
5 4 1 16 16 8 8 -1 1.41389e+06
6 3 1 16 16 5 8 -1 2.14748e+09
7 2 1 16 16 1 8 -1 1.09513e+06
8 2 1 16 16 2 8 -1 1.11322e+06
9 2 1 16 16 4 8 -1 1.12031e+06
10 2 1 16 16 8 8 -1 1.19086e+06
11 4 1 16 16 1 8 64 1.69111e+06
12 4 1 16 16 2 8 64 1.60056e+06
13 4 1 16 16 4 8 64 1.41263e+06
14 4 1 16 16 8 8 64 1.74572e+06
15 3 1 16 16 5 8 64 1.48062e+06
16 2 1 16 16 1 8 64 1.51234e+06
17 2 1 16 16 2 8 64 1.68108e+06
18 2 1 16 16 4 8 64 1.7624e+06
19 2 1 16 16 8 8 64 1.69563e+06
20 4 1 16 32 1 8 -1 1.24798e+06
21 4 1 16 32 2 8 -1 1.58421e+06
22 4 1 16 32 4 8 -1 2.10718e+06
23 4 1 16 32 8 8 -1 1.54288e+06
24 3 1 16 32 5 8 -1 2.14748e+09
25 2 1 16 32 1 8 -1 1.55906e+06
26 2 1 16 32 2 8 -1 1.58576e+06
27 2 1 16 32 4 8 -1 1.57993e+06
28 2 1 16 32 8 8 -1 1.80443e+06
29 4 1 16 32 1 8 64 1.58354e+06
30 4 1 16 32 2 8 64 1.63248e+06
31 4 1 16 32 4 8 64 1.91902e+06
32 4 1 16 32 8 8 64 1.9243e+06
33 3 1 16 32 5 8 64 1.33812e+06
34 2 1 16 32 1 8 64 1.77522e+06
35 2 1 16 32 2 8 64 1.54702e+06
36 2 1 16 32 4 8 64 1.78772e+06
37 2 1 16 32 8 8 64 1.49612e+06

Binary file not shown.

View file

@ -80,7 +80,7 @@ requirements = [
"gekko", "gekko",
"torch>=1.13.0", "torch>=1.13.0",
"safetensors", "safetensors",
"transformers>=4.31.0", "transformers>=4.34.0",
"peft", "peft",
"tqdm", "tqdm",
] ]
@ -98,10 +98,7 @@ if BUILD_CUDA_EXT:
if platform.system() != 'Windows': if platform.system() != 'Windows':
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2]) p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
ret = subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)]) subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
# if ret != 0:
# raise Exception(f"Failed generate with {ret}")
# sys.exit(-1)
if not ROCM_VERSION: if not ROCM_VERSION:
from distutils.sysconfig import get_python_lib from distutils.sysconfig import get_python_lib