diff --git a/auto_gptq/modeling/llama.py b/auto_gptq/modeling/llama.py index 8062d8f..f13d02b 100644 --- a/auto_gptq/modeling/llama.py +++ b/auto_gptq/modeling/llama.py @@ -40,6 +40,7 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM): rope_cache = build_rope_cache( rotary_dim=model_config.hidden_size // num_heads, max_position=model_config.max_position_embeddings, + base=10000, device=model.device, dtype=model.dtype )