make_sure_not_tensor_in_meta_device before load checkpoint
This commit is contained in:
parent
63f1b4e073
commit
c31b370228
1 changed files with 2 additions and 2 deletions
|
@ -606,6 +606,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type]
|
||||
)
|
||||
if low_cpu_mem_usage:
|
||||
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||
model,
|
||||
checkpoint=model_save_name,
|
||||
|
@ -613,8 +615,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
offload_state_dict=True,
|
||||
offload_buffers=True
|
||||
)
|
||||
if low_cpu_mem_usage:
|
||||
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
||||
if full_cpu_offload and "cpu" in list(device_map.values()):
|
||||
model = simple_dispatch_model(model, device_map)
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue