bug fix quantization demo

This commit is contained in:
qwopqwop200 2023-05-01 08:03:11 +09:00 committed by GitHub
parent e2c7cd4fb3
commit d986a738e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -28,13 +28,14 @@ def get_wikitext2(nsamples, seed, seqlen, model):
np.random.seed(0) np.random.seed(0)
torch.random.manual_seed(0) torch.random.manual_seed(0)
trainloader = [] traindataset = []
for _ in range(nsamples): for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen j = i + seqlen
inp = trainenc.input_ids[:, i:j] inp = trainenc.input_ids[:, i:j]
trainloader.append({'input_ids':inp}) attention_mask = torch.ones_like(inp)
return trainloader, testenc traindataset.append({'input_ids':inp,'attention_mask': attention_mask})
return traindataset, testenc
@torch.no_grad() @torch.no_grad()
def opt_eval(model, testenc, dev, seqlen = 2048): def opt_eval(model, testenc, dev, seqlen = 2048):
@ -131,7 +132,7 @@ def opt_eval(model, testenc, dev, seqlen = 2048):
def main(): def main():
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
trainloader,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir) traindataset,testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
quantize_config = BaseQuantizeConfig( quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit bits=4, # quantize model to 4-bit
@ -143,7 +144,7 @@ def main():
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type. # with value under torch.LongTensor type.
model.quantize(trainloader, use_triton=False) model.quantize(traindataset, use_triton=False)
# save quantized model # save quantized model
model.save_quantized(quantized_model_dir) model.save_quantized(quantized_model_dir)