update setup.py
This commit is contained in:
parent
419160b733
commit
0bb84053a2
4 changed files with 39 additions and 33 deletions
|
@ -1,10 +0,0 @@
|
|||
from setuptools import setup, Extension
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
setup(
|
||||
name='quant_cuda',
|
||||
ext_modules=[cpp_extension.CUDAExtension(
|
||||
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
|
||||
)],
|
||||
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
||||
)
|
62
setup.py
62
setup.py
|
@ -1,9 +1,11 @@
|
|||
from os.path import abspath, dirname, join
|
||||
from setuptools import setup, find_packages, Extension
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
from torch.utils import cpp_extension
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
project_root = dirname(abspath(__file__))
|
||||
|
||||
requirements = [
|
||||
"datasets",
|
||||
|
@ -11,29 +13,43 @@ requirements = [
|
|||
"rouge",
|
||||
"torch>=1.13.0",
|
||||
"safetensors",
|
||||
"transformers>=4.26.1"
|
||||
"transformers>=4.26.1",
|
||||
"triton>=2.0.0"
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
"llama": ["transformers>=4.28.0"]
|
||||
}
|
||||
|
||||
extensions = [
|
||||
cpp_extension.CUDAExtension(
|
||||
"quant_cuda",
|
||||
[
|
||||
join(project_root, "auto_gptq/quantization/quant_cuda.cpp"),
|
||||
join(project_root, "auto_gptq/quantization/quant_cuda_kernel.cu")
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
setup(
|
||||
name="auto_gptq",
|
||||
packages=find_packages(),
|
||||
version="v0.0.4-dev",
|
||||
install_requires=requirements,
|
||||
extras_require=extras_require,
|
||||
ext_modules=extensions,
|
||||
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
||||
)
|
||||
if TORCH_AVAILABLE:
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
extensions = [
|
||||
cpp_extension.CUDAExtension(
|
||||
"quant_cuda",
|
||||
[
|
||||
"quant_cuda/quant_cuda.cpp",
|
||||
"quant_cuda/quant_cuda_kernel.cu"
|
||||
]
|
||||
)
|
||||
]
|
||||
setup(
|
||||
name="auto_gptq",
|
||||
packages=find_packages(),
|
||||
version="v0.0.4-dev",
|
||||
install_requires=requirements,
|
||||
extras_require=extras_require,
|
||||
include_dirs=["quant_cuda"],
|
||||
ext_modules=extensions,
|
||||
cmdclass={'build_ext': cpp_extension.BuildExtension}
|
||||
)
|
||||
else:
|
||||
setup(
|
||||
name="auto_gptq",
|
||||
packages=find_packages(),
|
||||
version="v0.0.4-dev",
|
||||
install_requires=requirements,
|
||||
extras_require=extras_require,
|
||||
include_dirs=["quant_cuda"]
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue