remove override of _resize_attention_mask for llama and opt
This commit is contained in:
parent
1d91fded6c
commit
b490ab004e
2 changed files with 0 additions and 10 deletions
|
@ -12,10 +12,5 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
["mlp.down_proj"]
|
["mlp.down_proj"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resize_attention_mask(attention_mask):
|
|
||||||
attention_mask = [each.unsqueeze(1) for each in attention_mask]
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["LlamaGPTQForCausalLM"]
|
__all__ = ["LlamaGPTQForCausalLM"]
|
||||||
|
|
|
@ -15,10 +15,5 @@ class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
["fc2"]
|
["fc2"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resize_attention_mask(attention_mask):
|
|
||||||
attention_mask = [each.unsqueeze(1) for each in attention_mask]
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OPTGPTQForCausalLM"]
|
__all__ = ["OPTGPTQForCausalLM"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue