AutoGPTQ/auto_gptq/utils/data_utils.py
2023-04-28 18:08:58 +08:00

267 lines
11 KiB
Python

import copy
import random
from functools import partial
from typing import Callable, Dict, List, Optional
import torch
from datasets import load_dataset, DatasetDict, IterableDatasetDict
from torch import LongTensor
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer
def make_data_block(
samples: Dict[str, List[str]],
prompt_col_name: str,
label_col_name: str,
tokenizer: PreTrainedTokenizer,
preprocess_fn: Optional[Callable] = None,
sample_max_len: int = 1024,
block_max_len: int = 2048,
add_eos_token: bool = False,
truncate_prompt: bool = True,
merge_prompt_label: bool = False
) -> Dict[str, List[LongTensor]]:
"""A simple implementation of text generation oriented smart batching to maximize VRAM usage when evaluation
:param samples: Dict[str, List[str]], samples that used to make data blocks
:param prompt_col_name: str, name of the key in samples whose value stores prompt
:param label_col_name: str, name of the key in samples whose value stores label
:param tokenizer: transformers.PretrainedTokenizer, tokenizer that used to tokenize samples
:param preprocess_fn: Optional[Callable], optional function that used to preprocess samples such as
refactor the data structure of samples, note the output of this function must be a dict whose keys
at least contains `prompt_col_name` and `label_col_name`
:param sample_max_len: int, defaults to 1024, max tokens number of each sample (before padding)
:param block_max_len: int, defaults to 2048, max tokens number of each data block (after padding)
:param add_eos_token: bool, defaults to False, whether add eos_token or not to the label
:param truncate_prompt: bool, defaults to True, whether to truncate prompt if the sample's total tokens
number exceeds `sample_max_len`, if not, will truncate label and drop this sample when all tokens
in label are truncated
:param merge_prompt_label: bool, defaults to False, will merge label into prompt if set to True, usually
this only required when doing language modeling task
:return: Dict[str, List[torch.LongTensor]], a dict whose keys are `input_ids`, `attention_mask` and
`label` and values are a list of torch.LongTensor
"""
if preprocess_fn:
samples = preprocess_fn(samples)
prompts = samples[prompt_col_name]
labels = samples[label_col_name]
# tokenize samples
tokenized_prompts = tokenizer(prompts, truncation=False)["input_ids"]
tokenized_labels = tokenizer(labels, truncation=False)["input_ids"]
# filter tokenized samples by length
dropped_indices = []
for idx, (tokenized_prompt, tokenized_label) in enumerate(zip(tokenized_prompts, tokenized_labels)):
if add_eos_token:
tokenized_label += [tokenizer.eos_token_id]
len_prompt = len(tokenized_prompt)
len_label = len(tokenized_label)
exceed_len = len_prompt + len_label - sample_max_len
if exceed_len > 0:
if truncate_prompt:
tokenized_prompt = tokenized_prompt[exceed_len:]
else:
tokenized_label = tokenized_label[: -exceed_len]
tokenized_prompts[idx] = tokenized_prompt
tokenized_labels[idx] = tokenized_label
if not tokenized_label:
dropped_indices.append(idx)
# make data blocks of samples
tokenized_samples = sorted(
[
(p, l) for idx, (p, l) in enumerate(zip(tokenized_prompts, tokenized_labels))
if idx not in dropped_indices
],
key=lambda x: (len(x[0]) + len(x[1])) if merge_prompt_label else len(x[0])
)
sample_blocks = []
sample_block = []
blk_max_len = 0
blk_total_len = 0
for tokenized_sample in tokenized_samples:
prompt_ids, label_ids = tokenized_sample
ori_sample_len = len(prompt_ids)
if merge_prompt_label:
ori_sample_len += len(label_ids)
if ori_sample_len <= blk_max_len:
additional_len = blk_max_len
sample_len = blk_max_len
else:
additional_len = len(sample_block) * (ori_sample_len - blk_max_len) + ori_sample_len
sample_len = ori_sample_len
if blk_total_len + additional_len > block_max_len:
sample_blocks.append((copy.copy(sample_block), blk_max_len))
sample_block = []
blk_max_len = 0
blk_total_len = 0
sample_len = ori_sample_len
additional_len = ori_sample_len
sample_block.append(tokenized_sample)
blk_max_len = max(blk_max_len, sample_len)
blk_total_len += additional_len
if sample_block:
sample_blocks.append((copy.copy(sample_block), blk_max_len))
del sample_block
del blk_max_len
del blk_total_len
new_samples = {
"input_ids": [],
"attention_mask": [],
"labels": []
}
# padding each data block internally
for block, blk_max_len in sample_blocks:
input_ids = []
attention_mask = []
label_ids = []
label_max_len = max([len(sample[1]) for sample in block])
for sample in block:
tokenized_prompt, tokenized_label = sample
sample_len = len(tokenized_prompt)
if merge_prompt_label:
sample_len += len(tokenized_label)
pad_num = blk_max_len - sample_len
if merge_prompt_label:
input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt + tokenized_label)
label_ids.append([-100] * (pad_num + len(tokenized_prompt)) + tokenized_label)
else:
input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt)
label_ids.append([-100] * (label_max_len - len(tokenized_label)) + tokenized_label)
attention_mask.append([0] * pad_num + [1] * sample_len)
new_samples["input_ids"].append(input_ids)
new_samples["attention_mask"].append(attention_mask)
new_samples["labels"].append(label_ids)
return new_samples
def collate_data(blocks: List[Dict[str, List[List[int]]]], pad_token_id: int) -> Dict[str, LongTensor]:
def pad_block(block, pads):
return torch.cat((pads.to(block.device), block), dim=-1)
input_ids_blocks = [LongTensor(block["input_ids"]) for block in blocks]
attention_mask_blocks = [LongTensor(block["attention_mask"]) for block in blocks]
label_blocks = [LongTensor(block["labels"]) for block in blocks]
bsz = len(blocks)
inp_max_len = max([block.size(-1) for block in input_ids_blocks])
label_max_len = max([block.size(-1) for block in label_blocks])
for i in range(bsz):
block_bsz, block_inp_len = input_ids_blocks[i].shape
block_label_len = label_blocks[i].shape[-1]
pad_num = inp_max_len - block_inp_len
if pad_num > 0:
input_ids_blocks[i] = pad_block(input_ids_blocks[i], torch.ones((block_bsz, pad_num)) * pad_token_id)
attention_mask_blocks[i] = pad_block(attention_mask_blocks[i], torch.zeros((block_bsz, pad_num)))
label_pad_num = label_max_len - block_label_len
if label_pad_num > 0:
label_blocks[i] = pad_block(label_blocks[i], torch.ones((block_bsz, label_pad_num)) * -100)
return {
"input_ids": torch.cat(input_ids_blocks, dim=0).long(),
"attention_mask": torch.cat(attention_mask_blocks, dim=0).long(),
"labels": torch.cat(label_blocks, dim=0).long()
}
def get_dataloader(
data_path_or_name: str,
prompt_col_name: str,
label_col_name: str,
tokenizer: PreTrainedTokenizer,
load_fn: Optional[Callable] = None,
preprocess_fn: Optional[Callable] = None,
num_samples: int = 128,
sample_max_len: int = 1024,
block_max_len: int = 2048,
add_eos_token: bool = False,
truncate_prompt: bool = True,
merge_prompt_label: bool = False,
load_fn_kwargs: Optional[dict] = None,
preprocess_fn_kwargs: Optional[dict] = None,
**kwargs
) -> DataLoader:
"""load dataset and build dataloader
:param data_path_or_name: str, dataset name in hf-hub or local file path
:param prompt_col_name: str, see `make_data_block`
:param label_col_name: str, see `make_data_block`
:param tokenizer: str, see `make_data_block`
:param load_fn: Optional[Callable], defaults to None, function used to load dataset, if not specified,
use `datasets.load_dataset`
:param preprocess_fn: Optional[Callable], see `make_data_block`
:param num_samples: int, defaults to 128, total samples used to evaluation
:param sample_max_len: int, see `make_data_block`
:param block_max_len: int, see `make_data_block`
:param add_eos_token: bool, see `make_data_block`
:param truncate_prompt: bool, see `make_data_block`
:param merge_prompt_label: bool, see `make_data_block`
:param load_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
for `load_fn` or `datasets.load_dataset`
:param preprocess_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
for `preprocess_fn`
:param kwargs: additional keyword arguments will be passed to torch's `DataLoader` initialization,
note values of `batch_size`, `shuffle` and `collate_fn` will always be overridden to fixed value
:return: torch.utils.data.DataLoader
"""
if not load_fn_kwargs:
load_fn_kwargs = dict()
if not preprocess_fn_kwargs:
preprocess_fn_kwargs = dict()
if load_fn:
ds = load_fn(data_path_or_name, **load_fn_kwargs)
else:
ds = load_dataset(data_path_or_name, **load_fn_kwargs)
if isinstance(ds, (DatasetDict, IterableDatasetDict)):
if "evaluation" in ds:
ds = ds["evaluation"]
elif "test" in ds:
ds = ds["test"]
else:
ds = ds["train"]
ds = ds.select(indices=random.sample(range(len(ds)), min(len(ds), num_samples)), keep_in_memory=True)
ds = ds.map(
make_data_block,
batched=True,
batch_size=len(ds),
num_proc=1,
remove_columns=ds.column_names,
keep_in_memory=True,
load_from_cache_file=False,
fn_kwargs={
"prompt_col_name": prompt_col_name,
"label_col_name": label_col_name,
"tokenizer": tokenizer,
"preprocess_fn": partial(preprocess_fn, **preprocess_fn_kwargs),
"sample_max_len": sample_max_len,
"block_max_len": block_max_len,
"add_eos_token": add_eos_token,
"truncate_prompt": truncate_prompt,
"merge_prompt_label": merge_prompt_label
}
)
# override some arguments' values in kwargs despite user specified
kwargs["batch_size"] = 1
kwargs["shuffle"] = False
kwargs["collate_fn"] = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
dl = DataLoader(ds, **kwargs)
return dl
__all__ = ["make_data_block", "collate_data", "get_dataloader"]