remove full_cpu_offload argument and unify model dispatch strategy

This commit is contained in:
PanQiWei 2023-05-24 17:41:04 +08:00
parent 379f24c2a5
commit 10347fdd7b
3 changed files with 11 additions and 8 deletions

View file

@ -496,7 +496,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
full_cpu_offload: bool = True,
use_triton: bool = False,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
@ -617,10 +616,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
offload_state_dict=True,
offload_buffers=True
)
if full_cpu_offload and "cpu" in list(device_map.values()):
model = simple_dispatch_model(model, device_map)
else:
model = accelerate.dispatch_model(model, device_map=device_map)
model = simple_dispatch_model(model, device_map)
# == step4: set seqlen == #
model_config = model.config.to_dict()

View file

@ -128,6 +128,14 @@ def check_and_get_model_type(model_dir, trust_remote_code=False):
def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
@ -147,7 +155,8 @@ def simple_dispatch_model(model, device_map):
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
accelerate.hooks.attach_align_device_hook(m, execution_device=d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map

View file

@ -56,7 +56,6 @@ class AutoGPTQForCausalLM:
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
full_cpu_offload: bool = True,
use_triton: bool = False,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
@ -77,7 +76,6 @@ class AutoGPTQForCausalLM:
max_memory=max_memory,
device=device,
low_cpu_mem_usage=low_cpu_mem_usage,
full_cpu_offload=full_cpu_offload,
use_triton=use_triton,
inject_fused_attention=inject_fused_attention,
inject_fused_mlp=inject_fused_mlp,