update setup.py
This commit is contained in:
parent
419160b733
commit
0bb84053a2
4 changed files with 39 additions and 33 deletions
|
@ -1,10 +0,0 @@
|
||||||
from setuptools import setup, Extension
|
|
||||||
from torch.utils import cpp_extension
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='quant_cuda',
|
|
||||||
ext_modules=[cpp_extension.CUDAExtension(
|
|
||||||
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
|
|
||||||
)],
|
|
||||||
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
|
||||||
)
|
|
62
setup.py
62
setup.py
|
@ -1,9 +1,11 @@
|
||||||
from os.path import abspath, dirname, join
|
from setuptools import setup, find_packages
|
||||||
from setuptools import setup, find_packages, Extension
|
|
||||||
|
|
||||||
from torch.utils import cpp_extension
|
try:
|
||||||
|
import torch
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
|
||||||
project_root = dirname(abspath(__file__))
|
|
||||||
|
|
||||||
requirements = [
|
requirements = [
|
||||||
"datasets",
|
"datasets",
|
||||||
|
@ -11,29 +13,43 @@ requirements = [
|
||||||
"rouge",
|
"rouge",
|
||||||
"torch>=1.13.0",
|
"torch>=1.13.0",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"transformers>=4.26.1"
|
"transformers>=4.26.1",
|
||||||
|
"triton>=2.0.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"llama": ["transformers>=4.28.0"]
|
"llama": ["transformers>=4.28.0"]
|
||||||
}
|
}
|
||||||
|
|
||||||
extensions = [
|
|
||||||
cpp_extension.CUDAExtension(
|
|
||||||
"quant_cuda",
|
|
||||||
[
|
|
||||||
join(project_root, "auto_gptq/quantization/quant_cuda.cpp"),
|
|
||||||
join(project_root, "auto_gptq/quantization/quant_cuda_kernel.cu")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
setup(
|
if TORCH_AVAILABLE:
|
||||||
name="auto_gptq",
|
from torch.utils import cpp_extension
|
||||||
packages=find_packages(),
|
|
||||||
version="v0.0.4-dev",
|
extensions = [
|
||||||
install_requires=requirements,
|
cpp_extension.CUDAExtension(
|
||||||
extras_require=extras_require,
|
"quant_cuda",
|
||||||
ext_modules=extensions,
|
[
|
||||||
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
"quant_cuda/quant_cuda.cpp",
|
||||||
)
|
"quant_cuda/quant_cuda_kernel.cu"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
setup(
|
||||||
|
name="auto_gptq",
|
||||||
|
packages=find_packages(),
|
||||||
|
version="v0.0.4-dev",
|
||||||
|
install_requires=requirements,
|
||||||
|
extras_require=extras_require,
|
||||||
|
include_dirs=["quant_cuda"],
|
||||||
|
ext_modules=extensions,
|
||||||
|
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setup(
|
||||||
|
name="auto_gptq",
|
||||||
|
packages=find_packages(),
|
||||||
|
version="v0.0.4-dev",
|
||||||
|
install_requires=requirements,
|
||||||
|
extras_require=extras_require,
|
||||||
|
include_dirs=["quant_cuda"]
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue