correct typo of function name
This commit is contained in:
parent
10347fdd7b
commit
c89bb6450c
2 changed files with 3 additions and 3 deletions
|
@ -607,7 +607,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
||||||
|
|
||||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -163,7 +163,7 @@ def simple_dispatch_model(model, device_map):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def make_sure_not_tensor_in_meta_device(model, use_triton, desc_act, group_size):
|
def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size):
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size)
|
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size)
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
|
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
|
||||||
|
@ -180,5 +180,5 @@ __all__ = [
|
||||||
"pack_model",
|
"pack_model",
|
||||||
"check_and_get_model_type",
|
"check_and_get_model_type",
|
||||||
"simple_dispatch_model",
|
"simple_dispatch_model",
|
||||||
"make_sure_not_tensor_in_meta_device"
|
"make_sure_no_tensor_in_meta_device"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Reference in a new issue