remove argument 'save_dir' in method from_quantized
This commit is contained in:
parent
722a621aaa
commit
ff1f100ded
3 changed files with 4 additions and 17 deletions
|
@ -676,8 +676,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_quantized(
|
def from_quantized(
|
||||||
cls,
|
cls,
|
||||||
model_name_or_path: Optional[str] = None,
|
model_name_or_path: Optional[str],
|
||||||
save_dir: Optional[str] = None,
|
|
||||||
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
|
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
|
||||||
max_memory: Optional[dict] = None,
|
max_memory: Optional[dict] = None,
|
||||||
device: Optional[Union[str, int]] = None,
|
device: Optional[Union[str, int]] = None,
|
||||||
|
@ -726,14 +725,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
use_triton = False
|
use_triton = False
|
||||||
|
|
||||||
# == step1: prepare configs and file names == #
|
# == step1: prepare configs and file names == #
|
||||||
if model_name_or_path and save_dir:
|
|
||||||
logger.warning("save_dir will be ignored because model_name_or_path is explicit specified.")
|
|
||||||
if not model_name_or_path and save_dir:
|
|
||||||
model_name_or_path = save_dir
|
|
||||||
warnings.warn("save_dir is deprecated and will be removed in version 0.3.0", PendingDeprecationWarning, stacklevel=2)
|
|
||||||
if not model_name_or_path and not save_dir:
|
|
||||||
raise ValueError("at least one of model_name_or_path or save_dir should be specified.")
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||||
|
|
||||||
if config.model_type not in SUPPORTED_MODELS:
|
if config.model_type not in SUPPORTED_MODELS:
|
||||||
|
|
|
@ -64,8 +64,7 @@ class AutoGPTQForCausalLM:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_quantized(
|
def from_quantized(
|
||||||
cls,
|
cls,
|
||||||
model_name_or_path: Optional[str] = None,
|
model_name_or_path: Optional[str],
|
||||||
save_dir: Optional[str] = None,
|
|
||||||
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
|
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
|
||||||
max_memory: Optional[dict] = None,
|
max_memory: Optional[dict] = None,
|
||||||
device: Optional[Union[str, int]] = None,
|
device: Optional[Union[str, int]] = None,
|
||||||
|
@ -82,9 +81,7 @@ class AutoGPTQForCausalLM:
|
||||||
trainable: bool = False,
|
trainable: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseGPTQForCausalLM:
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(
|
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
|
||||||
save_dir or model_name_or_path, trust_remote_code
|
|
||||||
)
|
|
||||||
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
|
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
|
||||||
# A static list of kwargs needed for huggingface_hub
|
# A static list of kwargs needed for huggingface_hub
|
||||||
huggingface_kwargs = [
|
huggingface_kwargs = [
|
||||||
|
@ -107,7 +104,6 @@ class AutoGPTQForCausalLM:
|
||||||
}
|
}
|
||||||
return quant_func(
|
return quant_func(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
save_dir=save_dir,
|
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
|
@ -165,7 +165,7 @@ def load_model_tokenizer(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = AutoGPTQForCausalLM.from_quantized(
|
model = AutoGPTQForCausalLM.from_quantized(
|
||||||
save_dir=model_name_or_path,
|
model_name_or_path,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
use_triton=use_triton,
|
use_triton=use_triton,
|
||||||
|
|
Loading…
Add table
Reference in a new issue