duplicate code remove
This commit is contained in:
parent
f23a06f911
commit
dafdd6189a
1 changed files with 0 additions and 22 deletions
|
@ -127,28 +127,6 @@ def process_zeros_scales(zeros, scales, bits, out_features):
|
|||
|
||||
return new_zeros, new_scales
|
||||
|
||||
def process_zeros_scales(zeros, scales, bits, out_features):
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != out_features:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != out_features:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
return new_zeros, new_scales
|
||||
|
||||
def preprocess_checkpoint_qigen(
|
||||
module,
|
||||
names,
|
||||
|
|
Loading…
Add table
Reference in a new issue