Merge pull request #309 from PanQiWei/install-skip-qigen(windows)
skip qigen installation on windows
This commit is contained in:
commit
1339db3045
5 changed files with 71 additions and 46 deletions
|
@ -26,7 +26,7 @@ from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModul
|
||||||
from ..quantization import GPTQ
|
from ..quantization import GPTQ
|
||||||
from ..utils.data_utils import collate_data
|
from ..utils.data_utils import collate_data
|
||||||
from ..utils.import_utils import (
|
from ..utils.import_utils import (
|
||||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE
|
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -727,13 +727,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
"_raise_exceptions_for_missing_entries": False,
|
"_raise_exceptions_for_missing_entries": False,
|
||||||
"_commit_hash": commit_hash,
|
"_commit_hash": commit_hash,
|
||||||
}
|
}
|
||||||
if use_qigen:
|
if use_qigen and not QIGEN_AVAILABLE:
|
||||||
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
logger.warning("Qigen is not installed, reset use_qigen to False.")
|
||||||
inject_fused_attention = False
|
use_qigen = False
|
||||||
inject_fused_mlp = False
|
|
||||||
use_triton = False
|
|
||||||
disable_exllama = True
|
|
||||||
|
|
||||||
if use_triton and not TRITON_AVAILABLE:
|
if use_triton and not TRITON_AVAILABLE:
|
||||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
logger.warning("Triton is not installed, reset use_triton to False.")
|
||||||
use_triton = False
|
use_triton = False
|
||||||
|
@ -755,6 +751,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
"3. CUDA and nvcc are not installed in your device."
|
"3. CUDA and nvcc are not installed in your device."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_qigen and QIGEN_AVAILABLE:
|
||||||
|
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
||||||
|
inject_fused_attention = False
|
||||||
|
inject_fused_mlp = False
|
||||||
|
use_triton = False
|
||||||
|
disable_exllama = True
|
||||||
|
|
||||||
# == step1: prepare configs and file names == #
|
# == step1: prepare configs and file names == #
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
import transformers
|
import transformers
|
||||||
import cQIGen as qinfer
|
|
||||||
|
|
||||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||||
from ..utils.import_utils import dynamically_import_QuantLinear
|
from ..utils.import_utils import dynamically_import_QuantLinear
|
||||||
|
@ -105,7 +104,28 @@ def make_quant(
|
||||||
use_qigen=use_qigen
|
use_qigen=use_qigen
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_zeros_scales(zeros, scales, bits, out_features):
|
def preprocess_checkpoint_qigen(
|
||||||
|
module,
|
||||||
|
names,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
checkpoint,
|
||||||
|
name='',
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import cQIGen as qinfer
|
||||||
|
except ImportError:
|
||||||
|
logger.error('cQIGen not installed.')
|
||||||
|
raise
|
||||||
|
|
||||||
|
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
|
||||||
|
if isinstance(module, QuantLinear):
|
||||||
|
in_features = module.infeatures
|
||||||
|
out_features = module.outfeatures
|
||||||
|
|
||||||
|
zeros = checkpoint[name + '.qzeros']
|
||||||
|
scales = checkpoint[name + '.scales'].float()
|
||||||
|
|
||||||
if zeros.dtype != torch.float32:
|
if zeros.dtype != torch.float32:
|
||||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||||
if bits == 4:
|
if bits == 4:
|
||||||
|
@ -125,22 +145,7 @@ def process_zeros_scales(zeros, scales, bits, out_features):
|
||||||
else:
|
else:
|
||||||
new_zeros = zeros.contiguous()
|
new_zeros = zeros.contiguous()
|
||||||
|
|
||||||
return new_zeros, new_scales
|
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales
|
||||||
|
|
||||||
def preprocess_checkpoint_qigen(
|
|
||||||
module,
|
|
||||||
names,
|
|
||||||
bits,
|
|
||||||
group_size,
|
|
||||||
checkpoint,
|
|
||||||
name='',
|
|
||||||
):
|
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
|
|
||||||
if isinstance(module, QuantLinear):
|
|
||||||
in_features = module.infeatures
|
|
||||||
out_features = module.outfeatures
|
|
||||||
|
|
||||||
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features)
|
|
||||||
del checkpoint[name + '.qzeros']
|
del checkpoint[name + '.qzeros']
|
||||||
del checkpoint[name + '.g_idx']
|
del checkpoint[name + '.g_idx']
|
||||||
if name + '.bias' in checkpoint:
|
if name + '.bias' in checkpoint:
|
||||||
|
|
|
@ -4,7 +4,6 @@ from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
import cQIGen as qinfer
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gekko import GEKKO
|
from gekko import GEKKO
|
||||||
|
@ -12,6 +11,11 @@ from logging import getLogger
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cQIGen as qinfer
|
||||||
|
except ImportError:
|
||||||
|
logger.error('cQIGen not installed.')
|
||||||
|
raise
|
||||||
|
|
||||||
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
||||||
m = GEKKO() # create GEKKO model
|
m = GEKKO() # create GEKKO model
|
||||||
|
|
|
@ -25,6 +25,13 @@ try:
|
||||||
except:
|
except:
|
||||||
EXLLAMA_KERNELS_AVAILABLE = False
|
EXLLAMA_KERNELS_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cQIGen as qinfer
|
||||||
|
|
||||||
|
QIGEN_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
QIGEN_AVAILABLE = False
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
12
setup.py
12
setup.py
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||||
from setuptools import setup, Extension, find_packages
|
from setuptools import setup, Extension, find_packages
|
||||||
import subprocess
|
import subprocess
|
||||||
import math
|
import math
|
||||||
|
import platform
|
||||||
|
|
||||||
os.environ["CC"] = "g++"
|
os.environ["CC"] = "g++"
|
||||||
os.environ["CXX"] = "g++"
|
os.environ["CXX"] = "g++"
|
||||||
|
@ -95,9 +96,10 @@ additional_setup_kwargs = dict()
|
||||||
if BUILD_CUDA_EXT:
|
if BUILD_CUDA_EXT:
|
||||||
from torch.utils import cpp_extension
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
|
if platform.system() != 'Windows':
|
||||||
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
|
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
|
||||||
|
|
||||||
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
|
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
|
||||||
|
|
||||||
if not ROCM_VERSION:
|
if not ROCM_VERSION:
|
||||||
from distutils.sysconfig import get_python_lib
|
from distutils.sysconfig import get_python_lib
|
||||||
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
||||||
|
@ -120,7 +122,11 @@ if BUILD_CUDA_EXT:
|
||||||
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
|
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
|
||||||
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
|
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
|
||||||
]
|
]
|
||||||
),
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
extensions.append(
|
||||||
cpp_extension.CppExtension(
|
cpp_extension.CppExtension(
|
||||||
"cQIGen",
|
"cQIGen",
|
||||||
[
|
[
|
||||||
|
@ -128,7 +134,7 @@ if BUILD_CUDA_EXT:
|
||||||
],
|
],
|
||||||
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
|
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
# On Windows, fix an error LNK2001: unresolved external symbol cublasHgemm bug in the compilation
|
# On Windows, fix an error LNK2001: unresolved external symbol cublasHgemm bug in the compilation
|
||||||
|
|
Loading…
Add table
Reference in a new issue