simplified code

This commit is contained in:
PanQiWei 2023-07-26 17:53:47 +08:00
parent 5d6862ee8d
commit 722a621aaa
2 changed files with 49 additions and 34 deletions

View file

@ -5,21 +5,21 @@ from tqdm import tqdm
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
class Perplexity: class Perplexity:
""" """
A class for calculating the perplexity of a language model. A class for calculating the perplexity of a language model.
""" """
def __init__(self, model, tokenizer=None, device='auto', dataset_path='wikitext', def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
dataset_name=None, split='test', text_column='text'):
""" """
Calculate perplexity using the same method as seen in llama.cpp. Calculate perplexity using the same method as seen in llama.cpp.
Parameters Parameters
---------- ----------
model : AutoModelForCausalLM or str model : AutoModelForCausalLM
The language model for which the perplexity is calculated. The language model for which the perplexity is calculated.
tokenizer : AutoTokenizer or str tokenizer : AutoTokenizer
The tokenizer corresponding to the model. The tokenizer corresponding to the model.
device : str, optional device : str, optional
The device to run the calculations on. If auto, the device that your model uses The device to run the calculations on. If auto, the device that your model uses
@ -33,23 +33,7 @@ class Perplexity:
text_column : str, optional text_column : str, optional
The name of the column in the dataset that contains the text data. Default is 'text'. The name of the column in the dataset that contains the text data. Default is 'text'.
""" """
self._model = model
if tokenizer is None and type(model) == str:
tokenizer = AutoTokenizer.from_pretrained(model)
elif type(tokenizer) == str:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
elif tokenizer is None and type(model) != str:
raise Exception('Tokenizer cannot be None if model is not a str. Please load a tokenizer first:\n' +
'tokenizer = AutoTokenizer.from_pretrained(model_name)')
if type(model) == str:
model = AutoModelForCausalLM.from_pretrained(model)
self._device = self._get_device() if device == 'auto' else device
self._model = model.to(self._device)
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._dataset_path = dataset_path self._dataset_path = dataset_path
self._dataset_name = dataset_name self._dataset_name = dataset_name
@ -119,7 +103,7 @@ class Perplexity:
""" """
# Tokenize the text # Tokenize the text
self._tokenizer.model_max_length = sys.maxsize self._tokenizer.model_max_length = sys.maxsize
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._device) tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._model.device)
nll = 0.0 # Negative log likelihood nll = 0.0 # Negative log likelihood
count = 0 # Counter for processed tokens count = 0 # Counter for processed tokens
@ -173,7 +157,7 @@ class Perplexity:
for j in range(num_batches): for j in range(num_batches):
batch_start = start + j * n_batch batch_start = start + j * n_batch
batch_size = min(end - batch_start, n_batch) batch_size = min(end - batch_start, n_batch)
token_org = tokens[0][batch_start].item() token_org = tokens[0][batch_start].item()
@ -228,4 +212,4 @@ class Perplexity:
# Compute the logits without keeping track of gradients # Compute the logits without keeping track of gradients
with torch.no_grad(): with torch.no_grad():
outputs = self._model(tokens[:, batch_start:batch_start+batch_size]) outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
return outputs.logits.detach() return outputs.logits.detach()

View file

@ -1,6 +1,9 @@
import os import os
import argparse import argparse
import torch
from auto_gptq.utils import Perplexity from auto_gptq.utils import Perplexity
from transformers import AutoTokenizer
if __name__ == "__main__": if __name__ == "__main__":
""" """
@ -24,32 +27,60 @@ if __name__ == "__main__":
parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.") parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.")
parser.add_argument("--n_ctx", type=int, default=512, help="Context size.") parser.add_argument("--n_ctx", type=int, default=512, help="Context size.")
parser.add_argument("--n_batch", type=int, default=512, help="Batch size.") parser.add_argument("--n_batch", type=int, default=512, help="Batch size.")
parser.add_argument("--device", type=str, default="auto", help="Device to use.")
parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.") parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.") parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
parser.add_argument("--split", type=str, default='test', help="Dataset split to use.") parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.") parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
parser.add_argument("--is_quantized", action=argparse.BooleanOptionalAction, default=False, help="Is the model GPTQ quantized?") parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.")
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
args = parser.parse_args() args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id
max_memory = dict()
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
if torch.cuda.is_available():
max_memory.update(
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
)
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
if not max_memory:
max_memory = None
if args.is_quantized: if args.is_quantized:
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoGPTQForCausalLM.from_quantized( model = AutoGPTQForCausalLM.from_quantized(
args.model_name, args.model_name,
low_cpu_mem_usage=True,
device_map="auto",
max_memory=max_memory,
model_basename=args.model_basename, model_basename=args.model_basename,
use_safetensors=True, use_safetensors=args.use_safetensors,
trust_remote_code=True trust_remote_code=args.trust_remote_code,
inject_fused_mlp=False,
inject_fused_attention=False
) )
else: else:
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(args.model_name) args.model_name,
low_cpu_mem_usage=True,
device_map="auto",
max_memory=max_memory,
torch_dtype=torch.float16,
trust_remote_code=args.trust_remote_code
)
ppl = Perplexity(model, tokenizer, args.device, args.dataset_path, args.dataset_name, args.split, args.text_column) ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column)
ppl.calculate_perplexity(args.n_ctx, args.n_batch) ppl.calculate_perplexity(args.n_ctx, args.n_batch)