fix model type changed after calling .to() method

This commit is contained in:
PanQiWei 2023-08-31 18:39:03 +08:00
parent 604c96144f
commit c7021f0f44

View file

@ -433,7 +433,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
return torch.device(device)
def to(self, device: Union[str, torch.device]):
return self.model.to(device)
self.model.to(device)
return self
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)