simplified code
This commit is contained in:
parent
5d6862ee8d
commit
722a621aaa
2 changed files with 49 additions and 34 deletions
|
@ -5,21 +5,21 @@ from tqdm import tqdm
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
class Perplexity:
|
class Perplexity:
|
||||||
"""
|
"""
|
||||||
A class for calculating the perplexity of a language model.
|
A class for calculating the perplexity of a language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, tokenizer=None, device='auto', dataset_path='wikitext',
|
def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
|
||||||
dataset_name=None, split='test', text_column='text'):
|
|
||||||
"""
|
"""
|
||||||
Calculate perplexity using the same method as seen in llama.cpp.
|
Calculate perplexity using the same method as seen in llama.cpp.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model : AutoModelForCausalLM or str
|
model : AutoModelForCausalLM
|
||||||
The language model for which the perplexity is calculated.
|
The language model for which the perplexity is calculated.
|
||||||
tokenizer : AutoTokenizer or str
|
tokenizer : AutoTokenizer
|
||||||
The tokenizer corresponding to the model.
|
The tokenizer corresponding to the model.
|
||||||
device : str, optional
|
device : str, optional
|
||||||
The device to run the calculations on. If auto, the device that your model uses
|
The device to run the calculations on. If auto, the device that your model uses
|
||||||
|
@ -33,23 +33,7 @@ class Perplexity:
|
||||||
text_column : str, optional
|
text_column : str, optional
|
||||||
The name of the column in the dataset that contains the text data. Default is 'text'.
|
The name of the column in the dataset that contains the text data. Default is 'text'.
|
||||||
"""
|
"""
|
||||||
|
self._model = model
|
||||||
if tokenizer is None and type(model) == str:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
|
||||||
|
|
||||||
elif type(tokenizer) == str:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
||||||
|
|
||||||
elif tokenizer is None and type(model) != str:
|
|
||||||
raise Exception('Tokenizer cannot be None if model is not a str. Please load a tokenizer first:\n' +
|
|
||||||
'tokenizer = AutoTokenizer.from_pretrained(model_name)')
|
|
||||||
|
|
||||||
if type(model) == str:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model)
|
|
||||||
|
|
||||||
self._device = self._get_device() if device == 'auto' else device
|
|
||||||
self._model = model.to(self._device)
|
|
||||||
|
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._dataset_path = dataset_path
|
self._dataset_path = dataset_path
|
||||||
self._dataset_name = dataset_name
|
self._dataset_name = dataset_name
|
||||||
|
@ -119,7 +103,7 @@ class Perplexity:
|
||||||
"""
|
"""
|
||||||
# Tokenize the text
|
# Tokenize the text
|
||||||
self._tokenizer.model_max_length = sys.maxsize
|
self._tokenizer.model_max_length = sys.maxsize
|
||||||
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._device)
|
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._model.device)
|
||||||
|
|
||||||
nll = 0.0 # Negative log likelihood
|
nll = 0.0 # Negative log likelihood
|
||||||
count = 0 # Counter for processed tokens
|
count = 0 # Counter for processed tokens
|
||||||
|
@ -173,7 +157,7 @@ class Perplexity:
|
||||||
|
|
||||||
for j in range(num_batches):
|
for j in range(num_batches):
|
||||||
batch_start = start + j * n_batch
|
batch_start = start + j * n_batch
|
||||||
batch_size = min(end - batch_start, n_batch)
|
batch_size = min(end - batch_start, n_batch)
|
||||||
|
|
||||||
token_org = tokens[0][batch_start].item()
|
token_org = tokens[0][batch_start].item()
|
||||||
|
|
||||||
|
@ -228,4 +212,4 @@ class Perplexity:
|
||||||
# Compute the logits without keeping track of gradients
|
# Compute the logits without keeping track of gradients
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
|
outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
|
||||||
return outputs.logits.detach()
|
return outputs.logits.detach()
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
from auto_gptq.utils import Perplexity
|
from auto_gptq.utils import Perplexity
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
|
@ -24,32 +27,60 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.")
|
parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.")
|
||||||
parser.add_argument("--n_ctx", type=int, default=512, help="Context size.")
|
parser.add_argument("--n_ctx", type=int, default=512, help="Context size.")
|
||||||
parser.add_argument("--n_batch", type=int, default=512, help="Batch size.")
|
parser.add_argument("--n_batch", type=int, default=512, help="Batch size.")
|
||||||
parser.add_argument("--device", type=str, default="auto", help="Device to use.")
|
|
||||||
parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
|
parser.add_argument("--dataset_path", type=str, default='wikitext', help="Path to the dataset.")
|
||||||
parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
|
parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
|
||||||
parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
|
parser.add_argument("--split", type=str, default='test', help="Dataset split to use.")
|
||||||
parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
|
parser.add_argument("--text_column", type=str, default='text', help="Column in the dataset containing the text.")
|
||||||
parser.add_argument("--is_quantized", action=argparse.BooleanOptionalAction, default=False, help="Is the model GPTQ quantized?")
|
parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="Max memory used in each GPU.")
|
||||||
|
parser.add_argument("--cpu_max_memory", type=int, default=None, help="Mx memory used in CPU.")
|
||||||
|
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
|
||||||
|
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
||||||
|
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
||||||
|
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
|
||||||
|
if not tokenizer.pad_token_id:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
max_memory = dict()
|
||||||
|
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
max_memory.update(
|
||||||
|
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
|
||||||
|
)
|
||||||
|
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
|
||||||
|
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
|
||||||
|
if not max_memory:
|
||||||
|
max_memory = None
|
||||||
|
|
||||||
if args.is_quantized:
|
if args.is_quantized:
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(
|
model = AutoGPTQForCausalLM.from_quantized(
|
||||||
args.model_name,
|
args.model_name,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
device_map="auto",
|
||||||
|
max_memory=max_memory,
|
||||||
model_basename=args.model_basename,
|
model_basename=args.model_basename,
|
||||||
use_safetensors=True,
|
use_safetensors=args.use_safetensors,
|
||||||
trust_remote_code=True
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
inject_fused_mlp=False,
|
||||||
|
inject_fused_attention=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.model_name)
|
args.model_name,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
device_map="auto",
|
||||||
|
max_memory=max_memory,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
trust_remote_code=args.trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
ppl = Perplexity(model, tokenizer, args.device, args.dataset_path, args.dataset_name, args.split, args.text_column)
|
ppl = Perplexity(model, tokenizer, args.dataset_path, args.dataset_name, args.split, args.text_column)
|
||||||
ppl.calculate_perplexity(args.n_ctx, args.n_batch)
|
ppl.calculate_perplexity(args.n_ctx, args.n_batch)
|
||||||
|
|
Loading…
Add table
Reference in a new issue