47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
import math
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from torch import LongTensor
|
|
|
|
from ._base import BaseTask
|
|
|
|
|
|
class LanguageModelingTask(BaseTask):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
tokenizer,
|
|
data_name_or_path: str,
|
|
prompt_col_name: str,
|
|
label_col_name: str,
|
|
device: Optional[str] = None,
|
|
**kwargs
|
|
):
|
|
kwargs["merge_prompt_label"] = True
|
|
super().__init__(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
data_name_or_path=data_name_or_path,
|
|
prompt_col_name=prompt_col_name,
|
|
label_col_name=label_col_name,
|
|
device=device,
|
|
**kwargs
|
|
)
|
|
|
|
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[float]:
|
|
outputs = self.model(**batch_data)
|
|
loss = outputs.loss.cpu().item()
|
|
|
|
return [loss]
|
|
|
|
def _parse_labels(self, label_ids: LongTensor) -> List[Any]:
|
|
return []
|
|
|
|
def _metric(self, pred: List[Any], label: List[Any]) -> Dict[str, float]:
|
|
return {"ppl": math.exp(sum(pred) / len(pred))}
|
|
|
|
def run(self) -> Dict[str, float]:
|
|
return super().run()
|
|
|
|
|
|
__all__ = ["LanguageModelingTask"]
|