diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index b369839..deecbc5 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -606,6 +606,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): max_memory=max_memory, no_split_module_classes=[cls.layer_type] ) + if low_cpu_mem_usage: + make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size) accelerate.utils.modeling.load_checkpoint_in_model( model, checkpoint=model_save_name, @@ -613,8 +615,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): offload_state_dict=True, offload_buffers=True ) - if low_cpu_mem_usage: - make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size) if full_cpu_offload and "cpu" in list(device_map.values()): model = simple_dispatch_model(model, device_map) else: