From c31b3702283d89fbf5db5aedd2c7710a19a918f3 Mon Sep 17 00:00:00 2001 From: PanQiWei <594557445@qq.com> Date: Wed, 24 May 2023 11:32:45 +0800 Subject: [PATCH] make_sure_not_tensor_in_meta_device before load checkpoint --- auto_gptq/modeling/_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index b369839..deecbc5 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -606,6 +606,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): max_memory=max_memory, no_split_module_classes=[cls.layer_type] ) + if low_cpu_mem_usage: + make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size) accelerate.utils.modeling.load_checkpoint_in_model( model, checkpoint=model_save_name, @@ -613,8 +615,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): offload_state_dict=True, offload_buffers=True ) - if low_cpu_mem_usage: - make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size) if full_cpu_offload and "cpu" in list(device_map.values()): model = simple_dispatch_model(model, device_map) else: