simplified code
This commit is contained in:
parent
5d6862ee8d
commit
722a621aaa
2 changed files with 49 additions and 34 deletions
|
@ -5,21 +5,21 @@ from tqdm import tqdm
|
|||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
class Perplexity:
|
||||
"""
|
||||
A class for calculating the perplexity of a language model.
|
||||
"""
|
||||
|
||||
def __init__(self, model, tokenizer=None, device='auto', dataset_path='wikitext',
|
||||
dataset_name=None, split='test', text_column='text'):
|
||||
def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
|
||||
"""
|
||||
Calculate perplexity using the same method as seen in llama.cpp.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : AutoModelForCausalLM or str
|
||||
model : AutoModelForCausalLM
|
||||
The language model for which the perplexity is calculated.
|
||||
tokenizer : AutoTokenizer or str
|
||||
tokenizer : AutoTokenizer
|
||||
The tokenizer corresponding to the model.
|
||||
device : str, optional
|
||||
The device to run the calculations on. If auto, the device that your model uses
|
||||
|
@ -33,23 +33,7 @@ class Perplexity:
|
|||
text_column : str, optional
|
||||
The name of the column in the dataset that contains the text data. Default is 'text'.
|
||||
"""
|
||||
|
||||
if tokenizer is None and type(model) == str:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
elif type(tokenizer) == str:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
||||
|
||||
elif tokenizer is None and type(model) != str:
|
||||
raise Exception('Tokenizer cannot be None if model is not a str. Please load a tokenizer first:\n' +
|
||||
'tokenizer = AutoTokenizer.from_pretrained(model_name)')
|
||||
|
||||
if type(model) == str:
|
||||
model = AutoModelForCausalLM.from_pretrained(model)
|
||||
|
||||
self._device = self._get_device() if device == 'auto' else device
|
||||
self._model = model.to(self._device)
|
||||
|
||||
self._model = model
|
||||
self._tokenizer = tokenizer
|
||||
self._dataset_path = dataset_path
|
||||
self._dataset_name = dataset_name
|
||||
|
@ -119,7 +103,7 @@ class Perplexity:
|
|||
"""
|
||||
# Tokenize the text
|
||||
self._tokenizer.model_max_length = sys.maxsize
|
||||
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._device)
|
||||
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._model.device)
|
||||
|
||||
nll = 0.0 # Negative log likelihood
|
||||
count = 0 # Counter for processed tokens
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from auto_gptq.utils import Perplexity
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
|
@ -24,32 +27,60 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.")
|
||||
parser.add_argument("--n_ctx", type=int, default=512, help="Context size.")
|
||||
parser.add_argument("--n_batch", type=int, default=512, help="Batch size.")
|
||||
parser.add_argument("--device", type=str, default="auto", help="Device to use.")
|
||||
parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
|
||||
parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
|
||||
parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
|
||||
parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
|
||||
parser.add_argument("--is_quantized", action=argparse.BooleanOptionalAction, default=False, help="Is the model GPTQ quantized?")
|
||||
parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.")
|
||||
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
|
||||
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
||||
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
|
||||
if not tokenizer.pad_token_id:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
max_memory = dict()
|
||||
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
|
||||
if torch.cuda.is_available():
|
||||
max_memory.update(
|
||||
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
|
||||
)
|
||||
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
|
||||
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
|
||||
if not max_memory:
|
||||
max_memory = None
|
||||
|
||||
if args.is_quantized:
|
||||
from transformers import AutoTokenizer
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
args.model_name,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
model_basename=args.model_basename,
|
||||
use_safetensors=True,
|
||||
trust_remote_code=True
|
||||
use_safetensors=args.use_safetensors,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
inject_fused_mlp=False,
|
||||
inject_fused_attention=False
|
||||
)
|
||||
else:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
|
||||
ppl = Perplexity(model, tokenizer, args.device, args.dataset_path, args.dataset_name, args.split, args.text_column)
|
||||
ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column)
|
||||
ppl.calculate_perplexity(args.n_ctx, args.n_batch)
|
||||
|
|
Loading…
Add table
Reference in a new issue