bug fix quantization demo
This commit is contained in:
parent
e2c7cd4fb3
commit
d986a738e1
1 changed files with 6 additions and 5 deletions
|
@ -28,13 +28,14 @@ def get_wikitext2(nsamples, seed, seqlen, model):
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
|
|
||||||
trainloader = []
|
traindataset = []
|
||||||
for _ in range(nsamples):
|
for _ in range(nsamples):
|
||||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||||
j = i + seqlen
|
j = i + seqlen
|
||||||
inp = trainenc.input_ids[:, i:j]
|
inp = trainenc.input_ids[:, i:j]
|
||||||
trainloader.append({'input_ids':inp})
|
attention_mask = torch.ones_like(inp)
|
||||||
return trainloader, testenc
|
traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
|
||||||
|
return traindataset, testenc
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def opt_eval(model, testenc, dev, seqlen = 2048):
|
def opt_eval(model, testenc, dev, seqlen = 2048):
|
||||||
|
@ -131,7 +132,7 @@ def opt_eval(model, testenc, dev, seqlen = 2048):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
||||||
trainloader,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
|
traindataset,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
|
||||||
|
|
||||||
quantize_config = BaseQuantizeConfig(
|
quantize_config = BaseQuantizeConfig(
|
||||||
bits=4, # quantize model to 4-bit
|
bits=4, # quantize model to 4-bit
|
||||||
|
@ -143,7 +144,7 @@ def main():
|
||||||
|
|
||||||
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
|
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
|
||||||
# with value under torch.LongTensor type.
|
# with value under torch.LongTensor type.
|
||||||
model.quantize(trainloader, use_triton=False)
|
model.quantize(traindataset, use_triton=False)
|
||||||
|
|
||||||
# save quantized model
|
# save quantized model
|
||||||
model.save_quantized(quantized_model_dir)
|
model.save_quantized(quantized_model_dir)
|
||||||
|
|
Loading…
Add table
Reference in a new issue