Merge pull request #310 from PanQiWei/fix_to()_metod_bug

fix model type changed after calling .to() method
This commit is contained in:
潘其威(William) 2023-08-31 19:04:02 +08:00 committed by GitHub
commit 1e938e6bad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -433,7 +433,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
return torch.device(device) return torch.device(device)
def to(self, device: Union[str, torch.device]): def to(self, device: Union[str, torch.device]):
return self.model.to(device) self.model.to(device)
return self
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.model(*args, **kwargs) return self.model(*args, **kwargs)