remove full_cpu_offload argument and unify model dispatch strategy
This commit is contained in:
parent
379f24c2a5
commit
10347fdd7b
3 changed files with 11 additions and 8 deletions
|
@ -496,7 +496,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
max_memory: Optional[dict] = None,
|
||||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
full_cpu_offload: bool = True,
|
||||
use_triton: bool = False,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
|
@ -617,10 +616,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
offload_state_dict=True,
|
||||
offload_buffers=True
|
||||
)
|
||||
if full_cpu_offload and "cpu" in list(device_map.values()):
|
||||
model = simple_dispatch_model(model, device_map)
|
||||
else:
|
||||
model = accelerate.dispatch_model(model, device_map=device_map)
|
||||
model = simple_dispatch_model(model, device_map)
|
||||
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
|
|
|
@ -128,6 +128,14 @@ def check_and_get_model_type(model_dir, trust_remote_code=False):
|
|||
|
||||
|
||||
def simple_dispatch_model(model, device_map):
|
||||
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
|
||||
|
||||
if "" in device_map:
|
||||
d = device_map[""]
|
||||
model = model.to(torch.device(d))
|
||||
model.hf_device_map = device_map
|
||||
return model
|
||||
|
||||
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
|
||||
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
|
||||
main_device = "cpu"
|
||||
|
@ -147,7 +155,8 @@ def simple_dispatch_model(model, device_map):
|
|||
m = get_module_by_name_suffix(model, n)
|
||||
if d != "cpu":
|
||||
d = torch.device(d)
|
||||
accelerate.hooks.attach_align_device_hook(m, execution_device=d)
|
||||
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
|
||||
add_hook_to_module(m, hook)
|
||||
accelerate.utils.modeling.retie_parameters(model, tied_params)
|
||||
model.hf_device_map = device_map
|
||||
|
||||
|
|
|
@ -56,7 +56,6 @@ class AutoGPTQForCausalLM:
|
|||
max_memory: Optional[dict] = None,
|
||||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
full_cpu_offload: bool = True,
|
||||
use_triton: bool = False,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
|
@ -77,7 +76,6 @@ class AutoGPTQForCausalLM:
|
|||
max_memory=max_memory,
|
||||
device=device,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
full_cpu_offload=full_cpu_offload,
|
||||
use_triton=use_triton,
|
||||
inject_fused_attention=inject_fused_attention,
|
||||
inject_fused_mlp=inject_fused_mlp,
|
||||
|
|
Loading…
Add table
Reference in a new issue