diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index 417df473..1254ff5d 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -245,3 +245,20 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) return Exllamav3HF(pretrained_model_name_or_path) + + def unload(self): + """Properly unload the ExllamaV3 model and free GPU memory.""" + if hasattr(self, 'ex_model') and self.ex_model is not None: + self.ex_model.unload() + self.ex_model = None + + if hasattr(self, 'ex_cache') and self.ex_cache is not None: + self.ex_cache = None + + # Clean up any additional ExllamaV3 resources + if hasattr(self, 'past_seq'): + self.past_seq = None + if hasattr(self, 'past_seq_negative'): + self.past_seq_negative = None + if hasattr(self, 'ex_cache_negative'): + self.ex_cache_negative = None diff --git a/modules/models.py b/modules/models.py index 4218d58c..d329ae3c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -116,10 +116,13 @@ def unload_model(keep_model_name=False): return is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer') + if shared.args.loader == 'ExLlamav3_HF': + shared.model.unload() shared.model = shared.tokenizer = None shared.lora_names = [] shared.model_dirty_from_training = False + if not is_llamacpp: from modules.torch_utils import clear_torch_cache clear_torch_cache()