Merge pull request #310 from PanQiWei/fix_to()_metod_bug
fix model type changed after calling .to() method
This commit is contained in:
commit
1e938e6bad
1 changed files with 2 additions and 1 deletions
|
@ -433,7 +433,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
return torch.device(device)
|
return torch.device(device)
|
||||||
|
|
||||||
def to(self, device: Union[str, torch.device]):
|
def to(self, device: Union[str, torch.device]):
|
||||||
return self.model.to(device)
|
self.model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.model(*args, **kwargs)
|
return self.model(*args, **kwargs)
|
||||||
|
|
Loading…
Add table
Reference in a new issue