install check qigen
This commit is contained in:
parent
71d56c76d0
commit
45a1ee4d84
4 changed files with 53 additions and 34 deletions
|
@ -26,7 +26,7 @@ from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModul
|
|||
from ..quantization import GPTQ
|
||||
from ..utils.data_utils import collate_data
|
||||
from ..utils.import_utils import (
|
||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE
|
||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -727,13 +727,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_commit_hash": commit_hash,
|
||||
}
|
||||
if use_qigen:
|
||||
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
||||
inject_fused_attention = False
|
||||
inject_fused_mlp = False
|
||||
use_triton = False
|
||||
disable_exllama = True
|
||||
|
||||
if use_qigen and not QIGEN_AVAILABLE:
|
||||
logger.warning("Qigen is not installed, reset use_qigen to False.")
|
||||
use_qigen = False
|
||||
if use_triton and not TRITON_AVAILABLE:
|
||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
||||
use_triton = False
|
||||
|
@ -754,7 +750,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
"2. You are using pytorch without CUDA support.\n"
|
||||
"3. CUDA and nvcc are not installed in your device."
|
||||
)
|
||||
|
||||
|
||||
if use_qigen and QIGEN_AVAILABLE:
|
||||
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
||||
inject_fused_attention = False
|
||||
inject_fused_mlp = False
|
||||
use_triton = False
|
||||
disable_exllama = True
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
import transformers
|
||||
import cQIGen as qinfer
|
||||
|
||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
from ..utils.import_utils import dynamically_import_QuantLinear
|
||||
|
@ -105,28 +104,6 @@ def make_quant(
|
|||
use_qigen=use_qigen
|
||||
)
|
||||
|
||||
def process_zeros_scales(zeros, scales, bits, out_features):
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != out_features:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != out_features:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
return new_zeros, new_scales
|
||||
|
||||
def preprocess_checkpoint_qigen(
|
||||
module,
|
||||
names,
|
||||
|
@ -135,12 +112,40 @@ def preprocess_checkpoint_qigen(
|
|||
checkpoint,
|
||||
name='',
|
||||
):
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
except ImportError:
|
||||
logger.error('cQIGen not installed.')
|
||||
raise
|
||||
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
|
||||
if isinstance(module, QuantLinear):
|
||||
in_features = module.infeatures
|
||||
out_features = module.outfeatures
|
||||
|
||||
zeros = checkpoint[name + '.qzeros']
|
||||
scales = checkpoint[name + '.scales'].float()
|
||||
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != out_features:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != out_features:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features)
|
||||
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
|
||||
del checkpoint[name + '.qzeros']
|
||||
del checkpoint[name + '.g_idx']
|
||||
if name + '.bias' in checkpoint:
|
||||
|
|
|
@ -4,7 +4,6 @@ from torch import nn
|
|||
from tqdm import tqdm
|
||||
import gc
|
||||
|
||||
import cQIGen as qinfer
|
||||
import math
|
||||
import numpy as np
|
||||
from gekko import GEKKO
|
||||
|
@ -12,6 +11,11 @@ from logging import getLogger
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
except ImportError:
|
||||
logger.error('cQIGen not installed.')
|
||||
raise
|
||||
|
||||
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
||||
m = GEKKO() # create GEKKO model
|
||||
|
|
|
@ -25,6 +25,13 @@ try:
|
|||
except:
|
||||
EXLLAMA_KERNELS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
|
||||
QIGEN_AVAILABLE = True
|
||||
except:
|
||||
QIGEN_AVAILABLE = False
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue