fix device mismatch when directly using model to inference after quantization
This commit is contained in:
parent
892eeb40e0
commit
a69a73a22c
4 changed files with 11 additions and 7 deletions
|
@ -1,6 +1,7 @@
|
|||
from ._base import BaseGPTQForCausalLM, BaseQuantizeConfig
|
||||
from .auto import *
|
||||
from .bloom import *
|
||||
from .gpt2 import *
|
||||
from .gpt_neox import *
|
||||
from .gptj import *
|
||||
from .llama import *
|
||||
|
|
|
@ -99,8 +99,9 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
if self.quantized:
|
||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||
|
||||
if self.hf_device_map:
|
||||
for name, device in self.hf_device_map.items():
|
||||
device_map = self.hf_device_map
|
||||
if device_map:
|
||||
for name, device in device_map.items():
|
||||
if device == "cpu":
|
||||
module = get_module_by_name(self.model, name)
|
||||
remove_hook_from_module(module, recurse=True)
|
||||
|
@ -295,13 +296,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
autotune_warmup=autotune_warmup_after_quantized,
|
||||
force_layer_back_to_cpu=force_layer_back_to_cpu
|
||||
)
|
||||
|
||||
if self.hf_device_map:
|
||||
remove_hook_from_submodules(self.model)
|
||||
if device_map:
|
||||
self.model = remove_hook_from_module(self.model, recurse=True)
|
||||
self.model = accelerate.dispatch_model(self.model, device_map, offload_buffers=True)
|
||||
self.model.config.use_cache = forward_pass_use_cache
|
||||
|
||||
self._quantized = True
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.model.device
|
||||
|
|
|
@ -4,7 +4,7 @@ from ._base import *
|
|||
class GPT2GPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "GPT2Block"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = ["transformer.wte", "transformer.wpe","transformer.ln_f"]
|
||||
outside_layer_modules = ["transformer.wte", "transformer.wpe", "transformer.ln_f"]
|
||||
inside_layer_modules = [
|
||||
["attn.c_attn"],
|
||||
["attn.c_proj"],
|
||||
|
|
|
@ -178,7 +178,7 @@ class QuantLinear(nn.Module):
|
|||
if self.quant_cuda_available and (
|
||||
self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold
|
||||
):
|
||||
out = torch.zeros((x.shape[0], self.outfeatures), device='cuda', dtype=torch.float32)
|
||||
out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
|
||||
if self.bits == 2:
|
||||
quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx)
|
||||
elif self.bits == 3:
|
||||
|
|
Loading…
Add table
Reference in a new issue