diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 952b73b8..6bb422ea 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -3,7 +3,6 @@ import traceback from pathlib import Path import torch - from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, @@ -16,6 +15,7 @@ from exllamav2 import ( ExLlamaV2Tokenizer ) from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator + from modules import shared from modules.logging_colors import logger from modules.text_generation import get_max_prompt_length diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index d6c3bf6e..eb801940 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -4,15 +4,6 @@ from pathlib import Path from typing import Any, Dict, Optional, Union import torch -from torch.nn import CrossEntropyLoss -from transformers import ( - GenerationConfig, - GenerationMixin, - PretrainedConfig, - PreTrainedModel -) -from transformers.modeling_outputs import CausalLMOutputWithPast - from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, @@ -23,6 +14,15 @@ from exllamav2 import ( ExLlamaV2Cache_TP, ExLlamaV2Config ) +from torch.nn import CrossEntropyLoss +from transformers import ( + GenerationConfig, + GenerationMixin, + PretrainedConfig, + PreTrainedModel +) +from transformers.modeling_outputs import CausalLMOutputWithPast + from modules import shared from modules.logging_colors import logger