correct typo of function name
This commit is contained in:
parent
10347fdd7b
commit
c89bb6450c
2 changed files with 3 additions and 3 deletions
|
@ -607,7 +607,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
make_sure_not_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
||||
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size)
|
||||
|
||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||
model,
|
||||
|
|
|
@ -163,7 +163,7 @@ def simple_dispatch_model(model, device_map):
|
|||
return model
|
||||
|
||||
|
||||
def make_sure_not_tensor_in_meta_device(model, use_triton, desc_act, group_size):
|
||||
def make_sure_no_tensor_in_meta_device(model, use_triton, desc_act, group_size):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size)
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
|
||||
|
@ -180,5 +180,5 @@ __all__ = [
|
|||
"pack_model",
|
||||
"check_and_get_model_type",
|
||||
"simple_dispatch_model",
|
||||
"make_sure_not_tensor_in_meta_device"
|
||||
"make_sure_no_tensor_in_meta_device"
|
||||
]
|
||||
|
|
Loading…
Add table
Reference in a new issue