Merge pull request #309 from PanQiWei/install-skip-qigen(windows)

skip qigen installation on windows
This commit is contained in:
潘其威(William) 2023-08-31 19:03:43 +08:00 committed by GitHub
commit 1339db3045
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 46 deletions

View file

@ -26,7 +26,7 @@ from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModul
from ..quantization import GPTQ
from ..utils.data_utils import collate_data
from ..utils.import_utils import (
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE
)
logger = getLogger(__name__)
@ -727,13 +727,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
if use_qigen:
logger.warning("QIgen is active. Ignores all settings related to cuda.")
inject_fused_attention = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
if use_qigen and not QIGEN_AVAILABLE:
logger.warning("Qigen is not installed, reset use_qigen to False.")
use_qigen = False
if use_triton and not TRITON_AVAILABLE:
logger.warning("Triton is not installed, reset use_triton to False.")
use_triton = False
@ -754,7 +750,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"2. You are using pytorch without CUDA support.\n"
"3. CUDA and nvcc are not installed in your device."
)
if use_qigen and QIGEN_AVAILABLE:
logger.warning("QIgen is active. Ignores all settings related to cuda.")
inject_fused_attention = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
# == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)

View file

@ -6,7 +6,6 @@ import torch
import torch.nn as nn
from transformers import AutoConfig
import transformers
import cQIGen as qinfer
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from ..utils.import_utils import dynamically_import_QuantLinear
@ -105,28 +104,6 @@ def make_quant(
use_qigen=use_qigen
)
def process_zeros_scales(zeros, scales, bits, out_features):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
def preprocess_checkpoint_qigen(
module,
names,
@ -135,12 +112,40 @@ def preprocess_checkpoint_qigen(
checkpoint,
name='',
):
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
if isinstance(module, QuantLinear):
in_features = module.infeatures
out_features = module.outfeatures
zeros = checkpoint[name + '.qzeros']
scales = checkpoint[name + '.scales'].float()
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features)
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
del checkpoint[name + '.qzeros']
del checkpoint[name + '.g_idx']
if name + '.bias' in checkpoint:

View file

@ -4,7 +4,6 @@ from torch import nn
from tqdm import tqdm
import gc
import cQIGen as qinfer
import math
import numpy as np
from gekko import GEKKO
@ -12,6 +11,11 @@ from logging import getLogger
logger = getLogger(__name__)
try:
import cQIGen as qinfer
except ImportError:
logger.error('cQIGen not installed.')
raise
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
m = GEKKO() # create GEKKO model

View file

@ -25,6 +25,13 @@ try:
except:
EXLLAMA_KERNELS_AVAILABLE = False
try:
import cQIGen as qinfer
QIGEN_AVAILABLE = True
except:
QIGEN_AVAILABLE = False
logger = getLogger(__name__)

View file

@ -4,6 +4,7 @@ from pathlib import Path
from setuptools import setup, Extension, find_packages
import subprocess
import math
import platform
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
@ -94,10 +95,11 @@ include_dirs = ["autogptq_cuda"]
additional_setup_kwargs = dict()
if BUILD_CUDA_EXT:
from torch.utils import cpp_extension
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
if platform.system() != 'Windows':
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
if not ROCM_VERSION:
from distutils.sysconfig import get_python_lib
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
@ -120,16 +122,20 @@ if BUILD_CUDA_EXT:
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
]
),
cpp_extension.CppExtension(
"cQIGen",
[
'autogptq_extension/qigen/backend.cpp'
],
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
)
]
if platform.system() != 'Windows':
extensions.append(
cpp_extension.CppExtension(
"cQIGen",
[
'autogptq_extension/qigen/backend.cpp'
],
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
)
)
if os.name == "nt":
# On Windows, fix an error LNK2001: unresolved external symbol cublasHgemm bug in the compilation
cuda_path = os.environ.get("CUDA_PATH", None)