empty cache before switch model
This commit is contained in:
parent
1367677c45
commit
a8e748c511
3 changed files with 9 additions and 5 deletions
|
@ -1,9 +1,10 @@
|
|||
import datasets
|
||||
from argparse import ArgumentParser
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from auto_gptq.eval_tasks import LanguageModelingTask
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
DATASET = "tatsu-lab/alpaca"
|
||||
|
@ -63,6 +64,7 @@ def main():
|
|||
task.model = None
|
||||
model.cpu()
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||
task.model = model
|
||||
|
|
|
@ -2,10 +2,10 @@ from argparse import ArgumentParser
|
|||
from functools import partial
|
||||
|
||||
import datasets
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from auto_gptq.eval_tasks import SequenceClassificationTask
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
|
||||
|
@ -67,6 +67,7 @@ def main():
|
|||
task.model = None
|
||||
model.cpu()
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||
task.model = model
|
||||
|
|
|
@ -2,10 +2,10 @@ import os
|
|||
from argparse import ArgumentParser
|
||||
|
||||
import datasets
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from auto_gptq.eval_tasks import TextSummarizationTask
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
|
||||
os.system("pip install py7zr")
|
||||
|
@ -61,6 +61,7 @@ def main():
|
|||
task.model = None
|
||||
model.cpu()
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||
task.model = model
|
||||
|
|
Loading…
Add table
Reference in a new issue