fix model type changed after calling .to() method
This commit is contained in:
parent
604c96144f
commit
c7021f0f44
1 changed files with 2 additions and 1 deletions
|
@ -433,7 +433,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
return torch.device(device)
|
||||
|
||||
def to(self, device: Union[str, torch.device]):
|
||||
return self.model.to(device)
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
|
Loading…
Add table
Reference in a new issue