duplicate code remove
This commit is contained in:
parent
f23a06f911
commit
dafdd6189a
1 changed files with 0 additions and 22 deletions
|
@ -127,28 +127,6 @@ def process_zeros_scales(zeros, scales, bits, out_features):
|
||||||
|
|
||||||
return new_zeros, new_scales
|
return new_zeros, new_scales
|
||||||
|
|
||||||
def process_zeros_scales(zeros, scales, bits, out_features):
|
|
||||||
if zeros.dtype != torch.float32:
|
|
||||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
|
||||||
if bits == 4:
|
|
||||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
|
||||||
elif bits == 2:
|
|
||||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
|
||||||
elif bits == 3:
|
|
||||||
logger.info("Unpacking zeros for 3 bits")
|
|
||||||
new_scales = scales.contiguous()
|
|
||||||
else:
|
|
||||||
if scales.shape[1] != out_features:
|
|
||||||
new_scales = scales.transpose(0,1).contiguous()
|
|
||||||
else:
|
|
||||||
new_scales = scales.contiguous()
|
|
||||||
if zeros.shape[1] != out_features:
|
|
||||||
new_zeros = zeros.transpose(0,1).contiguous()
|
|
||||||
else:
|
|
||||||
new_zeros = zeros.contiguous()
|
|
||||||
|
|
||||||
return new_zeros, new_scales
|
|
||||||
|
|
||||||
def preprocess_checkpoint_qigen(
|
def preprocess_checkpoint_qigen(
|
||||||
module,
|
module,
|
||||||
names,
|
names,
|
||||||
|
|
Loading…
Add table
Reference in a new issue