diff --git a/auto_gptq/quantization/gptq.py b/auto_gptq/quantization/gptq.py index ef72f10..989a8f4 100644 --- a/auto_gptq/quantization/gptq.py +++ b/auto_gptq/quantization/gptq.py @@ -23,7 +23,7 @@ class GPTQ: W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): + if isinstance(self.layer, transformers.pytorch_utils.Conv1D): W = W.t() self.rows = W.shape[0] self.columns = W.shape[1]