bug fix quantization demo
This commit is contained in:
parent
e2c7cd4fb3
commit
d986a738e1
1 changed files with 6 additions and 5 deletions
|
@ -28,13 +28,14 @@ def get_wikitext2(nsamples, seed, seqlen, model):
|
|||
np.random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
trainloader = []
|
||||
traindataset = []
|
||||
for _ in range(nsamples):
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
trainloader.append({'input_ids':inp})
|
||||
return trainloader, testenc
|
||||
attention_mask = torch.ones_like(inp)
|
||||
traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
|
||||
return traindataset, testenc
|
||||
|
||||
@torch.no_grad()
|
||||
def opt_eval(model, testenc, dev, seqlen = 2048):
|
||||
|
@ -131,7 +132,7 @@ def opt_eval(model, testenc, dev, seqlen = 2048):
|
|||
|
||||
def main():
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
||||
trainloader,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
|
||||
traindataset,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
|
||||
|
||||
quantize_config = BaseQuantizeConfig(
|
||||
bits=4, # quantize model to 4-bit
|
||||
|
@ -143,7 +144,7 @@ def main():
|
|||
|
||||
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
|
||||
# with value under torch.LongTensor type.
|
||||
model.quantize(trainloader, use_triton=False)
|
||||
model.quantize(traindataset, use_triton=False)
|
||||
|
||||
# save quantized model
|
||||
model.save_quantized(quantized_model_dir)
|
||||
|
|
Loading…
Add table
Reference in a new issue