update setup.py

This commit is contained in:
PanQiWei 2023-04-25 18:50:21 +08:00
parent 419160b733
commit 0bb84053a2
4 changed files with 39 additions and 33 deletions

View file

@ -1,10 +0,0 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(
name='quant_cuda',
ext_modules=[cpp_extension.CUDAExtension(
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu']
)],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)

View file

@ -1,9 +1,11 @@
from os.path import abspath, dirname, join from setuptools import setup, find_packages
from setuptools import setup, find_packages, Extension
from torch.utils import cpp_extension try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
project_root = dirname(abspath(__file__))
requirements = [ requirements = [
"datasets", "datasets",
@ -11,29 +13,43 @@ requirements = [
"rouge", "rouge",
"torch>=1.13.0", "torch>=1.13.0",
"safetensors", "safetensors",
"transformers>=4.26.1" "transformers>=4.26.1",
"triton>=2.0.0"
] ]
extras_require = { extras_require = {
"llama": ["transformers>=4.28.0"] "llama": ["transformers>=4.28.0"]
} }
extensions = [
if TORCH_AVAILABLE:
from torch.utils import cpp_extension
extensions = [
cpp_extension.CUDAExtension( cpp_extension.CUDAExtension(
"quant_cuda", "quant_cuda",
[ [
join(project_root, "auto_gptq/quantization/quant_cuda.cpp"), "quant_cuda/quant_cuda.cpp",
join(project_root, "auto_gptq/quantization/quant_cuda_kernel.cu") "quant_cuda/quant_cuda_kernel.cu"
] ]
) )
] ]
setup(
setup(
name="auto_gptq", name="auto_gptq",
packages=find_packages(), packages=find_packages(),
version="v0.0.4-dev", version="v0.0.4-dev",
install_requires=requirements, install_requires=requirements,
extras_require=extras_require, extras_require=extras_require,
include_dirs=["quant_cuda"],
ext_modules=extensions, ext_modules=extensions,
cmdclass={'build_ext': cpp_extension.BuildExtension} cmdclass={'build_ext': cpp_extension.BuildExtension}
) )
else:
setup(
name="auto_gptq",
packages=find_packages(),
version="v0.0.4-dev",
install_requires=requirements,
extras_require=extras_require,
include_dirs=["quant_cuda"]
)