AutoGPTQ/auto_gptq/utils/perplexity_utils.py
2023-07-26 17:53:47 +08:00

215 lines
6.9 KiB
Python

import sys
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
class Perplexity:
"""
A class for calculating the perplexity of a language model.
"""
def __init__(self, model, tokenizer, dataset_path='wikitext', dataset_name=None, split='test', text_column='text'):
"""
Calculate perplexity using the same method as seen in llama.cpp.
Parameters
----------
model : AutoModelForCausalLM
The language model for which the perplexity is calculated.
tokenizer : AutoTokenizer
The tokenizer corresponding to the model.
device : str, optional
The device to run the calculations on. If auto, the device that your model uses
will be the device used for these calculations. Default is 'auto'.
dataset_path : str, optional
The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.
dataset_name : str, optional
The name of the dataset. Default is None.
split : str, optional
The split of the dataset to use. Default is 'test'.
text_column : str, optional
The name of the column in the dataset that contains the text data. Default is 'text'.
"""
self._model = model
self._tokenizer = tokenizer
self._dataset_path = dataset_path
self._dataset_name = dataset_name
self._split = split
self._text_column = text_column
self._text = self._prepare_data()
def _get_device(self):
if torch.backends.mps.is_available():
return 'mps'
elif torch.cuda.is_available():
return 'cuda:0'
else:
return 'cpu'
def _prepare_data(self):
"""
Prepares the dataset by loading and formatting.
Returns
-------
str
The formatted dataset as a single string.
"""
if self._dataset_path == 'wikitext':
self._dataset_name = 'wikitext-2-raw-v1'
# Load the dataset
data = load_dataset(self._dataset_path, self._dataset_name, split=self._split)
# Format the text column of the dataset
text_list = [' \n' if s == '' else s for s in data[self._text_column]]
return ''.join(text_list)
@staticmethod
def softmax(logits):
"""
Static method for applying the softmax function.
Parameters
----------
logits : np.ndarray
The input to the softmax function.
Returns
-------
np.ndarray
The output of the softmax function.
"""
e_x = np.exp(logits - np.max(logits))
return e_x / e_x.sum(axis=0)
def calculate_perplexity(self, n_ctx=512, n_batch=512):
"""
Calculates the perplexity of the language model.
Parameters
----------
n_ctx : int
The context size.
n_batch : int
The batch size.
Returns
-------
list
The list of perplexity scores calculated.
"""
# Tokenize the text
self._tokenizer.model_max_length = sys.maxsize
tokens = self._tokenizer(self._text, truncation=False, return_tensors='pt').input_ids.to(self._model.device)
nll = 0.0 # Negative log likelihood
count = 0 # Counter for processed tokens
curr_ppl = 0
all_perplexity = []
with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress:
for i in progress:
# Process each batch of tokens
nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count)
# Calculate and display the current perplexity
curr_ppl = np.exp(nll / count)
all_perplexity.append(curr_ppl)
progress.set_description(f"Perplexity: {curr_ppl:.4f}")
return all_perplexity
def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count):
"""
Processes each batch of tokens.
Parameters
----------
i : int
The batch index.
n_ctx : int
The context size.
n_batch : int
The batch size.
tokens : torch.Tensor
The tokenized text.
nll : float
The current negative log likelihood.
count : int
The current count of processed tokens.
Returns
-------
float
The updated negative log likelihood.
int
The updated count of processed tokens.
"""
start = i * n_ctx
end = start + n_ctx
num_batches = (n_ctx + n_batch - 1) // n_batch
logits = []
for j in range(num_batches):
batch_start = start + j * n_batch
batch_size = min(end - batch_start, n_batch)
token_org = tokens[0][batch_start].item()
if j == 0:
# Replace the first token with the BOS token
tokens[0][batch_start] = self._tokenizer.bos_token_id
# Compute the logits for the current batch of tokens
batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)
tokens[0][batch_start] = token_org
logits.append(batch_logits)
# We rely on the fact that attention in the forward pass only looks at previous
# tokens here, so the logits returned for each token are an accurate representation
# of what the model would have predicted at that point.
#
# Example, we have a context window of 512, we will compute perplexity for each of the
# last 256 tokens. Then, we split the input up into context window size chunks to
# process the entire prompt.
for j in range(min(512, n_ctx // 2), n_ctx - 1):
tok_logits = logits[0][0][j].cpu().numpy()
# Compute the probability of the next token
prob = self.softmax(tok_logits)[tokens[0][start + j + 1]]
# Update the negative log likelihood and the count of processed tokens
nll += -np.log(prob, where=prob>0)
count += 1
return nll, count
def _compute_batch_logits(self, tokens, batch_start, batch_size):
"""
Computes the logits for a batch of tokens.
Parameters
----------
tokens : torch.Tensor
The tokenized text.
batch_start : int
The start index of the batch.
batch_size : int
The size of the batch.
Returns
-------
torch.Tensor
The logits for the batch of tokens.
"""
# Compute the logits without keeping track of gradients
with torch.no_grad():
outputs = self._model(tokens[:, batch_start:batch_start+batch_size])
return outputs.logits.detach()