support dispatch layers to different devices when loading pretrained model before quantization

This commit is contained in:
PanQiWei 2023-04-27 02:24:08 +08:00
parent 950f203260
commit a2abff983e
11 changed files with 68 additions and 21 deletions

View file

@ -4,4 +4,5 @@ from .bloom import *
from .gpt_neox import *
from .gptj import *
from .llama import *
from .moss import *
from .opt import *

View file

@ -3,8 +3,9 @@ import os
from dataclasses import dataclass, field, fields
from logging import getLogger
from os.path import join
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
import accelerate
import torch
import torch.nn as nn
import transformers
@ -55,6 +56,7 @@ class BaseQuantizeConfig(PushToHubMixin):
class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
layer_type: str = None
layers_block_name: str = None
outside_layer_modules: List[str] = None
inside_layer_modules: List[List[str]] = None
@ -138,8 +140,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
num_examples = len(examples)
layers = get_module_by_name(self.model, self.layers_block_name)
layers[0] = layers[0].to(CUDA)
self._move_outside_layer_modules(CUDA)
layers[0] = layers[0].to(CUDA_0)
self._move_outside_layer_modules(CUDA_0)
# get inputs for first layer
layers[0] = LayerHijacker(layers[0])
@ -147,7 +149,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
for k, v in example.items():
if len(v.shape) == 1:
v = v.unsqueeze(0)
example[k] = v.to(CUDA)
example[k] = v.to(CUDA_0)
try:
self.model(**example)
except ValueError:
@ -166,7 +168,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantizers = {}
for i in range(len(layers)):
logger.info(f"Start quantizing layer {i + 1}/{len(layers)}")
layer = layers[i].to(CUDA)
layer = layers[i].to(CUDA_0)
full = find_layers(layer)
for names in self.inside_layer_modules:
@ -191,16 +193,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(num_examples):
layer_input = layer_inputs[j].to(CUDA)
layer_attention_mask = attention_masks[j].to(CUDA)
layer_input = layer_inputs[j].to(CUDA_0)
layer_attention_mask = attention_masks[j].to(CUDA_0)
additional_layer_inputs = {
"attention_mask": layer_attention_mask
}
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA)) is not None:
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA_0)) is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = v.to(CUDA)
additional_layer_inputs[k] = v.to(CUDA_0)
else:
additional_layer_inputs[k] = v
layer(layer_input, **additional_layer_inputs)[0][0].cpu()
@ -220,16 +222,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
gptq[name].free()
for j in range(num_examples):
layer_input = layer_inputs[j].to(CUDA)
layer_attention_mask = attention_masks[j].to(CUDA)
layer_input = layer_inputs[j].to(CUDA_0)
layer_attention_mask = attention_masks[j].to(CUDA_0)
additional_layer_inputs = {
"attention_mask": layer_attention_mask
}
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA)) is not None:
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA_0)) is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = v.to(CUDA)
additional_layer_inputs[k] = v.to(CUDA_0)
else:
additional_layer_inputs[k] = v
layer_output = layer(layer_input, **additional_layer_inputs)[0][0].cpu()
@ -304,11 +306,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
bf16: bool = False,
max_memory_per_gpu: Optional[int] = None,
**model_init_kwargs
):
"""load un-quantized pretrained model to cpu"""
if not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
def skip(*args, **kwargs):
pass
@ -321,10 +326,41 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified
model_init_kwargs["device_map"] = None
model_init_kwargs["torch_dtype"] = torch.bfloat16 if bf16 else torch.float16
model_init_kwargs["low_cpu_mem_usage"] = False
model_init_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
model_init_kwargs["trust_remote_code"] = True
max_memory = None
if "max_memory" in model_init_kwargs:
max_memory = model_init_kwargs.pop("max_memory")
if max_memory_per_gpu is not None:
max_memory = {
cuda_id: f"{max_memory_per_gpu}GIB" for cuda_id in range(torch.cuda.device_count())
}
if max_memory:
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model.tie_weights()
max_memory = accelerate.utils.get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
low_zero=False
)
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"]
)
model_init_kwargs["low_cpu_mem_usage"] = True
del model
else:
model_init_kwargs["device_map"] = None
model_init_kwargs["low_cpu_mem_usage"] = False
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
model_config = model.config.to_dict()

View file

@ -4,10 +4,10 @@ from torch import device
from transformers import __version__ as transformers_version
CPU = device("cpu")
CUDA = device("cuda:0")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt", "moss"]
if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama")
__all__ = ["CPU", "CUDA", "SUPPORTED_MODELS"]
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]

View file

@ -3,7 +3,7 @@ from logging import getLogger
import torch.nn as nn
from transformers import AutoConfig
from ._const import SUPPORTED_MODELS, CUDA
from ._const import SUPPORTED_MODELS, CUDA_0
logger = getLogger(__name__)
@ -67,7 +67,7 @@ def pack_model(model, quantizers, bits, group_size, use_triton=False, autotune_w
logger.warning(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the hole model."
)
autotune_warmup_linear(model.to(CUDA), seqlen=model.seqlen)
autotune_warmup_linear(model.to(CUDA_0), seqlen=model.seqlen)
def check_and_get_model_type(model_dir):

View file

@ -2,6 +2,7 @@ from ._base import *
class BloomGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "BloomBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"]
inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class GPTNeoXGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTNeoXLayer"
layers_block_name = "gpt_neox.layers"
outside_layer_modules = ["gpt_neox.embed_in", "gpt_neox.final_layer_norm"]
inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTJBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "LlamaDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class MOSSGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MossBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.drop", "transformer.ln_f"]
inside_layer_modules = [
@ -10,3 +11,6 @@ class MOSSGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_in"],
["mlp.fc_out"]
]
__all__ = ["MOSSGPTQForCausalLM"]

View file

@ -2,6 +2,7 @@ from ._base import *
class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "OPTDecoderLayer"
layers_block_name = "model.decoder.layers"
outside_layer_modules = [
"model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out",

View file

@ -10,6 +10,7 @@ except ImportError:
version = "v0.1.0-dev"
requirements = [
"accelerate",
"datasets",
"numpy",
"rouge",