1
0
Fork 0
mirror of https://github.com/deepfakes/faceswap synced 2025-06-07 10:43:27 -04:00
faceswap/plugins/model/Model_Original/AutoEncoder.py
torzdf 7f53911453
Logging (#541)
* Convert prints to logger. Further logging improvements. Tidy  up

* Fix system verbosity. Allow SystemExit

* Fix reload extract bug

* Child Traceback handling

* Safer shutdown procedure

* Add shutdown event to queue manager

* landmarks_as_xy > property. GUI notes + linting. Aligner bugfix

* fix FaceFilter. Enable nFilter when no Filter is supplied

* Fix blurry face filter

* Continue on IO error. Better error handling

* Explicitly print stack trace tocrash log

* Windows Multiprocessing bugfix

* Add git info and conda version to crash log

* Windows/Anaconda mp bugfix

* Logging fixes for training
2018-12-04 13:31:49 +00:00

77 lines
2.8 KiB
Python

# AutoEncoder base classes
import logging
from lib.utils import backup_file
from lib import Serializer
from json import JSONDecodeError
hdf = {'encoderH5': 'encoder.h5',
'decoder_AH5': 'decoder_A.h5',
'decoder_BH5': 'decoder_B.h5',
'state': 'state'}
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class AutoEncoder:
def __init__(self, model_dir, gpus):
self.model_dir = model_dir
self.gpus = gpus
self.encoder = self.Encoder()
self.decoder_A = self.Decoder()
self.decoder_B = self.Decoder()
self.initModel()
def load(self, swapped):
serializer = Serializer.get_serializer('json')
state_fn = ".".join([hdf['state'], serializer.ext])
try:
with open(str(self.model_dir / state_fn), 'rb') as fp:
state = serializer.unmarshal(fp.read().decode('utf-8'))
self._epoch_no = state['epoch_no']
except IOError as e:
logger.warning('Error loading training info: %s', str(e.strerror))
self._epoch_no = 0
except JSONDecodeError as e:
logger.warning('Error loading training info: %s', str(e.msg))
self._epoch_no = 0
(face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5'])
try:
self.encoder.load_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder_A.load_weights(str(self.model_dir / face_A))
self.decoder_B.load_weights(str(self.model_dir / face_B))
logger.info('loaded model weights')
return True
except Exception as e:
logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir)
return False
def save_weights(self):
model_dir = str(self.model_dir)
for model in hdf.values():
backup_file(model_dir, model)
self.encoder.save_weights(str(self.model_dir / hdf['encoderH5']))
self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5']))
self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5']))
logger.info('saved model weights')
serializer = Serializer.get_serializer('json')
state_fn = ".".join([hdf['state'], serializer.ext])
state_dir = str(self.model_dir / state_fn)
try:
with open(state_dir, 'wb') as fp:
state_json = serializer.marshal({
'epoch_no' : self.epoch_no
})
fp.write(state_json.encode('utf-8'))
except IOError as e:
logger.error(e.strerror)
@property
def epoch_no(self):
"Get current training epoch number"
return self._epoch_no