Module pysimt.monitor
Training progress monitor.
Expand source code
"""Training progress monitor."""
from collections import defaultdict
import torch
from .utils.io import FileRotator
from .utils.misc import load_pt_file
from .metrics import beam_metrics, metric_info
class Monitor:
"""Class that tracks training progress. The following informations are
kept as object attributes:
self.ectr: # of epochs done so far
self.uctr: # of updates, i.e. mini-batches done so far
self.vctr: # of evaluations done on val_set so far
self.early_bad: # of consecutive evaluations where the model did not improve
self.train_loss: List of training losses
self.val_scores: Dict of lists keeping tracking of validation metrics
"""
# Variables to save
VARS = ['uctr', 'ectr', 'vctr', 'early_bad', 'train_loss', 'val_scores']
def __init__(self, save_path, exp_id,
model, logger, patience, eval_metrics, history=None,
save_best_metrics=False, n_checkpoints=0):
self.print = logger.info
self.save_path = save_path
self.exp_id = exp_id
self.model = model
self.patience = patience
self.eval_metrics = [e.strip() for e in eval_metrics.upper().split(',')]
self.save_best_metrics = save_best_metrics
self.optimizer = None
self.checkpoints = FileRotator(n_checkpoints)
self.beam_metrics = None
if history is None:
history = {}
self.uctr = history.pop('uctr', 0)
self.ectr = history.pop('ectr', 1)
self.vctr = history.pop('vctr', 0)
self.early_bad = history.pop('early_bad', 0)
self.train_loss = history.pop('train_loss', [])
self.val_scores = history.pop('val_scores', defaultdict(list))
if len(self.eval_metrics) > 0:
# To keep current best metric validation id and score
self.cur_bests = {}
# First metric is considered to be early-stopping metric
self.early_metric = self.eval_metrics[0]
# Will be used by optimizer
self.lr_decay_mode = metric_info[self.early_metric]
# Get metrics requiring beam_search
bms = set(self.eval_metrics).intersection(beam_metrics)
if len(bms) > 0:
self.beam_metrics = list(bms)
@staticmethod
def best_score(scores):
"""Returns the best validation id and score for that."""
idx, score = sorted(enumerate(scores), key=lambda e: e[1],
reverse=scores[0].higher_better)[0]
return (idx + 1, score)
def set_optimizer(self, optimizer):
"""Sets the optimizer to save its parameters."""
self.optimizer = optimizer
def state_dict(self):
"""Returns a dictionary of stateful variables."""
return {k: getattr(self, k) for k in self.VARS}
def val_summary(self):
"""Prints a summary of validation results."""
self.print('--> This is model: {}'.format(self.exp_id))
for name, (vctr, score) in self.cur_bests.items():
self.print('--> Best {} so far: {:.2f} @ validation {}'.format(
name, score.score, vctr))
def save_checkpoint(self):
"""Saves a checkpoint by keeping track of file rotation."""
self.checkpoints.push(
self.save_model(suffix='update{}'.format(self.uctr)))
def reload_previous_best(self):
"""Reloads the parameters from the previous best checkpoint."""
fname = self.save_path / "{}.best.{}.ckpt".format(
self.exp_id, self.early_metric.lower())
data = load_pt_file(fname)
self.model.load_state_dict(data['model'], strict=True)
def save_model(self, metric=None, suffix='', do_symlink=False):
"""Saves a checkpoint with arbitrary suffix(es) appended."""
# Construct file name
fname = self.exp_id
if metric:
self.print('Saving best model based on {}'.format(metric.name))
fname += "-val{:03d}.best.{}_{:.3f}".format(
self.vctr, metric.name.lower(), metric.score)
if suffix:
fname += "-{}".format(suffix)
fname = self.save_path / (fname + ".ckpt")
# Save the file
model_dict = {
'opts': self.model.opts.to_dict(),
'model': self.model.state_dict(),
'history': self.state_dict(),
}
# Add optimizer states
if self.optimizer is not None:
model_dict['optimizer'] = self.optimizer.state_dict()
torch.save(model_dict, fname)
# Also create a symbolic link to the above checkpoint for the metric
if metric and do_symlink:
symlink = "{}.best.{}.ckpt".format(self.exp_id, metric.name.lower())
symlink = self.save_path / symlink
if symlink.exists():
old_ckpt = symlink.resolve()
symlink.unlink()
old_ckpt.unlink()
symlink.symlink_to(fname.name)
return fname
def update_scores(self, results):
"""Updates score lists and current bests."""
for metric in results:
self.print('Validation {} -> {}'.format(self.vctr, metric))
self.val_scores[metric.name].append(metric)
self.cur_bests[metric.name] = self.best_score(
self.val_scores[metric.name])
def get_last_eval_score(self):
return self.cur_bests[self.early_metric][-1].score
def save_models(self):
cur_bests = self.cur_bests.copy()
# Let's start with early-stopping metric
vctr, metric = cur_bests.pop(self.early_metric)
if vctr == self.vctr:
self.early_bad = 0
self.save_model(metric=metric, do_symlink=True)
else:
# Increment counter
self.early_bad += 1
# If requested, save all best metric snapshots
if self.save_best_metrics and cur_bests:
for (vctr, metric) in cur_bests.values():
if metric.name in self.eval_metrics and vctr == self.vctr:
self.save_model(metric=metric, do_symlink=True)
self.print('Early stopping patience: {}'.format(
self.patience - self.early_bad))
Classes
class Monitor (save_path, exp_id, model, logger, patience, eval_metrics, history=None, save_best_metrics=False, n_checkpoints=0)
-
Class that tracks training progress. The following informations are kept as object attributes: self.ectr: # of epochs done so far self.uctr: # of updates, i.e. mini-batches done so far self.vctr: # of evaluations done on val_set so far self.early_bad: # of consecutive evaluations where the model did not improve self.train_loss: List of training losses self.val_scores: Dict of lists keeping tracking of validation metrics
Expand source code
class Monitor: """Class that tracks training progress. The following informations are kept as object attributes: self.ectr: # of epochs done so far self.uctr: # of updates, i.e. mini-batches done so far self.vctr: # of evaluations done on val_set so far self.early_bad: # of consecutive evaluations where the model did not improve self.train_loss: List of training losses self.val_scores: Dict of lists keeping tracking of validation metrics """ # Variables to save VARS = ['uctr', 'ectr', 'vctr', 'early_bad', 'train_loss', 'val_scores'] def __init__(self, save_path, exp_id, model, logger, patience, eval_metrics, history=None, save_best_metrics=False, n_checkpoints=0): self.print = logger.info self.save_path = save_path self.exp_id = exp_id self.model = model self.patience = patience self.eval_metrics = [e.strip() for e in eval_metrics.upper().split(',')] self.save_best_metrics = save_best_metrics self.optimizer = None self.checkpoints = FileRotator(n_checkpoints) self.beam_metrics = None if history is None: history = {} self.uctr = history.pop('uctr', 0) self.ectr = history.pop('ectr', 1) self.vctr = history.pop('vctr', 0) self.early_bad = history.pop('early_bad', 0) self.train_loss = history.pop('train_loss', []) self.val_scores = history.pop('val_scores', defaultdict(list)) if len(self.eval_metrics) > 0: # To keep current best metric validation id and score self.cur_bests = {} # First metric is considered to be early-stopping metric self.early_metric = self.eval_metrics[0] # Will be used by optimizer self.lr_decay_mode = metric_info[self.early_metric] # Get metrics requiring beam_search bms = set(self.eval_metrics).intersection(beam_metrics) if len(bms) > 0: self.beam_metrics = list(bms) @staticmethod def best_score(scores): """Returns the best validation id and score for that.""" idx, score = sorted(enumerate(scores), key=lambda e: e[1], reverse=scores[0].higher_better)[0] return (idx + 1, score) def set_optimizer(self, optimizer): """Sets the optimizer to save its parameters.""" self.optimizer = optimizer def state_dict(self): """Returns a dictionary of stateful variables.""" return {k: getattr(self, k) for k in self.VARS} def val_summary(self): """Prints a summary of validation results.""" self.print('--> This is model: {}'.format(self.exp_id)) for name, (vctr, score) in self.cur_bests.items(): self.print('--> Best {} so far: {:.2f} @ validation {}'.format( name, score.score, vctr)) def save_checkpoint(self): """Saves a checkpoint by keeping track of file rotation.""" self.checkpoints.push( self.save_model(suffix='update{}'.format(self.uctr))) def reload_previous_best(self): """Reloads the parameters from the previous best checkpoint.""" fname = self.save_path / "{}.best.{}.ckpt".format( self.exp_id, self.early_metric.lower()) data = load_pt_file(fname) self.model.load_state_dict(data['model'], strict=True) def save_model(self, metric=None, suffix='', do_symlink=False): """Saves a checkpoint with arbitrary suffix(es) appended.""" # Construct file name fname = self.exp_id if metric: self.print('Saving best model based on {}'.format(metric.name)) fname += "-val{:03d}.best.{}_{:.3f}".format( self.vctr, metric.name.lower(), metric.score) if suffix: fname += "-{}".format(suffix) fname = self.save_path / (fname + ".ckpt") # Save the file model_dict = { 'opts': self.model.opts.to_dict(), 'model': self.model.state_dict(), 'history': self.state_dict(), } # Add optimizer states if self.optimizer is not None: model_dict['optimizer'] = self.optimizer.state_dict() torch.save(model_dict, fname) # Also create a symbolic link to the above checkpoint for the metric if metric and do_symlink: symlink = "{}.best.{}.ckpt".format(self.exp_id, metric.name.lower()) symlink = self.save_path / symlink if symlink.exists(): old_ckpt = symlink.resolve() symlink.unlink() old_ckpt.unlink() symlink.symlink_to(fname.name) return fname def update_scores(self, results): """Updates score lists and current bests.""" for metric in results: self.print('Validation {} -> {}'.format(self.vctr, metric)) self.val_scores[metric.name].append(metric) self.cur_bests[metric.name] = self.best_score( self.val_scores[metric.name]) def get_last_eval_score(self): return self.cur_bests[self.early_metric][-1].score def save_models(self): cur_bests = self.cur_bests.copy() # Let's start with early-stopping metric vctr, metric = cur_bests.pop(self.early_metric) if vctr == self.vctr: self.early_bad = 0 self.save_model(metric=metric, do_symlink=True) else: # Increment counter self.early_bad += 1 # If requested, save all best metric snapshots if self.save_best_metrics and cur_bests: for (vctr, metric) in cur_bests.values(): if metric.name in self.eval_metrics and vctr == self.vctr: self.save_model(metric=metric, do_symlink=True) self.print('Early stopping patience: {}'.format( self.patience - self.early_bad))
Class variables
var VARS
Static methods
def best_score(scores)
-
Returns the best validation id and score for that.
Expand source code
@staticmethod def best_score(scores): """Returns the best validation id and score for that.""" idx, score = sorted(enumerate(scores), key=lambda e: e[1], reverse=scores[0].higher_better)[0] return (idx + 1, score)
Methods
def get_last_eval_score(self)
-
Expand source code
def get_last_eval_score(self): return self.cur_bests[self.early_metric][-1].score
def reload_previous_best(self)
-
Reloads the parameters from the previous best checkpoint.
Expand source code
def reload_previous_best(self): """Reloads the parameters from the previous best checkpoint.""" fname = self.save_path / "{}.best.{}.ckpt".format( self.exp_id, self.early_metric.lower()) data = load_pt_file(fname) self.model.load_state_dict(data['model'], strict=True)
def save_checkpoint(self)
-
Saves a checkpoint by keeping track of file rotation.
Expand source code
def save_checkpoint(self): """Saves a checkpoint by keeping track of file rotation.""" self.checkpoints.push( self.save_model(suffix='update{}'.format(self.uctr)))
def save_model(self, metric=None, suffix='', do_symlink=False)
-
Saves a checkpoint with arbitrary suffix(es) appended.
Expand source code
def save_model(self, metric=None, suffix='', do_symlink=False): """Saves a checkpoint with arbitrary suffix(es) appended.""" # Construct file name fname = self.exp_id if metric: self.print('Saving best model based on {}'.format(metric.name)) fname += "-val{:03d}.best.{}_{:.3f}".format( self.vctr, metric.name.lower(), metric.score) if suffix: fname += "-{}".format(suffix) fname = self.save_path / (fname + ".ckpt") # Save the file model_dict = { 'opts': self.model.opts.to_dict(), 'model': self.model.state_dict(), 'history': self.state_dict(), } # Add optimizer states if self.optimizer is not None: model_dict['optimizer'] = self.optimizer.state_dict() torch.save(model_dict, fname) # Also create a symbolic link to the above checkpoint for the metric if metric and do_symlink: symlink = "{}.best.{}.ckpt".format(self.exp_id, metric.name.lower()) symlink = self.save_path / symlink if symlink.exists(): old_ckpt = symlink.resolve() symlink.unlink() old_ckpt.unlink() symlink.symlink_to(fname.name) return fname
def save_models(self)
-
Expand source code
def save_models(self): cur_bests = self.cur_bests.copy() # Let's start with early-stopping metric vctr, metric = cur_bests.pop(self.early_metric) if vctr == self.vctr: self.early_bad = 0 self.save_model(metric=metric, do_symlink=True) else: # Increment counter self.early_bad += 1 # If requested, save all best metric snapshots if self.save_best_metrics and cur_bests: for (vctr, metric) in cur_bests.values(): if metric.name in self.eval_metrics and vctr == self.vctr: self.save_model(metric=metric, do_symlink=True) self.print('Early stopping patience: {}'.format( self.patience - self.early_bad))
def set_optimizer(self, optimizer)
-
Sets the optimizer to save its parameters.
Expand source code
def set_optimizer(self, optimizer): """Sets the optimizer to save its parameters.""" self.optimizer = optimizer
def state_dict(self)
-
Returns a dictionary of stateful variables.
Expand source code
def state_dict(self): """Returns a dictionary of stateful variables.""" return {k: getattr(self, k) for k in self.VARS}
def update_scores(self, results)
-
Updates score lists and current bests.
Expand source code
def update_scores(self, results): """Updates score lists and current bests.""" for metric in results: self.print('Validation {} -> {}'.format(self.vctr, metric)) self.val_scores[metric.name].append(metric) self.cur_bests[metric.name] = self.best_score( self.val_scores[metric.name])
def val_summary(self)
-
Prints a summary of validation results.
Expand source code
def val_summary(self): """Prints a summary of validation results.""" self.print('--> This is model: {}'.format(self.exp_id)) for name, (vctr, score) in self.cur_bests.items(): self.print('--> Best {} so far: {:.2f} @ validation {}'.format( name, score.score, vctr))