make compatible with older transformers version

This commit is contained in:
PanQiWei 2023-05-15 13:26:18 +08:00
parent 262669112b
commit 07e06fa08c

View file

@ -1,18 +1,14 @@
from logging import getLogger
from os.path import join, isfile
from typing import Optional, Union
import accelerate
import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from ._const import *
from ._utils import *
from ._base import *
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
from ..utils.import_utils import compare_transformers_version
if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None
logger = getLogger(__name__)