expose api to set exllama max length
This commit is contained in:
parent
3cd79c826e
commit
04730ac66c
8 changed files with 101 additions and 11 deletions
|
@ -2,3 +2,4 @@ __version__ = "0.4.1"
|
|||
from .modeling import BaseQuantizeConfig
|
||||
from .modeling import AutoGPTQForCausalLM
|
||||
from .utils.peft_utils import get_gptq_peft_model
|
||||
from .utils.exllama_utils import exllama_set_max_input_length
|
|
@ -982,5 +982,4 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
except:
|
||||
return getattr(self.model, item)
|
||||
|
||||
|
||||
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]
|
||||
|
|
|
@ -25,4 +25,6 @@ SUPPORTED_MODELS = [
|
|||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
|
||||
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
|
||||
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
||||
|
||||
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
from transformers import AutoConfig
|
||||
import transformers
|
||||
|
||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0
|
||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
from ..utils.import_utils import dynamically_import_QuantLinear
|
||||
|
||||
|
||||
|
@ -187,7 +187,10 @@ def simple_dispatch_model(model, device_map):
|
|||
return model
|
||||
|
||||
|
||||
def autogptq_post_init(model, use_act_order: bool):
|
||||
def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None):
|
||||
"""
|
||||
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
|
||||
"""
|
||||
device_to_buffers_size = {}
|
||||
|
||||
model_uses_exllama = False
|
||||
|
@ -229,9 +232,13 @@ def autogptq_post_init(model, use_act_order: bool):
|
|||
device_to_buffers = {}
|
||||
|
||||
if use_act_order:
|
||||
# TODO: initialize this properly
|
||||
max_input_len = 2048
|
||||
if max_input_length is None:
|
||||
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
else:
|
||||
max_input_len = max_input_len
|
||||
else:
|
||||
if max_input_length is not None:
|
||||
logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.")
|
||||
max_input_len = 1
|
||||
|
||||
for device, buffers_size in device_to_buffers_size.items():
|
||||
|
@ -239,7 +246,9 @@ def autogptq_post_init(model, use_act_order: bool):
|
|||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
device_to_buffers[device] = {
|
||||
"temp_state": torch.zeros((max_input_len, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device)
|
||||
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
|
||||
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
|
||||
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
|
||||
}
|
||||
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
|
|
48
auto_gptq/utils/exllama_utils.py
Normal file
48
auto_gptq/utils/exllama_utils.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import gc
|
||||
import torch
|
||||
|
||||
def exllama_set_max_input_length(model, max_input_length: int):
|
||||
"""
|
||||
This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
|
||||
|
||||
When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
|
||||
default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
|
||||
"""
|
||||
|
||||
# The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
|
||||
from exllama_kernels import prepare_buffers, cleanup_buffers_cuda
|
||||
|
||||
if not model.quantize_config.desc_act:
|
||||
raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.")
|
||||
|
||||
device_to_buffers_size = {}
|
||||
for device, buffers in model.device_to_buffers.items():
|
||||
device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]}
|
||||
|
||||
# For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
|
||||
for key in list(model.device_to_buffers.keys()):
|
||||
del model.device_to_buffers[key]
|
||||
model.device_to_buffers = None
|
||||
del model.device_to_buffers
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_buffers_cuda()
|
||||
|
||||
device_to_buffers = {}
|
||||
for device, buffers_size in device_to_buffers_size.items():
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
device_to_buffers[device] = {
|
||||
"temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
|
||||
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
|
||||
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
|
||||
}
|
||||
|
||||
prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"])
|
||||
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
model.device_to_buffers = device_to_buffers
|
||||
|
||||
return model
|
|
@ -240,7 +240,7 @@ void q4_matmul_recons_cuda
|
|||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
|
|
@ -251,4 +251,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
|
||||
}
|
||||
|
|
|
@ -6,8 +6,9 @@ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
|
|||
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
|
||||
|
||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
from auto_gptq import AutoGPTQForCausalLM, exllama_set_max_input_length
|
||||
from auto_gptq.modeling._utils import autogptq_post_init
|
||||
from auto_gptq.modeling._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
@ -188,6 +189,35 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
|
||||
self.assertTrue(torch.allclose(res, reference, rtol=3e-5, atol=2e-2), get_diff(res, reference))
|
||||
|
||||
def test_exllama_buffer_size(self):
|
||||
prompt = "I am in Paris and" * 450
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
|
||||
revision = "actorder"
|
||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
self.assertTrue(inp["input_ids"].shape[1] > EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) # 2048 is the default max_input_length
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
|
||||
self.assertTrue("temp_state buffer is too small" in str(cm.exception))
|
||||
|
||||
model_q = exllama_set_max_input_length(model_q, 4096)
|
||||
|
||||
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
|
||||
|
||||
model_q = exllama_set_max_input_length(model_q, 1034)
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
|
||||
self.assertTrue("temp_state buffer is too small" in str(cm.exception))
|
||||
|
||||
def test_generation_no_act_order(self):
|
||||
prompt = "I am in Paris and"
|
||||
device = torch.device("cuda:0")
|
||||
|
@ -196,7 +226,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
reference_output = "<s> I am in Paris and I am going to the Louvre Museum. What time does it open and what is the best way to get there?\nThe Louvre Museum in Paris is open from 9:00 AM to 6:00 PM every day except for Tuesdays. The best way to get"
|
||||
|
||||
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
|
||||
model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order"
|
||||
model_basename = "model"
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue