Compare commits
3 commits
main
...
weights_sh
Author | SHA1 | Date | |
---|---|---|---|
|
22af50bab0 | ||
|
fc1184e7bc | ||
|
bf70350153 |
15 changed files with 104 additions and 3048 deletions
|
@ -13,5 +13,3 @@ from .codegen import *
|
|||
from .baichuan import *
|
||||
from .internlm import *
|
||||
from .qwen import *
|
||||
from .mistral import *
|
||||
from .mpt import *
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
import json
|
||||
import warnings
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field, fields
|
||||
from logging import getLogger
|
||||
from os.path import join, isfile, isdir
|
||||
|
@ -17,7 +18,7 @@ from safetensors.torch import load_file as safe_load
|
|||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||
from transformers.utils.generic import ContextManagers
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.modeling_utils import no_init_weights, shard_checkpoint
|
||||
|
||||
from ._const import *
|
||||
from ._utils import *
|
||||
|
@ -467,6 +468,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: Optional[bool] = False,
|
||||
max_shard_size: str = "10GB",
|
||||
model_base_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Upload the model to the Hugging Face Hub.
|
||||
|
@ -504,7 +507,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
if save_dir is not None:
|
||||
logger.info(f"Saving model to {save_dir}")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata, max_shard_size, model_base_name)
|
||||
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model"
|
||||
|
@ -527,20 +530,55 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
repo_type="model",
|
||||
)
|
||||
|
||||
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
|
||||
def save_quantized(
|
||||
self,
|
||||
save_dir: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
max_shard_size: str = "10GB",
|
||||
model_base_name: Optional[str] = None
|
||||
):
|
||||
"""save quantized model and configs to local disk"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not self.quantized:
|
||||
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
||||
|
||||
self.model.to(CPU)
|
||||
if model_base_name is None:
|
||||
model_base_name = (
|
||||
self.quantize_config.model_file_base_name or
|
||||
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
)
|
||||
|
||||
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
if use_safetensors:
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
state_dict = self.model.state_dict()
|
||||
if use_safetensors:
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
else:
|
||||
model_save_name = model_base_name + ".bin"
|
||||
|
||||
# Shard checkpoint
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_dir):
|
||||
full_filename = join(save_dir, filename)
|
||||
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||
|
||||
if (
|
||||
filename.startswith(model_base_name)
|
||||
and isfile(full_filename)
|
||||
and filename not in shards.keys()
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save the model
|
||||
for shard_file, shard in shards.items():
|
||||
if use_safetensors:
|
||||
if safetensors_metadata is None:
|
||||
safetensors_metadata = {}
|
||||
elif not isinstance(safetensors_metadata, dict):
|
||||
|
@ -556,13 +594,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
new_key = str(key)
|
||||
new_value = str(value)
|
||||
except Exception as e:
|
||||
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
||||
raise TypeError(
|
||||
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
||||
if new_key in new_safetensors_metadata:
|
||||
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
||||
logger.warning(
|
||||
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
||||
new_safetensors_metadata[new_key] = new_value
|
||||
safetensors_metadata = new_safetensors_metadata
|
||||
if converted_keys:
|
||||
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
||||
logger.debug(
|
||||
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
||||
|
||||
# Format is required to enable Accelerate to load the metadata
|
||||
# otherwise it raises an OSError
|
||||
|
@ -576,10 +617,17 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
|
||||
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
|
||||
|
||||
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
|
||||
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
|
||||
else:
|
||||
model_save_name = model_base_name + ".bin"
|
||||
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
||||
torch.save(shard, join(save_dir, shard_file))
|
||||
|
||||
if index is not None:
|
||||
index_save_name = model_save_name + ".index.json"
|
||||
index_save_path = join(save_dir, index_save_name)
|
||||
# Save the index as well
|
||||
with open(index_save_path, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
self.model.config.save_pretrained(save_dir)
|
||||
self.quantize_config.save_pretrained(save_dir)
|
||||
|
@ -589,7 +637,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
|
||||
"""alias of save_quantized"""
|
||||
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
|
|
@ -21,15 +21,11 @@ SUPPORTED_MODELS = [
|
|||
"baichuan",
|
||||
"internlm",
|
||||
"qwen",
|
||||
"mpt",
|
||||
]
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
if compare_transformers_version("v4.33.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("falcon")
|
||||
if compare_transformers_version("v4.34.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("mistral")
|
||||
|
||||
|
||||
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
||||
|
||||
|
|
|
@ -16,8 +16,6 @@ from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
|||
from .baichuan import BaiChuanGPTQForCausalLM
|
||||
from .internlm import InternLMGPTQForCausalLM
|
||||
from .qwen import QwenGPTQForCausalLM
|
||||
from .mistral import MistralGPTQForCausalLM
|
||||
from .mpt import MPTGPTQForCausalLM
|
||||
|
||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||
"bloom": BloomGPTQForCausalLM,
|
||||
|
@ -35,8 +33,6 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"baichuan": BaiChuanGPTQForCausalLM,
|
||||
"internlm": InternLMGPTQForCausalLM,
|
||||
"qwen": QwenGPTQForCausalLM,
|
||||
"mistral": MistralGPTQForCausalLM,
|
||||
"mpt": MPTGPTQForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "MistralDecoderLayer"
|
||||
layers_block_name = "model.layers"
|
||||
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||
inside_layer_modules = [
|
||||
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||
["self_attn.o_proj"],
|
||||
["mlp.up_proj", "mlp.gate_proj"],
|
||||
["mlp.down_proj"],
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["MistralGPTQForCausalLM"]
|
|
@ -1,18 +0,0 @@
|
|||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
|
||||
|
||||
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "MPTBlock"
|
||||
layers_block_name = "transformer.blocks"
|
||||
outside_layer_modules = [
|
||||
"transformer.wte", "transformer.norm_f"
|
||||
]
|
||||
|
||||
inside_layer_modules = [
|
||||
["attn.Wqkv"],
|
||||
["attn.out_proj"],
|
||||
["ffn.up_proj"],
|
||||
["ffn.down_proj"]
|
||||
]
|
||||
|
||||
__all__ = ["MPTGPTQForCausalLM"]
|
|
@ -219,7 +219,7 @@ class QuantLinear(nn.Module):
|
|||
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||
self.wf.unsqueeze(0)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
@ -228,7 +228,7 @@ class QuantLinear(nn.Module):
|
|||
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||
self.wf.unsqueeze(-1)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** self.bits) - 1)
|
||||
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
|
||||
elif self.bits == 3:
|
||||
zeros = self.qzeros.reshape(
|
||||
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
|
||||
|
@ -267,10 +267,10 @@ class QuantLinear(nn.Module):
|
|||
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
|
||||
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
|
||||
weights = torch.cat(weights,dim=1)
|
||||
out = torch.matmul(x.to(weights.dtype), weights)
|
||||
out = torch.matmul(x.half(), weights)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
||||
|
|
|
@ -229,7 +229,7 @@ class QuantLinear(nn.Module):
|
|||
|
||||
if self.bits in [2,4,8]:
|
||||
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
@ -238,7 +238,7 @@ class QuantLinear(nn.Module):
|
|||
scales = scales.reshape(-1, 1, scales.shape[-1])
|
||||
|
||||
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight,(2 ** self.bits) - 1)
|
||||
torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight)
|
||||
weight = weight.reshape(-1, self.group_size, weight.shape[2])
|
||||
elif self.bits == 3:
|
||||
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
|
||||
|
@ -266,10 +266,10 @@ class QuantLinear(nn.Module):
|
|||
weight = (scales * (weight - zeros))
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
out = torch.matmul(x.to(weight.dtype), weight)
|
||||
out = torch.matmul(x.half(), weight)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,480 +0,0 @@
|
|||
#include<omp.h>
|
||||
#include<immintrin.h>
|
||||
#include<fstream>
|
||||
|
||||
#define mymin(a,b) ((a)<(b)?(a):(b))
|
||||
#define mymax(a,b) ((a)>(b)?(a):(b))
|
||||
inline
|
||||
void q2gemm_gs(const float* __restrict__ input,
|
||||
const int* __restrict__ W,
|
||||
const float* __restrict__ scales,
|
||||
const float* __restrict__ zeros,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ sums,
|
||||
float* __restrict__ output,
|
||||
const int n,
|
||||
const int m,
|
||||
const int t,
|
||||
const int nb,
|
||||
const int mb,
|
||||
const int tb,
|
||||
int ogtt,
|
||||
const int gs,
|
||||
const int cutoff){
|
||||
#pragma omp parallel num_threads(8)
|
||||
{
|
||||
int tid;
|
||||
const int mu = 16;
|
||||
const int nu = 1;
|
||||
const int tu = 32;
|
||||
const int on = n / nb;
|
||||
const int om = m / mb;
|
||||
const __m256i mask = _mm256_set1_epi32(3);
|
||||
tid = omp_get_thread_num();
|
||||
int tt = ogtt;
|
||||
if(tid >= cutoff){
|
||||
tt -= tb;
|
||||
}
|
||||
const int base_output = tid >= cutoff ?
|
||||
(tid-cutoff)*tt + (tt+tb)*cutoff:
|
||||
tid*tt;
|
||||
const int base_W = tid >= cutoff ?
|
||||
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/16:
|
||||
tid*tt*m/16;
|
||||
for(int j = 0; j < tt; j+=tb){
|
||||
for(int i = 0; i < on; i++) {
|
||||
for(int k = 0; k < om; k++) {
|
||||
for(int i1 = 0; i1 < nb; i1+=nu) {
|
||||
int j1 = 0;
|
||||
for(; j1 < tb-tu+1; j1+=tu) {
|
||||
for(int k1 = 0; k1 < mb; k1+=gs) {
|
||||
__m256 acc0_0 = _mm256_setzero_ps();
|
||||
__m256 acc0_8 = _mm256_setzero_ps();
|
||||
__m256 acc0_16 = _mm256_setzero_ps();
|
||||
__m256 acc0_24 = _mm256_setzero_ps();
|
||||
for(int k2 = k1; k2 < k1+gs; k2+=16)
|
||||
{
|
||||
__m256i w0 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+0]);
|
||||
__m256i w8 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+8]);
|
||||
__m256i w16 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+16]);
|
||||
__m256i w24 = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/16 + k*mb*tb/16 + k2*tb/16 + j1+24]);
|
||||
__m256 v0_15 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+15)*nb + i1+0]);
|
||||
__m256 v0_14 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+14)*nb + i1+0]);
|
||||
__m256 v0_13 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+13)*nb + i1+0]);
|
||||
__m256 v0_12 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+12)*nb + i1+0]);
|
||||
__m256 v0_11 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+11)*nb + i1+0]);
|
||||
__m256 v0_10 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+10)*nb + i1+0]);
|
||||
__m256 v0_9 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+9)*nb + i1+0]);
|
||||
__m256 v0_8 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+8)*nb + i1+0]);
|
||||
__m256i ws0_8 = _mm256_srli_epi32(w0, 16);
|
||||
__m256i ws8_8 = _mm256_srli_epi32(w8, 16);
|
||||
__m256i ws16_8 = _mm256_srli_epi32(w16, 16);
|
||||
__m256i ws24_8 = _mm256_srli_epi32(w24, 16);
|
||||
__m256i wsa0_8= _mm256_and_si256(ws0_8, mask);
|
||||
__m256i wsa8_8= _mm256_and_si256(ws8_8, mask);
|
||||
__m256i wsa16_8= _mm256_and_si256(ws16_8, mask);
|
||||
__m256i wsa24_8= _mm256_and_si256(ws24_8, mask);
|
||||
__m256 l0_8 = _mm256_cvtepi32_ps(wsa0_8);
|
||||
__m256 l8_8 = _mm256_cvtepi32_ps(wsa8_8);
|
||||
__m256 l16_8 = _mm256_cvtepi32_ps(wsa16_8);
|
||||
__m256 l24_8 = _mm256_cvtepi32_ps(wsa24_8);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_8, l0_8, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_8, l8_8, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_8, l16_8, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_8, l24_8, acc0_24);
|
||||
__m256i ws0_9 = _mm256_srli_epi32(w0, 18);
|
||||
__m256i ws8_9 = _mm256_srli_epi32(w8, 18);
|
||||
__m256i ws16_9 = _mm256_srli_epi32(w16, 18);
|
||||
__m256i ws24_9 = _mm256_srli_epi32(w24, 18);
|
||||
__m256i wsa0_9= _mm256_and_si256(ws0_9, mask);
|
||||
__m256i wsa8_9= _mm256_and_si256(ws8_9, mask);
|
||||
__m256i wsa16_9= _mm256_and_si256(ws16_9, mask);
|
||||
__m256i wsa24_9= _mm256_and_si256(ws24_9, mask);
|
||||
__m256 l0_9 = _mm256_cvtepi32_ps(wsa0_9);
|
||||
__m256 l8_9 = _mm256_cvtepi32_ps(wsa8_9);
|
||||
__m256 l16_9 = _mm256_cvtepi32_ps(wsa16_9);
|
||||
__m256 l24_9 = _mm256_cvtepi32_ps(wsa24_9);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_9, l0_9, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_9, l8_9, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_9, l16_9, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_9, l24_9, acc0_24);
|
||||
__m256i ws0_10 = _mm256_srli_epi32(w0, 20);
|
||||
__m256i ws8_10 = _mm256_srli_epi32(w8, 20);
|
||||
__m256i ws16_10 = _mm256_srli_epi32(w16, 20);
|
||||
__m256i ws24_10 = _mm256_srli_epi32(w24, 20);
|
||||
__m256i wsa0_10= _mm256_and_si256(ws0_10, mask);
|
||||
__m256i wsa8_10= _mm256_and_si256(ws8_10, mask);
|
||||
__m256i wsa16_10= _mm256_and_si256(ws16_10, mask);
|
||||
__m256i wsa24_10= _mm256_and_si256(ws24_10, mask);
|
||||
__m256 l0_10 = _mm256_cvtepi32_ps(wsa0_10);
|
||||
__m256 l8_10 = _mm256_cvtepi32_ps(wsa8_10);
|
||||
__m256 l16_10 = _mm256_cvtepi32_ps(wsa16_10);
|
||||
__m256 l24_10 = _mm256_cvtepi32_ps(wsa24_10);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_10, l0_10, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_10, l8_10, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_10, l16_10, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_10, l24_10, acc0_24);
|
||||
__m256i ws0_11 = _mm256_srli_epi32(w0, 22);
|
||||
__m256i ws8_11 = _mm256_srli_epi32(w8, 22);
|
||||
__m256i ws16_11 = _mm256_srli_epi32(w16, 22);
|
||||
__m256i ws24_11 = _mm256_srli_epi32(w24, 22);
|
||||
__m256i wsa0_11= _mm256_and_si256(ws0_11, mask);
|
||||
__m256i wsa8_11= _mm256_and_si256(ws8_11, mask);
|
||||
__m256i wsa16_11= _mm256_and_si256(ws16_11, mask);
|
||||
__m256i wsa24_11= _mm256_and_si256(ws24_11, mask);
|
||||
__m256 l0_11 = _mm256_cvtepi32_ps(wsa0_11);
|
||||
__m256 l8_11 = _mm256_cvtepi32_ps(wsa8_11);
|
||||
__m256 l16_11 = _mm256_cvtepi32_ps(wsa16_11);
|
||||
__m256 l24_11 = _mm256_cvtepi32_ps(wsa24_11);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_11, l0_11, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_11, l8_11, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_11, l16_11, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_11, l24_11, acc0_24);
|
||||
__m256i ws0_12 = _mm256_srli_epi32(w0, 24);
|
||||
__m256i ws8_12 = _mm256_srli_epi32(w8, 24);
|
||||
__m256i ws16_12 = _mm256_srli_epi32(w16, 24);
|
||||
__m256i ws24_12 = _mm256_srli_epi32(w24, 24);
|
||||
__m256i wsa0_12= _mm256_and_si256(ws0_12, mask);
|
||||
__m256i wsa8_12= _mm256_and_si256(ws8_12, mask);
|
||||
__m256i wsa16_12= _mm256_and_si256(ws16_12, mask);
|
||||
__m256i wsa24_12= _mm256_and_si256(ws24_12, mask);
|
||||
__m256 l0_12 = _mm256_cvtepi32_ps(wsa0_12);
|
||||
__m256 l8_12 = _mm256_cvtepi32_ps(wsa8_12);
|
||||
__m256 l16_12 = _mm256_cvtepi32_ps(wsa16_12);
|
||||
__m256 l24_12 = _mm256_cvtepi32_ps(wsa24_12);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_12, l0_12, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_12, l8_12, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_12, l16_12, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_12, l24_12, acc0_24);
|
||||
__m256i ws0_13 = _mm256_srli_epi32(w0, 26);
|
||||
__m256i ws8_13 = _mm256_srli_epi32(w8, 26);
|
||||
__m256i ws16_13 = _mm256_srli_epi32(w16, 26);
|
||||
__m256i ws24_13 = _mm256_srli_epi32(w24, 26);
|
||||
__m256i wsa0_13= _mm256_and_si256(ws0_13, mask);
|
||||
__m256i wsa8_13= _mm256_and_si256(ws8_13, mask);
|
||||
__m256i wsa16_13= _mm256_and_si256(ws16_13, mask);
|
||||
__m256i wsa24_13= _mm256_and_si256(ws24_13, mask);
|
||||
__m256 l0_13 = _mm256_cvtepi32_ps(wsa0_13);
|
||||
__m256 l8_13 = _mm256_cvtepi32_ps(wsa8_13);
|
||||
__m256 l16_13 = _mm256_cvtepi32_ps(wsa16_13);
|
||||
__m256 l24_13 = _mm256_cvtepi32_ps(wsa24_13);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_13, l0_13, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_13, l8_13, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_13, l16_13, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_13, l24_13, acc0_24);
|
||||
__m256i ws0_14 = _mm256_srli_epi32(w0, 28);
|
||||
__m256i ws8_14 = _mm256_srli_epi32(w8, 28);
|
||||
__m256i ws16_14 = _mm256_srli_epi32(w16, 28);
|
||||
__m256i ws24_14 = _mm256_srli_epi32(w24, 28);
|
||||
__m256i wsa0_14= _mm256_and_si256(ws0_14, mask);
|
||||
__m256i wsa8_14= _mm256_and_si256(ws8_14, mask);
|
||||
__m256i wsa16_14= _mm256_and_si256(ws16_14, mask);
|
||||
__m256i wsa24_14= _mm256_and_si256(ws24_14, mask);
|
||||
__m256 l0_14 = _mm256_cvtepi32_ps(wsa0_14);
|
||||
__m256 l8_14 = _mm256_cvtepi32_ps(wsa8_14);
|
||||
__m256 l16_14 = _mm256_cvtepi32_ps(wsa16_14);
|
||||
__m256 l24_14 = _mm256_cvtepi32_ps(wsa24_14);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_14, l0_14, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_14, l8_14, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_14, l16_14, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_14, l24_14, acc0_24);
|
||||
__m256i ws0_15 = _mm256_srli_epi32(w0, 30);
|
||||
__m256i ws8_15 = _mm256_srli_epi32(w8, 30);
|
||||
__m256i ws16_15 = _mm256_srli_epi32(w16, 30);
|
||||
__m256i ws24_15 = _mm256_srli_epi32(w24, 30);
|
||||
__m256i wsa0_15= _mm256_and_si256(ws0_15, mask);
|
||||
__m256i wsa8_15= _mm256_and_si256(ws8_15, mask);
|
||||
__m256i wsa16_15= _mm256_and_si256(ws16_15, mask);
|
||||
__m256i wsa24_15= _mm256_and_si256(ws24_15, mask);
|
||||
__m256 l0_15 = _mm256_cvtepi32_ps(wsa0_15);
|
||||
__m256 l8_15 = _mm256_cvtepi32_ps(wsa8_15);
|
||||
__m256 l16_15 = _mm256_cvtepi32_ps(wsa16_15);
|
||||
__m256 l24_15 = _mm256_cvtepi32_ps(wsa24_15);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_15, l0_15, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_15, l8_15, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_15, l16_15, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_15, l24_15, acc0_24);
|
||||
__m256 v0_7 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+7)*nb + i1+0]);
|
||||
__m256 v0_6 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+6)*nb + i1+0]);
|
||||
__m256 v0_5 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+5)*nb + i1+0]);
|
||||
__m256 v0_4 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+4)*nb + i1+0]);
|
||||
__m256 v0_3 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+3)*nb + i1+0]);
|
||||
__m256 v0_2 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+2)*nb + i1+0]);
|
||||
__m256 v0_1 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+1)*nb + i1+0]);
|
||||
__m256 v0_0 = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+0)*nb + i1+0]);
|
||||
__m256i ws0_0 = _mm256_srli_epi32(w0, 0);
|
||||
__m256i ws8_0 = _mm256_srli_epi32(w8, 0);
|
||||
__m256i ws16_0 = _mm256_srli_epi32(w16, 0);
|
||||
__m256i ws24_0 = _mm256_srli_epi32(w24, 0);
|
||||
__m256i wsa0_0= _mm256_and_si256(ws0_0, mask);
|
||||
__m256i wsa8_0= _mm256_and_si256(ws8_0, mask);
|
||||
__m256i wsa16_0= _mm256_and_si256(ws16_0, mask);
|
||||
__m256i wsa24_0= _mm256_and_si256(ws24_0, mask);
|
||||
__m256 l0_0 = _mm256_cvtepi32_ps(wsa0_0);
|
||||
__m256 l8_0 = _mm256_cvtepi32_ps(wsa8_0);
|
||||
__m256 l16_0 = _mm256_cvtepi32_ps(wsa16_0);
|
||||
__m256 l24_0 = _mm256_cvtepi32_ps(wsa24_0);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_0, l0_0, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_0, l8_0, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_0, l16_0, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_0, l24_0, acc0_24);
|
||||
__m256i ws0_1 = _mm256_srli_epi32(w0, 2);
|
||||
__m256i ws8_1 = _mm256_srli_epi32(w8, 2);
|
||||
__m256i ws16_1 = _mm256_srli_epi32(w16, 2);
|
||||
__m256i ws24_1 = _mm256_srli_epi32(w24, 2);
|
||||
__m256i wsa0_1= _mm256_and_si256(ws0_1, mask);
|
||||
__m256i wsa8_1= _mm256_and_si256(ws8_1, mask);
|
||||
__m256i wsa16_1= _mm256_and_si256(ws16_1, mask);
|
||||
__m256i wsa24_1= _mm256_and_si256(ws24_1, mask);
|
||||
__m256 l0_1 = _mm256_cvtepi32_ps(wsa0_1);
|
||||
__m256 l8_1 = _mm256_cvtepi32_ps(wsa8_1);
|
||||
__m256 l16_1 = _mm256_cvtepi32_ps(wsa16_1);
|
||||
__m256 l24_1 = _mm256_cvtepi32_ps(wsa24_1);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_1, l0_1, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_1, l8_1, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_1, l16_1, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_1, l24_1, acc0_24);
|
||||
__m256i ws0_2 = _mm256_srli_epi32(w0, 4);
|
||||
__m256i ws8_2 = _mm256_srli_epi32(w8, 4);
|
||||
__m256i ws16_2 = _mm256_srli_epi32(w16, 4);
|
||||
__m256i ws24_2 = _mm256_srli_epi32(w24, 4);
|
||||
__m256i wsa0_2= _mm256_and_si256(ws0_2, mask);
|
||||
__m256i wsa8_2= _mm256_and_si256(ws8_2, mask);
|
||||
__m256i wsa16_2= _mm256_and_si256(ws16_2, mask);
|
||||
__m256i wsa24_2= _mm256_and_si256(ws24_2, mask);
|
||||
__m256 l0_2 = _mm256_cvtepi32_ps(wsa0_2);
|
||||
__m256 l8_2 = _mm256_cvtepi32_ps(wsa8_2);
|
||||
__m256 l16_2 = _mm256_cvtepi32_ps(wsa16_2);
|
||||
__m256 l24_2 = _mm256_cvtepi32_ps(wsa24_2);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_2, l0_2, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_2, l8_2, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_2, l16_2, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_2, l24_2, acc0_24);
|
||||
__m256i ws0_3 = _mm256_srli_epi32(w0, 6);
|
||||
__m256i ws8_3 = _mm256_srli_epi32(w8, 6);
|
||||
__m256i ws16_3 = _mm256_srli_epi32(w16, 6);
|
||||
__m256i ws24_3 = _mm256_srli_epi32(w24, 6);
|
||||
__m256i wsa0_3= _mm256_and_si256(ws0_3, mask);
|
||||
__m256i wsa8_3= _mm256_and_si256(ws8_3, mask);
|
||||
__m256i wsa16_3= _mm256_and_si256(ws16_3, mask);
|
||||
__m256i wsa24_3= _mm256_and_si256(ws24_3, mask);
|
||||
__m256 l0_3 = _mm256_cvtepi32_ps(wsa0_3);
|
||||
__m256 l8_3 = _mm256_cvtepi32_ps(wsa8_3);
|
||||
__m256 l16_3 = _mm256_cvtepi32_ps(wsa16_3);
|
||||
__m256 l24_3 = _mm256_cvtepi32_ps(wsa24_3);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_3, l0_3, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_3, l8_3, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_3, l16_3, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_3, l24_3, acc0_24);
|
||||
__m256i ws0_4 = _mm256_srli_epi32(w0, 8);
|
||||
__m256i ws8_4 = _mm256_srli_epi32(w8, 8);
|
||||
__m256i ws16_4 = _mm256_srli_epi32(w16, 8);
|
||||
__m256i ws24_4 = _mm256_srli_epi32(w24, 8);
|
||||
__m256i wsa0_4= _mm256_and_si256(ws0_4, mask);
|
||||
__m256i wsa8_4= _mm256_and_si256(ws8_4, mask);
|
||||
__m256i wsa16_4= _mm256_and_si256(ws16_4, mask);
|
||||
__m256i wsa24_4= _mm256_and_si256(ws24_4, mask);
|
||||
__m256 l0_4 = _mm256_cvtepi32_ps(wsa0_4);
|
||||
__m256 l8_4 = _mm256_cvtepi32_ps(wsa8_4);
|
||||
__m256 l16_4 = _mm256_cvtepi32_ps(wsa16_4);
|
||||
__m256 l24_4 = _mm256_cvtepi32_ps(wsa24_4);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_4, l0_4, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_4, l8_4, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_4, l16_4, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_4, l24_4, acc0_24);
|
||||
__m256i ws0_5 = _mm256_srli_epi32(w0, 10);
|
||||
__m256i ws8_5 = _mm256_srli_epi32(w8, 10);
|
||||
__m256i ws16_5 = _mm256_srli_epi32(w16, 10);
|
||||
__m256i ws24_5 = _mm256_srli_epi32(w24, 10);
|
||||
__m256i wsa0_5= _mm256_and_si256(ws0_5, mask);
|
||||
__m256i wsa8_5= _mm256_and_si256(ws8_5, mask);
|
||||
__m256i wsa16_5= _mm256_and_si256(ws16_5, mask);
|
||||
__m256i wsa24_5= _mm256_and_si256(ws24_5, mask);
|
||||
__m256 l0_5 = _mm256_cvtepi32_ps(wsa0_5);
|
||||
__m256 l8_5 = _mm256_cvtepi32_ps(wsa8_5);
|
||||
__m256 l16_5 = _mm256_cvtepi32_ps(wsa16_5);
|
||||
__m256 l24_5 = _mm256_cvtepi32_ps(wsa24_5);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_5, l0_5, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_5, l8_5, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_5, l16_5, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_5, l24_5, acc0_24);
|
||||
__m256i ws0_6 = _mm256_srli_epi32(w0, 12);
|
||||
__m256i ws8_6 = _mm256_srli_epi32(w8, 12);
|
||||
__m256i ws16_6 = _mm256_srli_epi32(w16, 12);
|
||||
__m256i ws24_6 = _mm256_srli_epi32(w24, 12);
|
||||
__m256i wsa0_6= _mm256_and_si256(ws0_6, mask);
|
||||
__m256i wsa8_6= _mm256_and_si256(ws8_6, mask);
|
||||
__m256i wsa16_6= _mm256_and_si256(ws16_6, mask);
|
||||
__m256i wsa24_6= _mm256_and_si256(ws24_6, mask);
|
||||
__m256 l0_6 = _mm256_cvtepi32_ps(wsa0_6);
|
||||
__m256 l8_6 = _mm256_cvtepi32_ps(wsa8_6);
|
||||
__m256 l16_6 = _mm256_cvtepi32_ps(wsa16_6);
|
||||
__m256 l24_6 = _mm256_cvtepi32_ps(wsa24_6);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_6, l0_6, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_6, l8_6, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_6, l16_6, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_6, l24_6, acc0_24);
|
||||
__m256i ws0_7 = _mm256_srli_epi32(w0, 14);
|
||||
__m256i ws8_7 = _mm256_srli_epi32(w8, 14);
|
||||
__m256i ws16_7 = _mm256_srli_epi32(w16, 14);
|
||||
__m256i ws24_7 = _mm256_srli_epi32(w24, 14);
|
||||
__m256i wsa0_7= _mm256_and_si256(ws0_7, mask);
|
||||
__m256i wsa8_7= _mm256_and_si256(ws8_7, mask);
|
||||
__m256i wsa16_7= _mm256_and_si256(ws16_7, mask);
|
||||
__m256i wsa24_7= _mm256_and_si256(ws24_7, mask);
|
||||
__m256 l0_7 = _mm256_cvtepi32_ps(wsa0_7);
|
||||
__m256 l8_7 = _mm256_cvtepi32_ps(wsa8_7);
|
||||
__m256 l16_7 = _mm256_cvtepi32_ps(wsa16_7);
|
||||
__m256 l24_7 = _mm256_cvtepi32_ps(wsa24_7);
|
||||
acc0_0 = _mm256_fmadd_ps(v0_7, l0_7, acc0_0);
|
||||
acc0_8 = _mm256_fmadd_ps(v0_7, l8_7, acc0_8);
|
||||
acc0_16 = _mm256_fmadd_ps(v0_7, l16_7, acc0_16);
|
||||
acc0_24 = _mm256_fmadd_ps(v0_7, l24_7, acc0_24);
|
||||
}
|
||||
__m256 o0_0 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+0]);
|
||||
__m256 o0_8 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+8]);
|
||||
__m256 o0_16 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+16]);
|
||||
__m256 o0_24 = _mm256_loadu_ps(&output[base_output + j + (i1+0)*t + j1+24]);
|
||||
__m256 s0_0 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+0]);
|
||||
__m256 s0_8 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+8]);
|
||||
__m256 s0_16 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+16]);
|
||||
__m256 s0_24 = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+24]);
|
||||
__m256 f0_0 = _mm256_fmadd_ps(acc0_0, s0_0, o0_0);
|
||||
__m256 f0_8 = _mm256_fmadd_ps(acc0_8, s0_8, o0_8);
|
||||
__m256 f0_16 = _mm256_fmadd_ps(acc0_16, s0_16, o0_16);
|
||||
__m256 f0_24 = _mm256_fmadd_ps(acc0_24, s0_24, o0_24);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+0], f0_0);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+8], f0_8);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+16], f0_16);
|
||||
_mm256_storeu_ps(&output[base_output + j + (i1+0)*t + j1+24], f0_24);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma omp barrier
|
||||
const int ngs = m/gs;
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < tt; j+=32){
|
||||
__m256 acc0 = _mm256_setzero_ps();
|
||||
__m256 acc8 = _mm256_setzero_ps();
|
||||
__m256 acc16 = _mm256_setzero_ps();
|
||||
__m256 acc24 = _mm256_setzero_ps();
|
||||
for (int i1 = 0; i1 < ngs; i1++){
|
||||
__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);
|
||||
__m256 z0 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 0]);
|
||||
__m256 z8 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 8]);
|
||||
__m256 z16 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 16]);
|
||||
__m256 z24 = _mm256_loadu_ps(&zeros[base_output + i1* t + j + 24]);
|
||||
__m256 s0 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 0]);
|
||||
__m256 s8 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 8]);
|
||||
__m256 s16 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 16]);
|
||||
__m256 s24 = _mm256_loadu_ps(&scales[base_output + i1 * t + j + 24]);
|
||||
__m256 zs0 = _mm256_mul_ps(z0, s0);
|
||||
__m256 zs8 = _mm256_mul_ps(z8, s8);
|
||||
__m256 zs16 = _mm256_mul_ps(z16, s16);
|
||||
__m256 zs24 = _mm256_mul_ps(z24, s24);
|
||||
acc0 = _mm256_fmadd_ps(zs0, r, acc0);
|
||||
acc8 = _mm256_fmadd_ps(zs8, r, acc8);
|
||||
acc16 = _mm256_fmadd_ps(zs16, r, acc16);
|
||||
acc24 = _mm256_fmadd_ps(zs24, r, acc24);
|
||||
}
|
||||
__m256 o0 = _mm256_loadu_ps(&output[i*t + base_output + j + 0]);
|
||||
__m256 o8 = _mm256_loadu_ps(&output[i*t + base_output + j + 8]);
|
||||
__m256 o16 = _mm256_loadu_ps(&output[i*t + base_output + j + 16]);
|
||||
__m256 o24 = _mm256_loadu_ps(&output[i*t + base_output + j + 24]);
|
||||
__m256 b0 = _mm256_loadu_ps(&bias[base_output + j + 0]);
|
||||
__m256 b8 = _mm256_loadu_ps(&bias[base_output + j + 8]);
|
||||
__m256 b16 = _mm256_loadu_ps(&bias[base_output + j + 16]);
|
||||
__m256 b24 = _mm256_loadu_ps(&bias[base_output + j + 24]);
|
||||
__m256 o10 = _mm256_add_ps(o0, acc0);
|
||||
__m256 o18 = _mm256_add_ps(o8, acc8);
|
||||
__m256 o116 = _mm256_add_ps(o16, acc16);
|
||||
__m256 o124 = _mm256_add_ps(o24, acc24);
|
||||
__m256 o20 = _mm256_add_ps(o10, b0);
|
||||
__m256 o28 = _mm256_add_ps(o18, b8);
|
||||
__m256 o216 = _mm256_add_ps(o116, b16);
|
||||
__m256 o224 = _mm256_add_ps(o124, b24);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 0], o20);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 8], o28);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 16], o216);
|
||||
_mm256_storeu_ps(&output[i*t + base_output + j + 24], o224);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void qforward(const float* __restrict__ input,
|
||||
const int* __restrict__ W,
|
||||
const float* __restrict__ scales,
|
||||
const float* __restrict__ zeros,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ sums,
|
||||
float* __restrict__ output,
|
||||
int n,
|
||||
int m,
|
||||
int t) {
|
||||
q2gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, 1, 1024, 32, 512, 64, 9);
|
||||
}
|
||||
inline void pack_input(float* A, float* B){
|
||||
// copy the full matrix A in blocked format into B
|
||||
uint64_t idx = 0;
|
||||
const int N = 1;
|
||||
const int M = 4096;
|
||||
const int nb = 1;
|
||||
const int mb = 1024;
|
||||
for(int i = 0; i < N; i+=nb){
|
||||
for(int j = 0; j < M; j+=mb){
|
||||
for(int jj = j; jj < mymin(j+mb, M); jj++){
|
||||
for(int ii = i; ii < mymin(i+nb, N); ii++){
|
||||
B[idx] = A[ii*M+jj];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void pack_qw_inner(int* A, int* B, int cutoff){
|
||||
// copy the full matrix A in blocked format into B
|
||||
uint64_t idx = 0;
|
||||
const int N = 256;
|
||||
const int M = 4096;
|
||||
const int nb = 64;
|
||||
int mb = 32;
|
||||
for(int j = 0, tid = 0; j < M; j+=mb, tid++){
|
||||
for(int i = 0; i < N; i+=nb){
|
||||
for(int ii = i; ii < mymin(i+nb, N); ii++){
|
||||
for(int jj = j; jj < mymin(j+mb, M); jj++){
|
||||
B[idx] = A[ii*M+jj];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void pack_qw(int* A, int* B){
|
||||
pack_qw_inner(A, B, 65);
|
||||
}
|
||||
inline void pack_output(float* A, float* B){
|
||||
// copy the full matrix A in blocked format into B
|
||||
uint64_t idx = 0;
|
||||
const int N = 1;
|
||||
const int M = 4096;
|
||||
const int nb = 1;
|
||||
const int mb = 32;
|
||||
for(int i = 0; i < N; i+=nb){
|
||||
for(int j = 0; j < M; j+=mb){
|
||||
for(int ii = i; ii < mymin(i+nb, N); ii++){
|
||||
for(int jj = j; jj < mymin(j+mb, M); jj++){
|
||||
B[idx] = A[ii*M+jj];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void print_parameters(){
|
||||
std::ofstream outfile;
|
||||
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||
outfile << 2 << "," << 1 << "," << 16 << "," << 32 << "," << 8 << "," << 8 << "," << 64 << ",";
|
||||
}
|
Binary file not shown.
|
@ -1,37 +0,0 @@
|
|||
bits,nu,mu,tu,unroll,p,gs,time
|
||||
4,1,16,16,1,8,-1,1.3814e+06
|
||||
4,1,16,16,2,8,-1,1.44087e+06
|
||||
4,1,16,16,4,8,-1,1.56173e+06
|
||||
4,1,16,16,8,8,-1,1.41389e+06
|
||||
3,1,16,16,5,8,-1,2.14748e+09
|
||||
2,1,16,16,1,8,-1,1.09513e+06
|
||||
2,1,16,16,2,8,-1,1.11322e+06
|
||||
2,1,16,16,4,8,-1,1.12031e+06
|
||||
2,1,16,16,8,8,-1,1.19086e+06
|
||||
4,1,16,16,1,8,64,1.69111e+06
|
||||
4,1,16,16,2,8,64,1.60056e+06
|
||||
4,1,16,16,4,8,64,1.41263e+06
|
||||
4,1,16,16,8,8,64,1.74572e+06
|
||||
3,1,16,16,5,8,64,1.48062e+06
|
||||
2,1,16,16,1,8,64,1.51234e+06
|
||||
2,1,16,16,2,8,64,1.68108e+06
|
||||
2,1,16,16,4,8,64,1.7624e+06
|
||||
2,1,16,16,8,8,64,1.69563e+06
|
||||
4,1,16,32,1,8,-1,1.24798e+06
|
||||
4,1,16,32,2,8,-1,1.58421e+06
|
||||
4,1,16,32,4,8,-1,2.10718e+06
|
||||
4,1,16,32,8,8,-1,1.54288e+06
|
||||
3,1,16,32,5,8,-1,2.14748e+09
|
||||
2,1,16,32,1,8,-1,1.55906e+06
|
||||
2,1,16,32,2,8,-1,1.58576e+06
|
||||
2,1,16,32,4,8,-1,1.57993e+06
|
||||
2,1,16,32,8,8,-1,1.80443e+06
|
||||
4,1,16,32,1,8,64,1.58354e+06
|
||||
4,1,16,32,2,8,64,1.63248e+06
|
||||
4,1,16,32,4,8,64,1.91902e+06
|
||||
4,1,16,32,8,8,64,1.9243e+06
|
||||
3,1,16,32,5,8,64,1.33812e+06
|
||||
2,1,16,32,1,8,64,1.77522e+06
|
||||
2,1,16,32,2,8,64,1.54702e+06
|
||||
2,1,16,32,4,8,64,1.78772e+06
|
||||
2,1,16,32,8,8,64,1.49612e+06
|
|
BIN
qigen.tar.xz
BIN
qigen.tar.xz
Binary file not shown.
7
setup.py
7
setup.py
|
@ -80,7 +80,7 @@ requirements = [
|
|||
"gekko",
|
||||
"torch>=1.13.0",
|
||||
"safetensors",
|
||||
"transformers>=4.31.0",
|
||||
"transformers>=4.34.0",
|
||||
"peft",
|
||||
"tqdm",
|
||||
]
|
||||
|
@ -98,10 +98,7 @@ if BUILD_CUDA_EXT:
|
|||
|
||||
if platform.system() != 'Windows':
|
||||
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
|
||||
ret = subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
|
||||
# if ret != 0:
|
||||
# raise Exception(f"Failed generate with {ret}")
|
||||
# sys.exit(-1)
|
||||
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
|
||||
|
||||
if not ROCM_VERSION:
|
||||
from distutils.sysconfig import get_python_lib
|
||||
|
|
Loading…
Add table
Reference in a new issue