empty cache before switch model
This commit is contained in:
parent
1367677c45
commit
a8e748c511
3 changed files with 9 additions and 5 deletions
|
@ -1,9 +1,10 @@
|
||||||
import datasets
|
import datasets
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
from auto_gptq.eval_tasks import LanguageModelingTask
|
from auto_gptq.eval_tasks import LanguageModelingTask
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
DATASET = "tatsu-lab/alpaca"
|
DATASET = "tatsu-lab/alpaca"
|
||||||
|
@ -63,6 +64,7 @@ def main():
|
||||||
task.model = None
|
task.model = None
|
||||||
model.cpu()
|
model.cpu()
|
||||||
del model
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||||
task.model = model
|
task.model = model
|
||||||
|
|
|
@ -2,10 +2,10 @@ from argparse import ArgumentParser
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from transformers import AutoTokenizer
|
import torch
|
||||||
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
from auto_gptq.eval_tasks import SequenceClassificationTask
|
from auto_gptq.eval_tasks import SequenceClassificationTask
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
|
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
|
||||||
|
@ -67,6 +67,7 @@ def main():
|
||||||
task.model = None
|
task.model = None
|
||||||
model.cpu()
|
model.cpu()
|
||||||
del model
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||||
task.model = model
|
task.model = model
|
||||||
|
|
|
@ -2,10 +2,10 @@ import os
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from transformers import AutoTokenizer, GenerationConfig
|
import torch
|
||||||
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
from auto_gptq.eval_tasks import TextSummarizationTask
|
from auto_gptq.eval_tasks import TextSummarizationTask
|
||||||
|
from transformers import AutoTokenizer, GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
os.system("pip install py7zr")
|
os.system("pip install py7zr")
|
||||||
|
@ -61,6 +61,7 @@ def main():
|
||||||
task.model = None
|
task.model = None
|
||||||
model.cpu()
|
model.cpu()
|
||||||
del model
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||||
task.model = model
|
task.model = model
|
||||||
|
|
Loading…
Add table
Reference in a new issue