support dispatch layers to different devices when loading pretrained model before quantization

This commit is contained in:
PanQiWei 2023-04-27 02:24:08 +08:00
parent 950f203260
commit a2abff983e
11 changed files with 68 additions and 21 deletions

View file

@ -4,4 +4,5 @@ from .bloom import *
from .gpt_neox import * from .gpt_neox import *
from .gptj import * from .gptj import *
from .llama import * from .llama import *
from .moss import *
from .opt import * from .opt import *

View file

@ -3,8 +3,9 @@ import os
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from logging import getLogger from logging import getLogger
from os.path import join from os.path import join
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
import accelerate
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
@ -55,6 +56,7 @@ class BaseQuantizeConfig(PushToHubMixin):
class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
layer_type: str = None
layers_block_name: str = None layers_block_name: str = None
outside_layer_modules: List[str] = None outside_layer_modules: List[str] = None
inside_layer_modules: List[List[str]] = None inside_layer_modules: List[List[str]] = None
@ -138,8 +140,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
num_examples = len(examples) num_examples = len(examples)
layers = get_module_by_name(self.model, self.layers_block_name) layers = get_module_by_name(self.model, self.layers_block_name)
layers[0] = layers[0].to(CUDA) layers[0] = layers[0].to(CUDA_0)
self._move_outside_layer_modules(CUDA) self._move_outside_layer_modules(CUDA_0)
# get inputs for first layer # get inputs for first layer
layers[0] = LayerHijacker(layers[0]) layers[0] = LayerHijacker(layers[0])
@ -147,7 +149,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
for k, v in example.items(): for k, v in example.items():
if len(v.shape) == 1: if len(v.shape) == 1:
v = v.unsqueeze(0) v = v.unsqueeze(0)
example[k] = v.to(CUDA) example[k] = v.to(CUDA_0)
try: try:
self.model(**example) self.model(**example)
except ValueError: except ValueError:
@ -166,7 +168,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
quantizers = {} quantizers = {}
for i in range(len(layers)): for i in range(len(layers)):
logger.info(f"Start quantizing layer {i + 1}/{len(layers)}") logger.info(f"Start quantizing layer {i + 1}/{len(layers)}")
layer = layers[i].to(CUDA) layer = layers[i].to(CUDA_0)
full = find_layers(layer) full = find_layers(layer)
for names in self.inside_layer_modules: for names in self.inside_layer_modules:
@ -191,16 +193,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
for name in subset: for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name))) handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(num_examples): for j in range(num_examples):
layer_input = layer_inputs[j].to(CUDA) layer_input = layer_inputs[j].to(CUDA_0)
layer_attention_mask = attention_masks[j].to(CUDA) layer_attention_mask = attention_masks[j].to(CUDA_0)
additional_layer_inputs = { additional_layer_inputs = {
"attention_mask": layer_attention_mask "attention_mask": layer_attention_mask
} }
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA)) is not None: if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA_0)) is not None:
additional_layer_inputs["position_ids"] = layer_position_ids additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items(): for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = v.to(CUDA) additional_layer_inputs[k] = v.to(CUDA_0)
else: else:
additional_layer_inputs[k] = v additional_layer_inputs[k] = v
layer(layer_input, **additional_layer_inputs)[0][0].cpu() layer(layer_input, **additional_layer_inputs)[0][0].cpu()
@ -220,16 +222,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
gptq[name].free() gptq[name].free()
for j in range(num_examples): for j in range(num_examples):
layer_input = layer_inputs[j].to(CUDA) layer_input = layer_inputs[j].to(CUDA_0)
layer_attention_mask = attention_masks[j].to(CUDA) layer_attention_mask = attention_masks[j].to(CUDA_0)
additional_layer_inputs = { additional_layer_inputs = {
"attention_mask": layer_attention_mask "attention_mask": layer_attention_mask
} }
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA)) is not None: if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA_0)) is not None:
additional_layer_inputs["position_ids"] = layer_position_ids additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items(): for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = v.to(CUDA) additional_layer_inputs[k] = v.to(CUDA_0)
else: else:
additional_layer_inputs[k] = v additional_layer_inputs[k] = v
layer_output = layer(layer_input, **additional_layer_inputs)[0][0].cpu() layer_output = layer(layer_input, **additional_layer_inputs)[0][0].cpu()
@ -304,11 +306,14 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
cls, cls,
pretrained_model_name_or_path: str, pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig, quantize_config: BaseQuantizeConfig,
bf16: bool = False, max_memory_per_gpu: Optional[int] = None,
**model_init_kwargs **model_init_kwargs
): ):
"""load un-quantized pretrained model to cpu""" """load un-quantized pretrained model to cpu"""
if not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
def skip(*args, **kwargs): def skip(*args, **kwargs):
pass pass
@ -321,10 +326,41 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
raise TypeError(f"{config.model_type} isn't supported yet.") raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified # enforce some values despite user specified
model_init_kwargs["device_map"] = None model_init_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
model_init_kwargs["torch_dtype"] = torch.bfloat16 if bf16 else torch.float16
model_init_kwargs["low_cpu_mem_usage"] = False
model_init_kwargs["trust_remote_code"] = True model_init_kwargs["trust_remote_code"] = True
max_memory = None
if "max_memory" in model_init_kwargs:
max_memory = model_init_kwargs.pop("max_memory")
if max_memory_per_gpu is not None:
max_memory = {
cuda_id: f"{max_memory_per_gpu}GIB" for cuda_id in range(torch.cuda.device_count())
}
if max_memory:
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model.tie_weights()
max_memory = accelerate.utils.get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
low_zero=False
)
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"]
)
model_init_kwargs["low_cpu_mem_usage"] = True
del model
else:
model_init_kwargs["device_map"] = None
model_init_kwargs["low_cpu_mem_usage"] = False
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs) model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
model_config = model.config.to_dict() model_config = model.config.to_dict()

View file

@ -4,10 +4,10 @@ from torch import device
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
CPU = device("cpu") CPU = device("cpu")
CUDA = device("cuda:0") CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt", "moss"] SUPPORTED_MODELS = ["bloom", "gptj", "gpt_neox", "opt", "moss"]
if parse_version(transformers_version) >= parse_version("v4.28.0"): if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")
__all__ = ["CPU", "CUDA", "SUPPORTED_MODELS"] __all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS"]

View file

@ -3,7 +3,7 @@ from logging import getLogger
import torch.nn as nn import torch.nn as nn
from transformers import AutoConfig from transformers import AutoConfig
from ._const import SUPPORTED_MODELS, CUDA from ._const import SUPPORTED_MODELS, CUDA_0
logger = getLogger(__name__) logger = getLogger(__name__)
@ -67,7 +67,7 @@ def pack_model(model, quantizers, bits, group_size, use_triton=False, autotune_w
logger.warning( logger.warning(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the hole model." "using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the hole model."
) )
autotune_warmup_linear(model.to(CUDA), seqlen=model.seqlen) autotune_warmup_linear(model.to(CUDA_0), seqlen=model.seqlen)
def check_and_get_model_type(model_dir): def check_and_get_model_type(model_dir):

View file

@ -2,6 +2,7 @@ from ._base import *
class BloomGPTQForCausalLM(BaseGPTQForCausalLM): class BloomGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "BloomBlock"
layers_block_name = "transformer.h" layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"] outside_layer_modules = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"]
inside_layer_modules = [ inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class GPTNeoXGPTQForCausalLM(BaseGPTQForCausalLM): class GPTNeoXGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTNeoXLayer"
layers_block_name = "gpt_neox.layers" layers_block_name = "gpt_neox.layers"
outside_layer_modules = ["gpt_neox.embed_in", "gpt_neox.final_layer_norm"] outside_layer_modules = ["gpt_neox.embed_in", "gpt_neox.final_layer_norm"]
inside_layer_modules = [ inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class GPTJGPTQForCausalLM(BaseGPTQForCausalLM): class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTJBlock"
layers_block_name = "transformer.h" layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"] outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [ inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM): class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "LlamaDecoderLayer"
layers_block_name = "model.layers" layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"] outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [ inside_layer_modules = [

View file

@ -2,6 +2,7 @@ from ._base import *
class MOSSGPTQForCausalLM(BaseGPTQForCausalLM): class MOSSGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MossBlock"
layers_block_name = "transformer.h" layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.drop", "transformer.ln_f"] outside_layer_modules = ["transformer.wte", "transformer.drop", "transformer.ln_f"]
inside_layer_modules = [ inside_layer_modules = [
@ -10,3 +11,6 @@ class MOSSGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_in"], ["mlp.fc_in"],
["mlp.fc_out"] ["mlp.fc_out"]
] ]
__all__ = ["MOSSGPTQForCausalLM"]

View file

@ -2,6 +2,7 @@ from ._base import *
class OPTGPTQForCausalLM(BaseGPTQForCausalLM): class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "OPTDecoderLayer"
layers_block_name = "model.decoder.layers" layers_block_name = "model.decoder.layers"
outside_layer_modules = [ outside_layer_modules = [
"model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out", "model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out",

View file

@ -10,6 +10,7 @@ except ImportError:
version = "v0.1.0-dev" version = "v0.1.0-dev"
requirements = [ requirements = [
"accelerate",
"datasets", "datasets",
"numpy", "numpy",
"rouge", "rouge",