make_sure_not_tensor_in_meta_device before load checkpoint
This commit is contained in:
parent
63f1b4e073
commit
c31b370228
1 changed files with 2 additions and 2 deletions
|
@ -606,6 +606,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
no_split_module_classes=[cls.layer_type]
|
no_split_module_classes=[cls.layer_type]
|
||||||
)
|
)
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
||||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||||
model,
|
model,
|
||||||
checkpoint=model_save_name,
|
checkpoint=model_save_name,
|
||||||
|
@ -613,8 +615,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
offload_state_dict=True,
|
offload_state_dict=True,
|
||||||
offload_buffers=True
|
offload_buffers=True
|
||||||
)
|
)
|
||||||
if low_cpu_mem_usage:
|
|
||||||
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
|
||||||
if full_cpu_offload and "cpu" in list(device_map.values()):
|
if full_cpu_offload and "cpu" in list(device_map.values()):
|
||||||
model = simple_dispatch_model(model, device_map)
|
model = simple_dispatch_model(model, device_map)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Reference in a new issue