install check qigen

This commit is contained in:
qwopqwop200 2023-08-31 14:37:39 +09:00 committed by GitHub
parent 71d56c76d0
commit 45a1ee4d84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 34 deletions

View file

@ -26,7 +26,7 @@ from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModul
from ..quantization import GPTQ from ..quantization import GPTQ
from ..utils.data_utils import collate_data from ..utils.data_utils import collate_data
from ..utils.import_utils import ( from ..utils.import_utils import (
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE
) )
logger = getLogger(__name__) logger = getLogger(__name__)
@ -727,13 +727,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"_raise_exceptions_for_missing_entries": False, "_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash, "_commit_hash": commit_hash,
} }
if use_qigen: if use_qigen and not QIGEN_AVAILABLE:
logger.warning("QIgen is active. Ignores all settings related to cuda.") logger.warning("Qigen is not installed, reset use_qigen to False.")
inject_fused_attention = False use_qigen = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
if use_triton and not TRITON_AVAILABLE: if use_triton and not TRITON_AVAILABLE:
logger.warning("Triton is not installed, reset use_triton to False.") logger.warning("Triton is not installed, reset use_triton to False.")
use_triton = False use_triton = False
@ -754,7 +750,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"2. You are using pytorch without CUDA support.\n" "2. You are using pytorch without CUDA support.\n"
"3. CUDA and nvcc are not installed in your device." "3. CUDA and nvcc are not installed in your device."
) )
if use_qigen and QIGEN_AVAILABLE:
logger.warning("QIgen is active. Ignores all settings related to cuda.")
inject_fused_attention = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
# == step1: prepare configs and file names == # # == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs) config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)

View file

@ -6,7 +6,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoConfig from transformers import AutoConfig
import transformers import transformers
import cQIGen as qinfer
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from ..utils.import_utils import dynamically_import_QuantLinear from ..utils.import_utils import dynamically_import_QuantLinear
@ -105,28 +104,6 @@ def make_quant(
use_qigen=use_qigen use_qigen=use_qigen
) )
def process_zeros_scales(zeros, scales, bits, out_features):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
def preprocess_checkpoint_qigen( def preprocess_checkpoint_qigen(
module, module,
names, names,
@ -135,12 +112,40 @@ def preprocess_checkpoint_qigen(
checkpoint, checkpoint,
name='', name='',
): ):
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True) QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
if isinstance(module, QuantLinear): if isinstance(module, QuantLinear):
in_features = module.infeatures in_features = module.infeatures
out_features = module.outfeatures out_features = module.outfeatures
zeros = checkpoint[name + '.qzeros']
scales = checkpoint[name + '.scales'].float()
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features) checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
del checkpoint[name + '.qzeros'] del checkpoint[name + '.qzeros']
del checkpoint[name + '.g_idx'] del checkpoint[name + '.g_idx']
if name + '.bias' in checkpoint: if name + '.bias' in checkpoint:

View file

@ -4,7 +4,6 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
import gc import gc
import cQIGen as qinfer
import math import math
import numpy as np import numpy as np
from gekko import GEKKO from gekko import GEKKO
@ -12,6 +11,11 @@ from logging import getLogger
logger = getLogger(__name__) logger = getLogger(__name__)
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
def mem_model(N, M, T, mu, tu, bits, l1, p, gs): def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
m = GEKKO() # create GEKKO model m = GEKKO() # create GEKKO model

View file

@ -25,6 +25,13 @@ try:
except: except:
EXLLAMA_KERNELS_AVAILABLE = False EXLLAMA_KERNELS_AVAILABLE = False
try:
import cQIGen as qinfer
QIGEN_AVAILABLE = True
except:
QIGEN_AVAILABLE = False
logger = getLogger(__name__) logger = getLogger(__name__)