expose api to set exllama max length

This commit is contained in:
Felix Marty 2023-08-24 11:22:15 +00:00
parent 3cd79c826e
commit 04730ac66c
8 changed files with 101 additions and 11 deletions

View file

@ -2,3 +2,4 @@ __version__ = "0.4.1"
from .modeling import BaseQuantizeConfig
from .modeling import AutoGPTQForCausalLM
from .utils.peft_utils import get_gptq_peft_model
from .utils.exllama_utils import exllama_set_max_input_length

View file

@ -982,5 +982,4 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
except:
return getattr(self.model, item)
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]

View file

@ -25,4 +25,6 @@ SUPPORTED_MODELS = [
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]

View file

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Union
from typing import Union, Optional
import accelerate
import torch
@ -7,7 +7,7 @@ import torch.nn as nn
from transformers import AutoConfig
import transformers
from ._const import SUPPORTED_MODELS, CPU, CUDA_0
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from ..utils.import_utils import dynamically_import_QuantLinear
@ -187,7 +187,10 @@ def simple_dispatch_model(model, device_map):
return model
def autogptq_post_init(model, use_act_order: bool):
def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None):
"""
The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
"""
device_to_buffers_size = {}
model_uses_exllama = False
@ -229,9 +232,13 @@ def autogptq_post_init(model, use_act_order: bool):
device_to_buffers = {}
if use_act_order:
# TODO: initialize this properly
max_input_len = 2048
if max_input_length is None:
max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
else:
max_input_len = max_input_len
else:
if max_input_length is not None:
logger.info("Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored.")
max_input_len = 1
for device, buffers_size in device_to_buffers_size.items():
@ -239,7 +246,9 @@ def autogptq_post_init(model, use_act_order: bool):
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros((max_input_len, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device)
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
# Buffers need to be persistent to avoid any bug.

View file

@ -0,0 +1,48 @@
import gc
import torch
def exllama_set_max_input_length(model, max_input_length: int):
"""
This method does not necessarily require `model` to inherit from BaseGPTQForCausalLM.
When using the exllama backend with act-order, it is necessary to initialize a buffer that depends on the maximum expected input length. In case the
default used (EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) is too short, this method can be called to extend the buffer size without reloading the whole model.
"""
# The import is set here to avoid a global import. Arguably this is quite ugly, it would be better to have lazy loading.
from exllama_kernels import prepare_buffers, cleanup_buffers_cuda
if not model.quantize_config.desc_act:
raise ValueError("The method exllama_set_max_input_length should be called only when using the exllama backend **with act-order**.")
device_to_buffers_size = {}
for device, buffers in model.device_to_buffers.items():
device_to_buffers_size[device] = {"max_dq_buffer_size": buffers["max_dq_buffer_size"], "max_inner_outer_dim": buffers["max_inner_outer_dim"]}
# For an unknown reason calling just `del model.device_to_buffers` raises an AttributeError.
for key in list(model.device_to_buffers.keys()):
del model.device_to_buffers[key]
model.device_to_buffers = None
del model.device_to_buffers
gc.collect()
torch.cuda.empty_cache()
cleanup_buffers_cuda()
device_to_buffers = {}
for device, buffers_size in device_to_buffers_size.items():
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
device_to_buffers[device] = {
"temp_state": torch.zeros((max_input_length, buffers_size["max_inner_outer_dim"]), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, buffers_size["max_dq_buffer_size"]), dtype=torch.float16, device=device),
"max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
"max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
}
prepare_buffers(device, device_to_buffers[device]["temp_state"], device_to_buffers[device]["temp_dq"])
# Buffers need to be persistent to avoid any bug.
model.device_to_buffers = device_to_buffers
return model

View file

@ -240,7 +240,7 @@ void q4_matmul_recons_cuda
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}

View file

@ -251,4 +251,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
}

View file

@ -6,8 +6,9 @@ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
from exllama_kernels import prepare_buffers, set_tuning_params
from auto_gptq import AutoGPTQForCausalLM
from auto_gptq import AutoGPTQForCausalLM, exllama_set_max_input_length
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.modeling._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from transformers import AutoTokenizer
@ -188,6 +189,35 @@ class TestsQ4Exllama(unittest.TestCase):
self.assertTrue(torch.allclose(res, reference, rtol=3e-5, atol=2e-2), get_diff(res, reference))
def test_exllama_buffer_size(self):
prompt = "I am in Paris and" * 450
device = torch.device("cuda:0")
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
revision = "actorder"
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer(prompt, return_tensors="pt").to(device)
self.assertTrue(inp["input_ids"].shape[1] > EXLLAMA_DEFAULT_MAX_INPUT_LENGTH) # 2048 is the default max_input_length
with self.assertRaises(RuntimeError) as cm:
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
self.assertTrue("temp_state buffer is too small" in str(cm.exception))
model_q = exllama_set_max_input_length(model_q, 4096)
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
model_q = exllama_set_max_input_length(model_q, 1034)
with self.assertRaises(RuntimeError) as cm:
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
self.assertTrue("temp_state buffer is too small" in str(cm.exception))
def test_generation_no_act_order(self):
prompt = "I am in Paris and"
device = torch.device("cuda:0")
@ -196,7 +226,7 @@ class TestsQ4Exllama(unittest.TestCase):
reference_output = "<s> I am in Paris and I am going to the Louvre Museum. What time does it open and what is the best way to get there?\nThe Louvre Museum in Paris is open from 9:00 AM to 6:00 PM every day except for Tuesdays. The best way to get"
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order"
model_basename = "model"
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
tokenizer = AutoTokenizer.from_pretrained(model_id)