empty cache before switch model

This commit is contained in:
PanQiWei 2023-04-26 15:22:30 +08:00
parent 1367677c45
commit a8e748c511
3 changed files with 9 additions and 5 deletions

View file

@ -1,9 +1,10 @@
import datasets
from argparse import ArgumentParser
from transformers import AutoTokenizer
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import LanguageModelingTask
from transformers import AutoTokenizer
DATASET = "tatsu-lab/alpaca"
@ -63,6 +64,7 @@ def main():
task.model = None
model.cpu()
del model
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
task.model = model

View file

@ -2,10 +2,10 @@ from argparse import ArgumentParser
from functools import partial
import datasets
from transformers import AutoTokenizer
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import SequenceClassificationTask
from transformers import AutoTokenizer
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
@ -67,6 +67,7 @@ def main():
task.model = None
model.cpu()
del model
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
task.model = model

View file

@ -2,10 +2,10 @@ import os
from argparse import ArgumentParser
import datasets
from transformers import AutoTokenizer, GenerationConfig
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import TextSummarizationTask
from transformers import AutoTokenizer, GenerationConfig
os.system("pip install py7zr")
@ -61,6 +61,7 @@ def main():
task.model = None
model.cpu()
del model
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
task.model = model