Module pysimt.mainloop

Training main loop.

Expand source code
"""Training main loop."""

import time
import logging

import torch

from .evaluator import Evaluator
from .optimizer import Optimizer
from .monitor import Monitor
from .utils.nn import get_module_groups
from .utils.misc import load_pt_file
from .utils.ml_metrics import Loss
from .utils.data import make_dataloader
from .utils.tensorboard import TensorBoard
from .translators import get_translator

logger = logging.getLogger('pysimt')


class MainLoop:
    def __init__(self, model, train_opts, dev_mgr):
        # Get all training options into this mainloop
        self.__dict__.update(train_opts)

        self.print = logger.info
        self.model = model
        self.dev_mgr = dev_mgr
        self.epoch_valid = (self.eval_freq == 0)
        self.oom_count = 0
        self.loss_meter = Loss()
        self._found_optim_state = None

        # Load training and validation data & create iterators
        self.print('Loading dataset(s)')
        self.train_iterator = make_dataloader(
            self.model.load_data('train', self.batch_size),
            self.pin_memory, self.num_workers)

        # Create monitor for validation, evaluation, checkpointing stuff
        self.monitor = Monitor(self.save_path / self.subfolder, self.exp_id,
                               self.model, logger, self.patience,
                               self.eval_metrics,
                               save_best_metrics=self.save_best_metrics,
                               n_checkpoints=self.n_checkpoints)

        # If a validation set exists
        if 'val_set' in self.model.opts.data and self.eval_freq >= 0:
            if 'LOSS' in self.monitor.eval_metrics:
                self.vloss_iterator = make_dataloader(
                    self.model.load_data('val', self.batch_size, mode='eval'))

            if self.monitor.beam_metrics is not None:
                self.beam_iterator = make_dataloader(
                    self.model.load_data('val', self.eval_batch_size, mode='beam'))
                # Create hypothesis evaluator
                self.evaluator = Evaluator(
                    self.model.val_refs, self.monitor.beam_metrics,
                    filters=self.eval_filters)
                self.translator = get_translator(self.model.opts.model['translator_type'])(
                    self.model, self.beam_iterator, out_prefix='',
                    batch_size=self.eval_batch_size,
                    max_len=self.eval_max_len)

        # Setup model
        self.model.setup()
        self.model.reset_parameters()

        ################################################
        # Initialize model weights with a pretrained one
        # This should come after model.setup()
        ################################################
        if train_opts['pretrained_file']:
            # Relax the strict condition for partial initialization
            data = load_pt_file(train_opts['pretrained_file'])
            weights = data['model']
            self._found_optim_state = data.get('optimizer', None)
            if train_opts['pretrained_layers']:
                prefixes = tuple(train_opts['pretrained_layers'].split(','))
                keys = [w for w in weights if w.startswith(prefixes)]
                weights = {k: weights[k] for k in keys}

            for name in get_module_groups(weights.keys()):
                self.print(
                    ' -> will initialize {}.* with pretrained weights.'.format(name))
            model.load_state_dict(weights, strict=False)

        ############################
        # Freeze layers if requested
        ############################
        if train_opts['freeze_layers']:
            frozen = []
            for layer in train_opts['freeze_layers'].split(','):
                for name, param in self.model.named_parameters():
                    if name.startswith(layer):
                        param.requires_grad = False
                        frozen.append(name)

            for name in frozen:
                self.print(f' -> froze parameter {name}')

        self.print(self.model)
        self.model = self.model.to(self.dev_mgr.dev)

        if self.dev_mgr.req_cpu or len(self.dev_mgr.cuda_dev_ids) == 1:
            self.net = self.model
        else:
            self.net = torch.nn.DataParallel(
                self.model, device_ids=self.dev_mgr.cuda_dev_ids, dim=1)

        # Create optimizer instance
        self.optim = Optimizer(
            self.optimizer, self.model, lr=self.lr, momentum=self.momentum,
            nesterov=self.nesterov, weight_decay=self.l2_reg,
            gclip=self.gclip, lr_decay=self.lr_decay,
            lr_decay_factor=self.lr_decay_factor,
            lr_decay_mode=self.monitor.lr_decay_mode,
            lr_decay_min=self.lr_decay_min,
            lr_decay_patience=self.lr_decay_patience,
            tf_model_dim=self.tf_model_dim,
            lr_warmup_steps=self.lr_warmup_steps,
            adam_betas=self.adam_betas,
        )
        self.print(self.optim)

        if self._found_optim_state:
            # NOTE: This will overwrite weight_decay and lr parameters
            # from the checkpoint without obeying to new config file!
            self.optim.load_state_dict(self._found_optim_state)

        if self.save_optim_state:
            self.monitor.set_optimizer(self.optim)

        # Create TensorBoard logger if possible and requested
        self.tboard = TensorBoard(self.model, self.tensorboard_dir,
                                  self.exp_id, self.subfolder)
        self.print(self.tboard)

        # Models can also use tensorboard for custom purposes
        self.model.register_tensorboard(self.tboard)

    def train_batch(self, batch):
        """Trains a batch."""
        nn_start = time.time()

        # Reset gradients
        self.optim.zero_grad()

        # Forward pass with training progress
        # NOTE: Problematic for multi-gpu
        out = self.net(batch, uctr=self.monitor.uctr, ectr=self.monitor.ectr)
        self.loss_meter.update(out['loss'], out['n_items'])
        loss = out['loss'] / out['n_items']

        # Add other losses if any
        if self.net.aux_loss:
            loss += sum(list(self.net.aux_loss.values()))

        # Backward pass
        loss.backward()

        # Update parameters (includes gradient clipping logic)
        self.optim.step()

        return time.time() - nn_start

    def train_epoch(self):
        """Trains a full epoch."""
        self.print('Starting Epoch {}'.format(self.monitor.ectr))

        nn_sec = 0.0
        eval_sec = 0.0
        total_sec = time.time()
        self.loss_meter.reset()
        self.oom_count = 0

        for batch in self.train_iterator:
            batch.device(self.dev_mgr.dev)
            self.monitor.uctr += 1

            try:
                nn_sec += self.train_batch(batch)
            except RuntimeError as e:
                if self.handle_oom and 'out of memory' in e.args[0]:
                    torch.cuda.empty_cache()
                    self.oom_count += 1
                else:
                    raise e

            if self.monitor.uctr % self.disp_freq == 0:
                # Send statistics
                self.tboard.log_scalar(
                    'train_LOSS', self.loss_meter.batch_loss, self.monitor.uctr)

                msg = "Epoch {} - update {:10d} => loss: {:>7.3f}".format(
                    self.monitor.ectr, self.monitor.uctr,
                    self.loss_meter.batch_loss)
                for key, value in self.net.aux_loss.items():
                    val = value.item()
                    msg += ' [{}: {:.3f}]'.format(key, val)
                    self.tboard.log_scalar('train_' + key.upper(), val, self.monitor.uctr)
                msg += ' (#OOM: {})'.format(self.oom_count)
                self.print(msg)

            # Do validation?
            if (not self.epoch_valid and
                    self.monitor.ectr >= self.eval_start and
                    self.eval_freq > 0 and
                    self.monitor.uctr % self.eval_freq == 0):
                eval_start = time.time()
                self.do_validation()
                eval_sec += time.time() - eval_start

            if (self.checkpoint_freq and self.n_checkpoints > 0 and
                    self.monitor.uctr % self.checkpoint_freq == 0):
                self.print('Saving checkpoint...')
                self.monitor.save_checkpoint()

            # Check stopping conditions
            if self.monitor.early_bad == self.monitor.patience:
                self.print("Early stopped.")
                return False

            if self.monitor.uctr == self.max_iterations:
                self.print("Max iterations {} reached.".format(
                    self.max_iterations))
                return False

        # All time spent for this epoch
        total_min = (time.time() - total_sec) / 60
        # All time spent during forward/backward/step
        nn_min = nn_sec / 60
        # All time spent during validation(s)
        eval_min = eval_sec / 60
        # Rest is iteration overhead + checkpoint saving
        overhead_min = total_min - nn_min - eval_min

        # Compute epoch loss
        epoch_loss = self.loss_meter.get()
        self.monitor.train_loss.append(epoch_loss)

        self.print("--> Epoch {} finished with mean loss {:.5f}".format(
            self.monitor.ectr, epoch_loss))
        self.print("--> Overhead/Training/Evaluation: {:.2f}/{:.2f}/{:.2f} "
                   "mins (total: {:.2f} mins)   ({} samples/sec)".format(
                       overhead_min, nn_min, eval_min, total_min,
                       int(len(self.train_iterator.dataset) / nn_sec)))

        # Do validation?
        if self.epoch_valid and self.monitor.ectr >= self.eval_start:
            self.do_validation()

        # Check whether maximum epoch is reached
        if self.monitor.ectr == self.max_epochs:
            self.print("Max epochs {} reached.".format(self.max_epochs))
            return False

        self.monitor.ectr += 1
        return True

    def do_validation(self):
        """Do early-stopping validation."""
        results = []
        self.monitor.vctr += 1
        self.net.train(False)
        torch.set_grad_enabled(False)

        # Collect simple validation stats first
        self.print('Computing evaluation loss...')
        results.extend(self.net.test_performance(self.vloss_iterator))

        if self.monitor.beam_metrics:
            tr_args = self.model.opts.model.get('translator_args', {})
            self.print(f'Performing greedy search (args: {tr_args})')
            beam_time = time.time()
            # Use greedy search
            hyps, *_ = self.translator.run(**tr_args)
            beam_time = time.time() - beam_time

            # Compute metrics and update results
            score_time = time.time()
            results.extend(self.evaluator.score(hyps))
            score_time = time.time() - score_time

        # Log metrics to tensorboard
        self.tboard.log_metrics(results, self.monitor.uctr, suffix='val_')

        # Add new scores to history
        self.monitor.update_scores(results)

        # Do a scheduler LR step
        lr_change = self.optim.lr_step(self.monitor.get_last_eval_score())
        if lr_change and self.lr_decay_revert:
            self.print('Reloading previous best model parameters')
            self.monitor.reload_previous_best()

        # Check early-stop criteria and save snapshots if any
        self.monitor.save_models()

        # Dump summary and switch back to training mode
        self.monitor.val_summary()
        self.net.train(True)
        torch.set_grad_enabled(True)

    def __call__(self):
        """Runs training loop."""
        self.print('Training started on %s' % time.strftime('%d-%m-%Y %H:%M:%S'))
        self.net.train(True)
        torch.set_grad_enabled(True)

        # Evaluate once before even starting training
        if self.eval_zero:
            self.do_validation()

        while self.train_epoch():
            pass

        if self.monitor.vctr > 0:
            self.monitor.val_summary()
        else:
            # No validation done, save final model
            self.print('Saving final model.')
            self.monitor.save_model(suffix='final')

        self.print('Training finished on %s' % time.strftime('%d-%m-%Y %H:%M'))
        # Close tensorboard
        self.tboard.close()

Classes

class MainLoop (model, train_opts, dev_mgr)
Expand source code
class MainLoop:
    def __init__(self, model, train_opts, dev_mgr):
        # Get all training options into this mainloop
        self.__dict__.update(train_opts)

        self.print = logger.info
        self.model = model
        self.dev_mgr = dev_mgr
        self.epoch_valid = (self.eval_freq == 0)
        self.oom_count = 0
        self.loss_meter = Loss()
        self._found_optim_state = None

        # Load training and validation data & create iterators
        self.print('Loading dataset(s)')
        self.train_iterator = make_dataloader(
            self.model.load_data('train', self.batch_size),
            self.pin_memory, self.num_workers)

        # Create monitor for validation, evaluation, checkpointing stuff
        self.monitor = Monitor(self.save_path / self.subfolder, self.exp_id,
                               self.model, logger, self.patience,
                               self.eval_metrics,
                               save_best_metrics=self.save_best_metrics,
                               n_checkpoints=self.n_checkpoints)

        # If a validation set exists
        if 'val_set' in self.model.opts.data and self.eval_freq >= 0:
            if 'LOSS' in self.monitor.eval_metrics:
                self.vloss_iterator = make_dataloader(
                    self.model.load_data('val', self.batch_size, mode='eval'))

            if self.monitor.beam_metrics is not None:
                self.beam_iterator = make_dataloader(
                    self.model.load_data('val', self.eval_batch_size, mode='beam'))
                # Create hypothesis evaluator
                self.evaluator = Evaluator(
                    self.model.val_refs, self.monitor.beam_metrics,
                    filters=self.eval_filters)
                self.translator = get_translator(self.model.opts.model['translator_type'])(
                    self.model, self.beam_iterator, out_prefix='',
                    batch_size=self.eval_batch_size,
                    max_len=self.eval_max_len)

        # Setup model
        self.model.setup()
        self.model.reset_parameters()

        ################################################
        # Initialize model weights with a pretrained one
        # This should come after model.setup()
        ################################################
        if train_opts['pretrained_file']:
            # Relax the strict condition for partial initialization
            data = load_pt_file(train_opts['pretrained_file'])
            weights = data['model']
            self._found_optim_state = data.get('optimizer', None)
            if train_opts['pretrained_layers']:
                prefixes = tuple(train_opts['pretrained_layers'].split(','))
                keys = [w for w in weights if w.startswith(prefixes)]
                weights = {k: weights[k] for k in keys}

            for name in get_module_groups(weights.keys()):
                self.print(
                    ' -> will initialize {}.* with pretrained weights.'.format(name))
            model.load_state_dict(weights, strict=False)

        ############################
        # Freeze layers if requested
        ############################
        if train_opts['freeze_layers']:
            frozen = []
            for layer in train_opts['freeze_layers'].split(','):
                for name, param in self.model.named_parameters():
                    if name.startswith(layer):
                        param.requires_grad = False
                        frozen.append(name)

            for name in frozen:
                self.print(f' -> froze parameter {name}')

        self.print(self.model)
        self.model = self.model.to(self.dev_mgr.dev)

        if self.dev_mgr.req_cpu or len(self.dev_mgr.cuda_dev_ids) == 1:
            self.net = self.model
        else:
            self.net = torch.nn.DataParallel(
                self.model, device_ids=self.dev_mgr.cuda_dev_ids, dim=1)

        # Create optimizer instance
        self.optim = Optimizer(
            self.optimizer, self.model, lr=self.lr, momentum=self.momentum,
            nesterov=self.nesterov, weight_decay=self.l2_reg,
            gclip=self.gclip, lr_decay=self.lr_decay,
            lr_decay_factor=self.lr_decay_factor,
            lr_decay_mode=self.monitor.lr_decay_mode,
            lr_decay_min=self.lr_decay_min,
            lr_decay_patience=self.lr_decay_patience,
            tf_model_dim=self.tf_model_dim,
            lr_warmup_steps=self.lr_warmup_steps,
            adam_betas=self.adam_betas,
        )
        self.print(self.optim)

        if self._found_optim_state:
            # NOTE: This will overwrite weight_decay and lr parameters
            # from the checkpoint without obeying to new config file!
            self.optim.load_state_dict(self._found_optim_state)

        if self.save_optim_state:
            self.monitor.set_optimizer(self.optim)

        # Create TensorBoard logger if possible and requested
        self.tboard = TensorBoard(self.model, self.tensorboard_dir,
                                  self.exp_id, self.subfolder)
        self.print(self.tboard)

        # Models can also use tensorboard for custom purposes
        self.model.register_tensorboard(self.tboard)

    def train_batch(self, batch):
        """Trains a batch."""
        nn_start = time.time()

        # Reset gradients
        self.optim.zero_grad()

        # Forward pass with training progress
        # NOTE: Problematic for multi-gpu
        out = self.net(batch, uctr=self.monitor.uctr, ectr=self.monitor.ectr)
        self.loss_meter.update(out['loss'], out['n_items'])
        loss = out['loss'] / out['n_items']

        # Add other losses if any
        if self.net.aux_loss:
            loss += sum(list(self.net.aux_loss.values()))

        # Backward pass
        loss.backward()

        # Update parameters (includes gradient clipping logic)
        self.optim.step()

        return time.time() - nn_start

    def train_epoch(self):
        """Trains a full epoch."""
        self.print('Starting Epoch {}'.format(self.monitor.ectr))

        nn_sec = 0.0
        eval_sec = 0.0
        total_sec = time.time()
        self.loss_meter.reset()
        self.oom_count = 0

        for batch in self.train_iterator:
            batch.device(self.dev_mgr.dev)
            self.monitor.uctr += 1

            try:
                nn_sec += self.train_batch(batch)
            except RuntimeError as e:
                if self.handle_oom and 'out of memory' in e.args[0]:
                    torch.cuda.empty_cache()
                    self.oom_count += 1
                else:
                    raise e

            if self.monitor.uctr % self.disp_freq == 0:
                # Send statistics
                self.tboard.log_scalar(
                    'train_LOSS', self.loss_meter.batch_loss, self.monitor.uctr)

                msg = "Epoch {} - update {:10d} => loss: {:>7.3f}".format(
                    self.monitor.ectr, self.monitor.uctr,
                    self.loss_meter.batch_loss)
                for key, value in self.net.aux_loss.items():
                    val = value.item()
                    msg += ' [{}: {:.3f}]'.format(key, val)
                    self.tboard.log_scalar('train_' + key.upper(), val, self.monitor.uctr)
                msg += ' (#OOM: {})'.format(self.oom_count)
                self.print(msg)

            # Do validation?
            if (not self.epoch_valid and
                    self.monitor.ectr >= self.eval_start and
                    self.eval_freq > 0 and
                    self.monitor.uctr % self.eval_freq == 0):
                eval_start = time.time()
                self.do_validation()
                eval_sec += time.time() - eval_start

            if (self.checkpoint_freq and self.n_checkpoints > 0 and
                    self.monitor.uctr % self.checkpoint_freq == 0):
                self.print('Saving checkpoint...')
                self.monitor.save_checkpoint()

            # Check stopping conditions
            if self.monitor.early_bad == self.monitor.patience:
                self.print("Early stopped.")
                return False

            if self.monitor.uctr == self.max_iterations:
                self.print("Max iterations {} reached.".format(
                    self.max_iterations))
                return False

        # All time spent for this epoch
        total_min = (time.time() - total_sec) / 60
        # All time spent during forward/backward/step
        nn_min = nn_sec / 60
        # All time spent during validation(s)
        eval_min = eval_sec / 60
        # Rest is iteration overhead + checkpoint saving
        overhead_min = total_min - nn_min - eval_min

        # Compute epoch loss
        epoch_loss = self.loss_meter.get()
        self.monitor.train_loss.append(epoch_loss)

        self.print("--> Epoch {} finished with mean loss {:.5f}".format(
            self.monitor.ectr, epoch_loss))
        self.print("--> Overhead/Training/Evaluation: {:.2f}/{:.2f}/{:.2f} "
                   "mins (total: {:.2f} mins)   ({} samples/sec)".format(
                       overhead_min, nn_min, eval_min, total_min,
                       int(len(self.train_iterator.dataset) / nn_sec)))

        # Do validation?
        if self.epoch_valid and self.monitor.ectr >= self.eval_start:
            self.do_validation()

        # Check whether maximum epoch is reached
        if self.monitor.ectr == self.max_epochs:
            self.print("Max epochs {} reached.".format(self.max_epochs))
            return False

        self.monitor.ectr += 1
        return True

    def do_validation(self):
        """Do early-stopping validation."""
        results = []
        self.monitor.vctr += 1
        self.net.train(False)
        torch.set_grad_enabled(False)

        # Collect simple validation stats first
        self.print('Computing evaluation loss...')
        results.extend(self.net.test_performance(self.vloss_iterator))

        if self.monitor.beam_metrics:
            tr_args = self.model.opts.model.get('translator_args', {})
            self.print(f'Performing greedy search (args: {tr_args})')
            beam_time = time.time()
            # Use greedy search
            hyps, *_ = self.translator.run(**tr_args)
            beam_time = time.time() - beam_time

            # Compute metrics and update results
            score_time = time.time()
            results.extend(self.evaluator.score(hyps))
            score_time = time.time() - score_time

        # Log metrics to tensorboard
        self.tboard.log_metrics(results, self.monitor.uctr, suffix='val_')

        # Add new scores to history
        self.monitor.update_scores(results)

        # Do a scheduler LR step
        lr_change = self.optim.lr_step(self.monitor.get_last_eval_score())
        if lr_change and self.lr_decay_revert:
            self.print('Reloading previous best model parameters')
            self.monitor.reload_previous_best()

        # Check early-stop criteria and save snapshots if any
        self.monitor.save_models()

        # Dump summary and switch back to training mode
        self.monitor.val_summary()
        self.net.train(True)
        torch.set_grad_enabled(True)

    def __call__(self):
        """Runs training loop."""
        self.print('Training started on %s' % time.strftime('%d-%m-%Y %H:%M:%S'))
        self.net.train(True)
        torch.set_grad_enabled(True)

        # Evaluate once before even starting training
        if self.eval_zero:
            self.do_validation()

        while self.train_epoch():
            pass

        if self.monitor.vctr > 0:
            self.monitor.val_summary()
        else:
            # No validation done, save final model
            self.print('Saving final model.')
            self.monitor.save_model(suffix='final')

        self.print('Training finished on %s' % time.strftime('%d-%m-%Y %H:%M'))
        # Close tensorboard
        self.tboard.close()

Methods

def do_validation(self)

Do early-stopping validation.

Expand source code
def do_validation(self):
    """Do early-stopping validation."""
    results = []
    self.monitor.vctr += 1
    self.net.train(False)
    torch.set_grad_enabled(False)

    # Collect simple validation stats first
    self.print('Computing evaluation loss...')
    results.extend(self.net.test_performance(self.vloss_iterator))

    if self.monitor.beam_metrics:
        tr_args = self.model.opts.model.get('translator_args', {})
        self.print(f'Performing greedy search (args: {tr_args})')
        beam_time = time.time()
        # Use greedy search
        hyps, *_ = self.translator.run(**tr_args)
        beam_time = time.time() - beam_time

        # Compute metrics and update results
        score_time = time.time()
        results.extend(self.evaluator.score(hyps))
        score_time = time.time() - score_time

    # Log metrics to tensorboard
    self.tboard.log_metrics(results, self.monitor.uctr, suffix='val_')

    # Add new scores to history
    self.monitor.update_scores(results)

    # Do a scheduler LR step
    lr_change = self.optim.lr_step(self.monitor.get_last_eval_score())
    if lr_change and self.lr_decay_revert:
        self.print('Reloading previous best model parameters')
        self.monitor.reload_previous_best()

    # Check early-stop criteria and save snapshots if any
    self.monitor.save_models()

    # Dump summary and switch back to training mode
    self.monitor.val_summary()
    self.net.train(True)
    torch.set_grad_enabled(True)
def train_batch(self, batch)

Trains a batch.

Expand source code
def train_batch(self, batch):
    """Trains a batch."""
    nn_start = time.time()

    # Reset gradients
    self.optim.zero_grad()

    # Forward pass with training progress
    # NOTE: Problematic for multi-gpu
    out = self.net(batch, uctr=self.monitor.uctr, ectr=self.monitor.ectr)
    self.loss_meter.update(out['loss'], out['n_items'])
    loss = out['loss'] / out['n_items']

    # Add other losses if any
    if self.net.aux_loss:
        loss += sum(list(self.net.aux_loss.values()))

    # Backward pass
    loss.backward()

    # Update parameters (includes gradient clipping logic)
    self.optim.step()

    return time.time() - nn_start
def train_epoch(self)

Trains a full epoch.

Expand source code
def train_epoch(self):
    """Trains a full epoch."""
    self.print('Starting Epoch {}'.format(self.monitor.ectr))

    nn_sec = 0.0
    eval_sec = 0.0
    total_sec = time.time()
    self.loss_meter.reset()
    self.oom_count = 0

    for batch in self.train_iterator:
        batch.device(self.dev_mgr.dev)
        self.monitor.uctr += 1

        try:
            nn_sec += self.train_batch(batch)
        except RuntimeError as e:
            if self.handle_oom and 'out of memory' in e.args[0]:
                torch.cuda.empty_cache()
                self.oom_count += 1
            else:
                raise e

        if self.monitor.uctr % self.disp_freq == 0:
            # Send statistics
            self.tboard.log_scalar(
                'train_LOSS', self.loss_meter.batch_loss, self.monitor.uctr)

            msg = "Epoch {} - update {:10d} => loss: {:>7.3f}".format(
                self.monitor.ectr, self.monitor.uctr,
                self.loss_meter.batch_loss)
            for key, value in self.net.aux_loss.items():
                val = value.item()
                msg += ' [{}: {:.3f}]'.format(key, val)
                self.tboard.log_scalar('train_' + key.upper(), val, self.monitor.uctr)
            msg += ' (#OOM: {})'.format(self.oom_count)
            self.print(msg)

        # Do validation?
        if (not self.epoch_valid and
                self.monitor.ectr >= self.eval_start and
                self.eval_freq > 0 and
                self.monitor.uctr % self.eval_freq == 0):
            eval_start = time.time()
            self.do_validation()
            eval_sec += time.time() - eval_start

        if (self.checkpoint_freq and self.n_checkpoints > 0 and
                self.monitor.uctr % self.checkpoint_freq == 0):
            self.print('Saving checkpoint...')
            self.monitor.save_checkpoint()

        # Check stopping conditions
        if self.monitor.early_bad == self.monitor.patience:
            self.print("Early stopped.")
            return False

        if self.monitor.uctr == self.max_iterations:
            self.print("Max iterations {} reached.".format(
                self.max_iterations))
            return False

    # All time spent for this epoch
    total_min = (time.time() - total_sec) / 60
    # All time spent during forward/backward/step
    nn_min = nn_sec / 60
    # All time spent during validation(s)
    eval_min = eval_sec / 60
    # Rest is iteration overhead + checkpoint saving
    overhead_min = total_min - nn_min - eval_min

    # Compute epoch loss
    epoch_loss = self.loss_meter.get()
    self.monitor.train_loss.append(epoch_loss)

    self.print("--> Epoch {} finished with mean loss {:.5f}".format(
        self.monitor.ectr, epoch_loss))
    self.print("--> Overhead/Training/Evaluation: {:.2f}/{:.2f}/{:.2f} "
               "mins (total: {:.2f} mins)   ({} samples/sec)".format(
                   overhead_min, nn_min, eval_min, total_min,
                   int(len(self.train_iterator.dataset) / nn_sec)))

    # Do validation?
    if self.epoch_valid and self.monitor.ectr >= self.eval_start:
        self.do_validation()

    # Check whether maximum epoch is reached
    if self.monitor.ectr == self.max_epochs:
        self.print("Max epochs {} reached.".format(self.max_epochs))
        return False

    self.monitor.ectr += 1
    return True