fix device mismatch when directly using model to inference after quantization

This commit is contained in:
PanQiWei 2023-04-28 16:41:46 +08:00
parent 892eeb40e0
commit a69a73a22c
4 changed files with 11 additions and 7 deletions

View file

@ -1,6 +1,7 @@
from ._base import BaseGPTQForCausalLM, BaseQuantizeConfig
from .auto import *
from .bloom import *
from .gpt2 import *
from .gpt_neox import *
from .gptj import *
from .llama import *

View file

@ -99,8 +99,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
if self.hf_device_map:
for name, device in self.hf_device_map.items():
device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu":
module = get_module_by_name(self.model, name)
remove_hook_from_module(module, recurse=True)
@ -295,13 +296,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
autotune_warmup=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu
)
if self.hf_device_map:
remove_hook_from_submodules(self.model)
if device_map:
self.model = remove_hook_from_module(self.model, recurse=True)
self.model = accelerate.dispatch_model(self.model, device_map, offload_buffers=True)
self.model.config.use_cache = forward_pass_use_cache
self._quantized = True
torch.cuda.empty_cache()
@property
def device(self):
return self.model.device

View file

@ -4,7 +4,7 @@ from ._base import *
class GPT2GPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPT2Block"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.wpe","transformer.ln_f"]
outside_layer_modules = ["transformer.wte", "transformer.wpe", "transformer.ln_f"]
inside_layer_modules = [
["attn.c_attn"],
["attn.c_proj"],

View file

@ -178,7 +178,7 @@ class QuantLinear(nn.Module):
if self.quant_cuda_available and (
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
):
out = torch.zeros((x.shape[0], self.outfeatures), device='cuda', dtype=torch.float32)
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
if self.bits == 2:
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
elif self.bits == 3: