Training

pytorch-igniter constructs a training engine so you can focus on machine learning.

Features

  • Only create a model and write functions that train and evaluate on a single batch. pytorch-igniter constructs the training engine that checkpoints your model while training, evaluating, and logging.
  • Standardized and documented argparse command-line arguments like --batch-size, --max-epochs, and --learning-rate. Only write custom arguments that are unique to your script.
  • Save model on ctrl-C or kill. Automatically resume model from latest checkpoint. Configurable checkpointing.
  • Simplify defining metrics. Metrics can average or otherwise accumulate data and can be saved, printed, and more depending on configuration.
  • Integrate with MLflow for tracking training runs, including hyperparameters and metrics.
  • Integrate with AWS SageMaker using aws-sagemaker-remote for tracking training runs and executing training remotely on managed containers.

Basic Usage

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
import torchvision.transforms as transforms

from pytorch_igniter import train, RunSpec
from pytorch_igniter.args import train_kwargs
from pytorch_igniter.main import igniter_main


def main(
    args
):

    # Create data loaders
    train_loader = DataLoader(...)
    eval_loader = DataLoader(...)

    # Create model, optimizer, and criteria
    model = nn.Sequential(...)
    optimizer = torch.optim.Adam(...)
    criteria = nn.CrossEntropyLoss(...)

    # Single step of training
    def train_step(engine, batch):
        # Do training
        model.train()
        model.zero_grad()
        inputs, labels = batch
        outputs = model(inputs)
        loss = criteria(input=outputs, target=labels)
        loss.backward()
        optimizer.step()
        return {
            "loss": loss
        }

    # Single step of evaluation
    def eval_step(engine, batch):
        # Do evaluation
        model.eval()
        inputs, labels = batch
        outputs = model(inputs)
        loss = criteria(input=outputs, target=labels)
        return {
            "loss": loss
        }

    # Metrics average the outputs of the step functions and are printed and saved to logs
    metrics = {
        'loss': 'loss'
    }

    # Objects to save
    to_save = {
        "model": model,
        "optimizer": optimizer
    }

    train(
        to_save=to_save,
        # Training setup
        train_spec=RunSpec(
            step=train_step,
            loader=train_loader,
            metrics=metrics
        ),
        # Evaluation setup
        eval_spec=RunSpec(
            step=eval_step,
            loader=eval_loader,
            metrics=metrics
        ),
        **train_kwargs(args),
        parameters=vars(args)
    )


if __name__ == "__main__":
    igniter_main(
        main=main,
        inputs={
            'data': 'data'
        },
        # ...
    )

Command-Line Arguments

Note that additional command-line arguments are generated for each item in inputs and dependencies function arguments.

usage: pytorch-igniter [-h] [--max-epochs N] [--n-saved N_SAVED]
                       [--save-event SAVE_EVENT] [--train-pbar [TRAIN_PBAR]]
                       [--train-print-event TRAIN_PRINT_EVENT]
                       [--train-log-event TRAIN_LOG_EVENT]
                       [--sagemaker-profile SAGEMAKER_PROFILE]
                       [--sagemaker-run [SAGEMAKER_RUN]]
                       [--sagemaker-wait [SAGEMAKER_WAIT]]
                       [--sagemaker-spot-instances [SAGEMAKER_SPOT_INSTANCES]]
                       [--sagemaker-script SAGEMAKER_SCRIPT]
                       [--sagemaker-source SAGEMAKER_SOURCE]
                       [--sagemaker-training-instance SAGEMAKER_TRAINING_INSTANCE]
                       [--sagemaker-training-image SAGEMAKER_TRAINING_IMAGE]
                       [--sagemaker-training-image-path SAGEMAKER_TRAINING_IMAGE_PATH]
                       [--sagemaker-training-image-accounts SAGEMAKER_TRAINING_IMAGE_ACCOUNTS]
                       [--sagemaker-training-role SAGEMAKER_TRAINING_ROLE]
                       [--sagemaker-base-job-name SAGEMAKER_BASE_JOB_NAME]
                       [--sagemaker-job-name SAGEMAKER_JOB_NAME]
                       [--sagemaker-experiment-name SAGEMAKER_EXPERIMENT_NAME]
                       [--sagemaker-trial-name SAGEMAKER_TRIAL_NAME]
                       [--sagemaker-volume-size SAGEMAKER_VOLUME_SIZE]
                       [--sagemaker-max-run SAGEMAKER_MAX_RUN]
                       [--sagemaker-max-wait SAGEMAKER_MAX_WAIT]
                       [--sagemaker-output-json SAGEMAKER_OUTPUT_JSON]
                       [--model-dir MODEL_DIR] [--output-dir OUTPUT_DIR]
                       [--checkpoint-dir CHECKPOINT_DIR]
                       [--sagemaker-checkpoint-s3 SAGEMAKER_CHECKPOINT_S3]
                       [--sagemaker-checkpoint-container SAGEMAKER_CHECKPOINT_CONTAINER]

Named Arguments

--max-epochs

number of epochs to train (default: 10)

Default: 10

--n-saved

Number of checkpoints to keep (default: 10)

Default: 10

--save-event

save event

Default: “EPOCH_COMPLETED”

--train-pbar

Enable train progress bar

Default: True

--train-print-event
 training print event
--train-log-event
 training log event
--model-dir

Directory to save final model (default: output/model)

Default: “output/model”

--output-dir

Directory for logs, images, or other output files (default: “output/output”)

Default: “output/output”

SageMaker

SageMaker options

--sagemaker-profile
 

AWS profile for SageMaker session (default: [default])

Default: “default”

--sagemaker-run
 

Run training on SageMaker (yes/no default=False)

Default: False

--sagemaker-wait
 

Wait for SageMaker training to complete and tail logs files (yes/no default=True)

Default: True

--sagemaker-spot-instances
 

Use spot instances for training (yes/no default=False)

Default: False

--sagemaker-script
 

Script to run on SageMaker. (default: [script.py])

Default: “script.py”

--sagemaker-source
 

Source to upload to SageMaker. Must contain script. If blank, default to directory containing script. (default: [])

Default: “”

--sagemaker-training-instance
 

Instance type for training

Default: “ml.m5.large”

--sagemaker-training-image
 

Docker image for training

Default: “aws-sagemaker-remote-training:latest”

--sagemaker-training-image-path
 

Path to dockerfile if image does not exist

Default: “/home/docs/checkouts/readthedocs.org/user_builds/pytorch-igniter/envs/stable/lib/python3.7/site-packages/aws_sagemaker_remote/ecr/training”

--sagemaker-training-image-accounts
 

Accounts for docker build

Default: [‘763104351884’]

--sagemaker-training-role
 

Docker image for training

Default: “aws-sagemaker-remote-training-role”

--sagemaker-base-job-name
 

Base job name for tracking and organization on S3. A job name will be generated from the base job name unless a job name is specified.

Default: “training-job”

--sagemaker-job-name
 

Job name for tracking. Use –base-job-name instead and a job name will be automatically generated with a timestamp.

Default: “”

--sagemaker-experiment-name
 Name of experiment in SageMaker tracking.
--sagemaker-trial-name
 Name of experiment trial in SageMaker tracking.
--sagemaker-volume-size
 

Volume size in GB.

Default: 30

--sagemaker-max-run
 

Maximum runtime in seconds.

Default: 43200

--sagemaker-max-wait
 

Maximum time to wait for spot instances in seconds.

Default: 86400

--sagemaker-output-json
 Output job details to JSON file.

Checkpoints

Checkpointing options

--checkpoint-dir
 

Local directory to store checkpoints for resuming training (default: “output/checkpoint”)

Default: “output/checkpoint”

--sagemaker-checkpoint-s3
 

Location to store checkpoints on S3 or “default” (default: “default”)

Default: “default”

--sagemaker-checkpoint-container
 

Location to store checkpoints on container (default: “/opt/ml/checkpoints”)

Default: “/opt/ml/checkpoints”

See aws-sagemaker-remote documentation for SageMaker option documentation.