make_sure_not_tensor_in_meta_device before load checkpoint

This commit is contained in:
PanQiWei 2023-05-24 11:32:45 +08:00
parent 63f1b4e073
commit c31b370228

View file

@ -606,6 +606,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
max_memory=max_memory,
no_split_module_classes=[cls.layer_type]
)
if low_cpu_mem_usage:
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
accelerate.utils.modeling.load_checkpoint_in_model(
model,
checkpoint=model_save_name,
@ -613,8 +615,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
offload_state_dict=True,
offload_buffers=True
)
if low_cpu_mem_usage:
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
if full_cpu_offload and "cpu" in list(device_map.values()):
model = simple_dispatch_model(model, device_map)
else: