update setup.py

This commit is contained in:
PanQiWei 2023-04-25 18:50:21 +08:00
parent 419160b733
commit 0bb84053a2
4 changed files with 39 additions and 33 deletions

View file

@ -1,10 +0,0 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(
name='quant_cuda',
ext_modules=[cpp_extension.CUDAExtension(
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
)],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)

View file

@ -1,9 +1,11 @@
from os.path import abspath, dirname, join
from setuptools import setup, find_packages, Extension
from setuptools import setup, find_packages
from torch.utils import cpp_extension
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
project_root = dirname(abspath(__file__))
requirements = [
"datasets",
@ -11,29 +13,43 @@ requirements = [
"rouge",
"torch>=1.13.0",
"safetensors",
"transformers>=4.26.1"
"transformers>=4.26.1",
"triton>=2.0.0"
]
extras_require = {
"llama": ["transformers>=4.28.0"]
}
extensions = [
cpp_extension.CUDAExtension(
"quant_cuda",
[
join(project_root, "auto_gptq/quantization/quant_cuda.cpp"),
join(project_root, "auto_gptq/quantization/quant_cuda_kernel.cu")
]
)
]
setup(
name="auto_gptq",
packages=find_packages(),
version="v0.0.4-dev",
install_requires=requirements,
extras_require=extras_require,
ext_modules=extensions,
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
if TORCH_AVAILABLE:
from torch.utils import cpp_extension
extensions = [
cpp_extension.CUDAExtension(
"quant_cuda",
[
"quant_cuda/quant_cuda.cpp",
"quant_cuda/quant_cuda_kernel.cu"
]
)
]
setup(
name="auto_gptq",
packages=find_packages(),
version="v0.0.4-dev",
install_requires=requirements,
extras_require=extras_require,
include_dirs=["quant_cuda"],
ext_modules=extensions,
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
else:
setup(
name="auto_gptq",
packages=find_packages(),
version="v0.0.4-dev",
install_requires=requirements,
extras_require=extras_require,
include_dirs=["quant_cuda"]
)