Framework Migration Guide

https://chainer.github.io/migration-guide/

Authors: Chainer Team

General Information

Concepts and components in both frameworks

Array Library

Core Framework and Training Loop

Migration scenarios

I want to port my Chainer script to PyTorch, step by step

I want to let my Chainer code train a PyTorch model

Migration tools (cpm)

cpm.TorchModule

cpm.ChainerParameter

cpm.LinkAsTorchModel

cpm.ignite.add_trainer_extension

cpm.use_torch_in_cupy_malloc

Porting Guides

Dataset and data pre/post-processing

Negative strides

Rewriting Custom Converter Functions to collate_fn

NumPy bridge

CuPy bridge

Difference between PyTorch and NumPy/CuPy

Division Behavior

Feature Mapping

Training loop

Evaluation loop

Training and evaluation using Ignite

Using Chainer extensions with Ignite

Snapshots

Porting custom updater using Ignite

Rewriting existing Chainer model

Functions and Links

Functions

Links

Configuration

Hooks

Function Hooks

Link Hooks

Optimizer Hooks

Training PyTorch model using Chainer

Distributed training

Pytorch model using torch.distributed

Invocation

Initialization

Dataset scattering

Data transfer to devices

Optimizer wrapping

Initial values broadcast

Metrics average and reductions

https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions

Synchronization

PyTorch model using Horovod

Horovod initialization

Dataset scattering

Optimizer wrapping

Initial values broadcast

Metrics average and reductions

Horovod code structure

Obtaining Horovod traces to measure performance

Tuning Horovod performance

Using Horovod with apex

Multi-Node Batch Normalization in Horovod

Gathering arbitrary objects using Horovod and mpi4py

Alternatives to Horovod

Chainer model using Horovod

PyTorch model using ChainerMN

Porting code that edits the computational graph

Unchaining nodes

Backprop modes

Train/Test modes

Ecosystem

PyTorch

Ignite

torchvision

torchtext

torchaudio

Fairseq

Other

This document provides technical information for migration from Chainer to PyTorch.

General Information

Concepts and components in both frameworks

Array Library

Chainer uses NumPy/CuPy (xp.ndarray) as an array library, and wraps them as chainer.Variable to support autograd. Similarly, PyTorch uses ATen (at::Tensor (C++)) as an array library ("tensor library" in PyTorch terms), and wraps it as torch::Tensor (C++ API) / torch.Tensor (Python API) to support autograd. torch.* provides API similar to (but not compatible with) NumPy, e.g. torch.dot, torch.float32, etc.

Core Framework and Training Loop

As both frameworks share the same concept, define-by-run, the look-and-feel of code written in PyTorch is pretty similar to Chainer. Here is the high-level mapping of features:

Chainer

PyTorch

Notes

Variable

chainer.Variable

Tensor

torch.Tensor

Function

chainer.FunctionNode

(chainer.functions.*)

Function

torch.autograd.Function

(torch.nn.functional.*)

`torch.*` also provides NumPy-like (but not compatible) operations.

Link / Chain

chainer.{Link, Chain}

(chainer.links.*)

Module

torch.nn.Module

(torch.nn.*)

Sequential

chainer.Sequential

Sequential

torch.nn.Sequential

You can use function modules as member (e.g., torch.nn.ReLU instead of torch.nn.functional.relu).

Dataset

chainer.dataset.DatasetMixin

(chainer.datasets.*)

Dataset

torch.utils.data.Dataset

There are no TransformDataset in PyTorch (there is one in CPM as cpm.TransformDataset); datasets conventionally accepts `transforms` argument that perform per-example preprocessing.

Iterator

chainer.iterators.*

DataLoader

torch.utils.data.DataLoader

Unlike Chainer's Iterator, DataLoader automatically collates all samples into one Tensor by default; use collate_fn to customize this behavior.

DataLoader itself supports multi-process iteration (using num_workers option).

Optimizer

chainer.Optimizer

(chainer.optimizers.*)

Optimizer

torch.optim.Optimizer

(torch.optim.*)

Trainer

chainer.training.Trainer

Engine

ignite.Engine

ignite.engine.create_supervised_trainer()

Updater (with converter)

chainer.training.Updater

As noted above, Iterator concatenates examples by default. Transfer to device is handled by Engine (or custom loop code if you don't use Ignite)

Evaluator

chainer.training.extensions.Evaluator

ignite.engine.create_supervised_evaluator()

Extension

chainer.training.Extension

(chainer.training.extensions.*)

Handler

(ignite.handlers.*, ignite.contrib.handlers.*)

Refer to the Porting Guide section for the details of the difference of each component.

Migration scenarios

I want to port my Chainer script to PyTorch, step by step

Arguably the model is the hardest part to port without affecting the outcome of the training.

It might be easier to port in this order:

  1. Training script (optimizer / updater / evaluator / ...)
  1. Dataset / preprocessing
  1. Model

I want to let my Chainer code train a PyTorch model

You can use cpm.TorchModule to wrap a PyTorch module as a Chainer model.

Migration tools (CPM)

chainer-pytorch-migration Python module (called CPM or "cpm" (module name) in this document) provides various utilities to help migration from Chainer to PyTorch.

Example code assumes that cpm is imported as follows:

import chainer_pytorch_migration as cpm

import chainer_pytorch_migration.ignite

cpm.TorchModule

This class wraps a PyTorch module as a Chainer link. It allows training PyTorch models in Chainer training scripts. The graph (forward/backward) must be constructed and traversed in PyTorch.

model = torchvision.models.resnet50()

model.cuda()

w_model = cpm.TorchModule(model)

w_model.to_gpu(device) # Just synchronizes the metadata, does not transfer data

cpm.ChainerParameter

This class wraps a Chainer parameter as a PyTorch parameter. It allows training of Chainer models (chainer.Link) in PyTorch training scripts (with torch.optim.Optimizer). The graph (forward/backward) must be constructed and traversed in Chainer. cpm.LinkAsTorchModel internally uses it.

# initialized parameter

arr = numpy.full(shape, 17, 'float32')

chainer_param = chainer.Parameter(arr)

torch_param = cpm.ChainerParameter(chainer_param)

cpm.LinkAsTorchModel

This class automatically creates all the cpm.ChainerParameter objects for a given chainer link and provides methods such as parameters(), named_parameters() or state_dict() required by pytorch optimizers or tools such as horovod.

model = ChainerModel()

model.to_device(ch_device)

# Initialize parameters before converting to `ChainerParameter`s.

model(ch_device.xp.zeros((1, 784)).astype('f'))

# Convert parameters to `ChainerParameter`s to share memory with PyTorch.

torched_model = cpm.LinkAsTorchModel(model)

optimizer = optim.SGD(torched_model.parameters(), lr=args.lr)

cpm.ignite.add_trainer_extension

This function registers a chainer trainer extension to be used with ignite.

Function call requires the ignite trainer, torch optimizer and the chainer extension as the parameters

optimizer.target = model

trainer.out = 'path to store extension results'

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ExponentialShift('lr', 0.9, 1.0, 0.1))

pytorch-pfn-extras (ppe)

pytorch-pfn-extras Python module (called PPE or "ppe" (module name) in this document) provides various supplementary components for PyTorch, including APIs similar to Chainer, e.g. Extensions, Reporter, Lazy modules (automatically infer shapes of parameters). Here are some notable features Refer to the Documentation for the full list of features.

PPE also provides the interoperability feature between CuPy and PyTorch memory pool.

ppe.cuda.use_torch_mempool_in_cupy

This function makes CuPy use a memory pool from PyTorch. You need to call it before any operations using CuPy.

# Enable using PyTorch memory allocator in CuPy.

ppe.cuda.use_torch_mempool_in_cupy()

# Revert back to CuPy's default memory pool.

ppe.cuda.use_default_mempool_in_cupy()

Note: The feature was originally implemented in CPM as cpm.use_torch_in_cupy_malloc, but has been moved to PPE. CPM version has been deprecated and not recommended any more.

Porting Guides

Dataset and data pre/post-processing

PyTorch datasets (pytorch.utils.data.Dataset) are basically compatible with Chainer’s. In most cases they are interchangeable in both directions.

Negative strides

As of PyTorch 1.2.0, PyTorch cannot handle data arrays with negative strides (can result from numpy.flip or chainercv.transforms.flip, for example).

Perhaps the easiest way to circumvent this problem is to wrap the dataset with numpy.ascontiguousarray.

def avoid_negative_strides(in_data):

    data, label = in_data

    data = numpy.ascontiguousarray(data)

    return data, label

dataset = cpm.TransformDataset(dataset, avoid_negative_strides)

data_loader = torch.utils.data.DataLoader(dataset, ...)

Another way is to customize the collation function with collate_fn argument in torch.utils.data.DataLoader.

def collate(batch):

    data = numpy.stack([d for d, l in batch])

    label = numpy.stack([l for d, l in batch])

    data_tensor = torch.from_numpy(data)

    label_tensor = torch.from_numpy(label)

    return data_tensor, label_tensor

data_loader = torch.utils.data.DataLoader(dataset, ..., collate_fn=collate)

Rewriting Custom Converter Functions to collate_fn

In Chainer, it’s possible to specify custom converters for each batch via `training.updaters.StandardUpdater(train_iter, optimizer, device=device, converter=_converter)`. In PyTorch, similar functionality can be achieved via the data loader: `DataLoader(..., collate_fn=_converter)`.

There is, however, an important difference when used in conjunction with multiprocessing. In Chainer, `_converter` will be run in the main process, so it’s safe to access CUDA in the function when using multiprocessing’s `fork` mode. In PyTorch, however, `_converter` will be run inside each forker worker processes of the data loader. This means that we cannot access CUDA without getting a CUDA init error. It seems like in PyTorch, the correct usage is instead to only do CPU-related operations inside `_convert`, and only send the resulting tensors to the GPU *after* retrieving them from the data loader. The following is an example of correct PyTorch usage:

it = DataLoader(..., collate_fn=_converter)

for img, label, metadata in it:

     img = img.cuda()

     label = label.cuda()

     # metadata is still on CPU

     ...

Note that the above scenario is different from what we expect in Chainer, where the `_converter` is called in the main process, which is why Chainer code might have CUDA-related operations inside the `_converter`.

Note that in the above use case, _convert should also use `pin_memory` in order to speed up the transfer of `(img, label)` from CPU to GPU: https://discuss.pytorch.org/t/when-to-set-pin-memory-to-true/19723

NumPy bridge

torch.DataLoader automatically converts NumPy arrays to PyTorch tensors, but if you want to do that manually, refer to NumPy Bridge.

CuPy bridge

DLPack can be used to bridge between CuPy and torch.Tensor. Note that DLPack does not handle ownership, so you have to make sure the original buffer (the original cupy.ndarray object or dltensor capsule object returned by toDlpack()) survives while the converted tensor/array is in use.

If you allocate a memory both in PyTorch and CuPy, it is also recommended to call ppe.cuda.use_torch_mempool_in_cupy before using CuPy to let CuPy use the PyTorch memory pool. Otherwise memories allocated and freed in CuPy will be kept in the CuPy memory pool which cannot be used by PyTorch.

Difference between PyTorch and NumPy/CuPy

Division Behavior

The behavior is different from NumPy/CuPy, which respects Python 3 division rules. You need to explicitly cast to float in PyTorch (discussion).

>>> x = numpy.arange(5)

>>> x

array([0, 1, 2, 3, 4])

>>> x / 5

array([0. , 0.2, 0.4, 0.6, 0.8])

>>> torch.from_numpy(x) / 5

tensor([0, 0, 0, 0, 0])

>>> torch.from_numpy(x).float() / 5

tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000])

Feature Mapping

See PyTorch for Numpy users for the comparison table.

Training loop

This is an example code of training loop. Note model.train().

device = torch.device('cuda:0')

for i_epoch in range(args.epoch):

    train_loss = 0

    train_correct = 0

    model.train()

    for x, t in data_loader:

        x = x.to(device)

        t = t.to(device).long()

        optimizer.zero_grad()

        y = model(x)

        loss = F.nll_loss(y, t)

        loss.backward()

        optimizer.step()

        train_loss += loss.sum().item()

        _, pred = torch.max(y, 1)

        train_correct += (pred == t).sum().item()

    train_loss /= len(data_loader.dataset)

    train_accuracy = train_correct / len(data_loader.dataset)

    print('Train average loss: {:.03f}'.format(train_loss))

    print('Train accuracy : {:.03f} %'.format(train_accuracy * 100))

Evaluation loop

This is an example code of evaluation loop. Note model.eval() and with torch.no_grad().

device = torch.device('cuda:0')

total_loss = 0

total_correct = 0

model.eval()

with torch.no_grad():

    for x, t in data_loader:

        x = x.to(device)

        t = t.to(device).long()

        y = model(x)

        total_loss += F.nll_loss(y, t, reduction='sum').item()

        _, pred = torch.max(y, 1)

        total_correct += (pred == t).sum().item()

average_loss = total_loss / len(loader.dataset)

accuracy = total_correct / len(loader.dataset)

Training and evaluation using Ignite

Ignite is something corresponding to chainer.training.Trainer in Chainer.

This Chainer code:

updater = chainer.training.StandardUpdater(

    train_iter,

    optimizer,

    device=device)

trainer = chainer.training.Trainer(updater, (100, ‘epoch’))

trainer.extend(

    extensions.Evaluator(

        val_iter,

        model,

        device=device),

    trigger=(1, ‘epoch’))

trainer.run()

can be written in PyTorch using Ignite:

trainer = ignite.engine.create_supervised_trainer(

    model,

    optimizer,

    F.nll_loss,

    device=device)

evaluator = ignite.engine.create_supervised_evaluator(

    model,

    metrics={

        'accuracy': ignite.metrics.Accuracy(),

        'loss': ignite.metrics.Loss(F.nll_loss),

    },

    device=device)

@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)

def validation(engine):

    evaluator.run(val_loader)

    average_accuracy = evaluator.state.metrics[‘accuracy’]

    average_loss = evaluator.state.metrics[‘loss’]

    print(average_accuracy, average_loss)

trainer.run(train_loader, max_epochs=100)

For a list of supported metrics, see https://pytorch.org/ignite/metrics.html.

Using Chainer extensions with Ignite

Using cpm.ignite.add_trainer_extension it is possible to register a chainer extension to be called within the ignite training loop.

A list of the supported extensions follows:

Works

Doesn’t work

ExponentialShift

DumpGraph

FailOnNonNumber

Evaluator

InverseShift

unchain_variables

LinearShift

LogReport

MicroAverage

MultistepShift

ParameterStatistics

PlotReport

PolynomialShift

PrintReport

ProgressBar

snapshot(read docs)

StepShift

observe_lr

VariableStatisticsPlot

WarmupShift

Some drawbacks rely on that metrics associated to the model or links might not accessible by default.

For example the user will need to report the loss or accuracy per iteration by using an ignite callback as this was done inside the chainer model.

Also for some extensions to work it is necessary for the user to assign the torch or chainer model to the optimizer target attribute and the output directory path for the LogReport, plotters and snapshot extensions

from chainer import reporter
@trainer.on(Events.ITERATION_COMPLETED)

def report_loss(engine):

    reporter.report({'loss':engine.state.output})

An example of how to register multiple extensions:

# Torch optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)

# Ignite trainer

trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)


# Add the model to the target attribute of the optimizer

optimizer.target = model

# Set the output dir for some of the extensions

trainer.out = 'result'

# Restore the snapshot

cpm.ignite.load_chainer_snapshot(trainer, optimizer, 'result/snapshot_iter_4691')

# Add a bunch of extensions

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ProgressBar())

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.observe_lr())

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.MicroAverage('loss','lr','mav',(1, 'iteration')))

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.LogReport())

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.FailOnNonNumber())

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ExponentialShift('lr', 0.9, 1.0, 0.1))

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ParameterStatistics(model, prefix='model'))

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.VariableStatisticsPlot(model))

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PrintReport(

    ['epoch', 'iteration', 'loss', 'lr', 'mav', 'model/fc2/weight/grad/percentile/1']))

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.PlotReport(['loss'],

                                  'epoch', filename='loss.png'))

cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.snapshot(writer=writer), trigger=(1, 'epoch'))  # writer is a SimpleWriter

Snapshots

When using the snapshot extension with an ignite trainer, the pytorch objects are saved to an additional “snapshot-torch” file in the output folder. This allows to keep using these snapshots once the migration is finished and directly load pytorch models or the optimizer state from these files.

Additionally, if you are mixing chainer models or optimizers with ignite and pytorch, these objects will be saved in the chainer snapshot file.

The correct way to restore a snapshot is by using cpm.ignite.load_chainer_snapshot(engine, optimizer, snapshot_path) with the Chainer snapshot path.

Note that previously taken Chainer snapshots are not compatible.

Porting custom updater using Ignite

You can pass a step function to an Ignite engine.

Rewriting existing Chainer model

Use Mapping of functions and links to find and replace with the corresponding feature in PyTorch. You can also find existing model implementations in:

Common pitfalls:

Functions and Links

You can find the PyTorch equivalent of Chainer's functions and links in tables below.

Notes:

Functions

F refers to chainer.functions (Chainer) / torch.nn.functional (PyTorch).

Chainer

PyTorch

Notes

Arithmetic functions

F.add

torch.add

Batched addition (accumulating multiple tensors in a single call) is not supported.

Activation functions

F.clipped_relu

Rewrite as:

x.clamp(0, z)

F.crelu

Rewrite as:

torch.cat((F.relu(x), F.relu(-x)))

F.elu

F.elu

F.hard_sigmoid

Rewrite as:

torch.clamp(x * 0.2 + 0.5, 0, 1)

F.leaky_relu

F.leaky_relu

The default slope value is different.

F.log_softmax

F.log_softmax

F.lstm

See L.LSTM.

F.maxout

Need to implement manually; see https://github.com/pytorch/pytorch/issues/805

F.prelu

F.prelu

F.rrelu

F.rrelu

`training` option must be explicitly specified instead of `train` config in Chainer.

F.relu

F.relu

F.relu6

F.relu6

F.selu

F.selu

F.sigmoid

F.sigmoid

F.slstm

Some OSS implementations are available (e.g., https://github.com/reachtarunhere/S-LSTM-PyTorch)

F.softmax

F.softmax

F.softplus

F.softplus

PyTorch falls back to linear function by default; threshold option must be explicitly given.

F.swish

Rewrite as:

x * F.sigmoid(beta * x)

F.tanh

F.tanh

F.tree_lstm

Some OSS implementations are available (e.g., https://github.com/dasguptar/treelstm.pytorch)

Array manipulations

F.as_strided

torch.as_strided

F.broadcast

torch.broadcast_tensors

PyTorch operations perform broadcast automatically like as in NumPy: https://pytorch.org/docs/stable/notes/broadcasting.html

F.broadcast_to

N/A: https://github.com/pytorch/pytorch/pull/17160

[a]

F.cast

Tensor.to

F.concat

torch.cat

F.copy

Tensor.to

F.depth2space

F.pixel_shuffle

F.diagonal

torch.diagonal

F.dstack

Rewrite as:

torch.cat([a,b],dim=2)

F.expand_dims

Rewrite as:

torch.unsqueeze(a, dim)

F.flatten

torch.flatten

F.flip

torch.flip

F.fliplr

torch.flip

Use dims=1

F.flipud

torch.flip

Use dims=0

F.get_item

Use direct indexing: `x[indexes]`. Negative strides are not supported.

F.hstack

Rewrite as:

torch.cat([a,b],dim=1)

F.im2col

F.unfold

NCHW is only supported

F.moveaxis

Tensor.permute

See: https://discuss.pytorch.org/t/swap-axes-in-pytorch/970/2

F.pad

F.pad

Replace `constant_values` argument with `value`. Modes other than `constant` are also available.

F.pad_sequence

nn.utils.rnn.pad_squence

You cannot specify the length but the maximum length among the inputs is used.

F.permutate

Tensor.permute

F.repeat

Tensor.repeat

Different behavior to F.repeat. F.tile is more similar.

F.reshape

torch.reshape

F.resize_images

F.interpolate

F.rollaxis

Tensor.premute

See https://discuss.pytorch.org/t/swap-axes-in-pytorch/970/2

F.scatter_add

Tensor.scatter_add

F.select_item

Rewrite as:

torch.gather(x, 1, t[:, None])[:, 0]

F.separate

torch.split

Requires manual manipulation of the results to achieve some of the separate functionality.

F.space2depth

You need to implement it yourself. Ref:

https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14

F.spatial_transformer_grid

F.affine_grid

The second argument `size` takes `torch.Size` object that denotes the target output image size (N, C, H, W), while `F.spatial_transformer_grid` takes just a tuple of (H, W). The size of returned tensor is also different: (N x H x W x 2) is returned instead of (N x 2 x H x W). Also note the breaking change regarding align_corners in v1.3.0 (https://github.com/pytorch/pytorch/releases/tag/v1.3.0)

F.spatial_transformer_sampler

F.grid_sample

Grid shape is (N, 2, H, W) in Chainer while (N, H, W, 2) in PyTorch.

F.split_axis

torch.split

No `force_tuple`.

F.squeeze

torch.squeeze

F.stack

torch.stack

Use torch.stack or torch.cat([a,b],dim=axis)

F.swapaxes

Use permute instead: https://discuss.pytorch.org/t/swap-axes-in-pytorch/970/7

F.tile

F.repeat

F.transpose

torch.t

Use Tensor.permute or torch.t for no axes version

F.transpose_sequence

N/A

F.vstack

Rewrite as:

torch.cat([a,b],dim=0)

F.where

torch.where

Neural network connections

F.bilinear

F.bilinear

F.convolution_1d

F.conv1d

No `cover_all`.

F.convolution_2d

F.conv2d

No `cover_all`.

F.convolution_3d

F.conv3d

No `cover_all`.

F.convolution_nd

See https://discuss.pytorch.org/t/is-there-a-way-to-realize-4d-convolution-using-the-convnd-function/5999

F.deconvolution_1d

F.conv_transpose1d

F.deconvolution_2d

F.conv_transpose2d

F.deconvolution_3d

F.conv_transpose3d

F.deconvolution_nd

N/A, see https://discuss.pytorch.org/t/is-there-a-way-to-realize-4d-convolution-using-the-convnd-function/5999

F.depthwise_convolution_2d

F.conv2d

Use `groups` argument; see https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2

F.deformable_convolution_2d_sampler

Not implemented: https://github.com/pytorch/pytorch/issues/2260

F.dilated_convolution_2d

F.conv2d

Use `dilation` argument.

F.embed_id

F.embedding

F.linear

F.linear

There is no option for `n_batch_axes`.

F.local_convolution_2d

See https://github.com/pytorch/pytorch/pull/1583

F.n_step_bigru

Undocumented _C._VariableFunctions function torch.gru? The "link" is available https://pytorch.org/docs/stable/nn.html#torch.nn.GRU and this is probably the expected usage.

F.n_step_bilstm

See L.NStepBiLSTM.

F.n_step_birnn

See L.NStepBiRNNTanh or L.NStepBiRNNReLU.

F.n_step_gru

See L.NStepBiGRU.

F.n_step_lstm

See L.NStepLSTM.

F.n_step_rnn

See L.NStepRNNTanh or L.NStepRNNReLU.

F.shift

See https://github.com/pytorch/pytorch/issues/16408

Evaluation functions

F.accuracy

N/A, Ignite has an implementation: https://pytorch.org/ignite/metrics.html#ignite.metrics.Accuracy

F.binary_accuracy

N/A, Ignite has an implementation: https://pytorch.org/ignite/metrics.html#ignite.metrics.Accuracy

F.classification_summary

N/A

F.f1_score

See https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/1

F.precision

N/A, Ignite has an implementation: https://pytorch.org/ignite/metrics.html#ignite.metrics.Precision

F.r2_score

Not available. It's an evaluation metric that's not differentiable. It's implemented in Ignite though and could be used (as a reference) https://github.com/pytorch/ignite/pull/496

F.recall

N/A, Ignite has an implementation: https://pytorch.org/ignite/metrics.html#ignite.metrics.Recall

Loss functions

F.absolute_error

N/A

F.bernoulli_nll

Possibly: -torch.distributions.Bernoulli(y).log_prob(x).sum()

F.black_out

F.connectionist_temporal_classification

F.ctc_loss

F.contrastive

See https://github.com/adambielski/siamese-triplet

F.crf1d

Not available: https://github.com/pytorch/pytorch/issues/11134

F.argmax_crf1d

Not available: https://github.com/pytorch/pytorch/issues/11134

F.cross_covariance

F.decov

F.discriminative_margin_based_clustering_loss

Not available. See https://github.com/Wizaron/instance-segmentation-pytorch for a reproducing work.

F.gaussian_kl_divergence

N/A

F.gaussian_nll

N/A

F.hinge

F.hinge_embedding_loss

F.huber_loss

F.smooth_l1_loss

Use reduction='sum' to keep reduction method

F.mean_absolute_error

F.l1_loss

See also: ignite.metrics.MeanAbsoluteError

F.mean_squared_error

F.mse_loss

F.negative_sampling

See https://github.com/kefirski/pytorch_NEG_loss, https://github.com/theeluwin/pytorch-sgns

F.sigmoid_cross_entropy

F.binary_cross_entropy_with_logits

F.softmax_cross_entropy

nn.CrossEntropyLoss

F.squared_error

F.triplet

F.triplet_margin_loss

Mathematical functions

F.absolute

torch.abs

F.arccos

torch.acos

F.arcsin

torch.asin

F.arctan

torch.atan

F.arctan2

torch.atan2

F.arctanh

N/A: https://github.com/pytorch/pytorch/issues/10324

F.argmax

torch.argmax

F.argmin

torch.argmin

F.average

See F.mean.

F.batch_inv

torch.inverse

linalg ops batch is on progress, inverse is already merged

F.batch_l2_norm_squared

Rewrite as:

x.reshape(len(x), -1).norm(dim=1) ** 2

F.batch_matmul

torch.matmul

F.bias

Rewrite as

x + y[(...,) + (None,) * (x.ndim - y.ndim - axis)]

F.ceil

torch.ceil

F.clip

torch.clamp

F.cos

torch.cos

F.cosh

torch.cosh

F.cumprod

torch.cumprod

F.cumsum

torch.cumsum

F.det

torch.det

F.batch_det

torch.det

Arbitrary number of batch axes are supported.

F.digamma

torch.digamma

F.einsum

torch.einsum

F.erf

torch.erf

F.erfc

torch.erfc

F.erfcinv

Not available. Implement it similar to erf, erfinv? https://github.com/pytorch/pytorch/pull/2799

F.erfcx

N/A

F.erfinv

torch.erfinv

F.exp

torch.exp

F.expm1

torch.expm1

F.fft

torch.fft

Interface is quite different.

F.fix

N/A

F.fmod

torch.fmod

F.floor

torch.floor

F.identity

nn.Identity

F.ifft

torch.ifft

F.inv

torch.inverse

F.lgamma

torch.lgamma

Currently undocumented: https://github.com/pytorch/pytorch/pull/27812

F.linear_interpolate

Normal math should suffice.

F.log

torch.log

F.log10

torch.log10

F.log1p

torch.log1p

F.log2

torch.log2

F.log_ndtr

N/A

F.logsumexp

torch.logsumexp

F.matmul

torch.matmul

F.max

torch.max

F.maximum

torch.max

F.mean

torch.mean

Weighted average is not supported; rewrite as (without keepdims):

torch.tensordot(x, weights / weights.sum(), ([axis], [0]))

F.min

torch.min

F.minimum

torch.min

F.ndtr

F.gelu(x) is corresponding to x * F.ndtr(x).

F.ndtri

N/A

F.prod

torch.prod

F.polygamma

torch.polygamma

Not documented: https://github.com/pytorch/pytorch/issues/25347

n>=2 not supported: https://github.com/pytorch/pytorch/blob/v1.3.1/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp#L179

F.rsqrt

torch.rsqrt

-

F.scale

Rewrite as:

x * y[(...,) + (None,) * (x.ndim - y.ndim - 1)]

F.sin

torch.sin

F.sinh

torch.sinh

F.sign

torch.sign

F.sparse_matmul

torch.sparse.mm

Only support dense-sparse product. For sparse-dense product, transpose the operands and the output.

F.sqrt

torch.sqrt

F.square

Rewrite as:

x * x

F.squared_difference

N/A

F.sum

torch.sum

F.sum_to

N/A

F.tanh

torch.tanh

F.tan

torch.tan

F.tensordot

torch.tensordot

Noise injections

F.dropout

F.dropout

No mask support, elements are randomly zeroed.

F.gaussian

torch.distributions.normal.Normal

F.gumbel_softmax

F.gumbel_softmax

The default value of tau is 1, while the Chainer's function takes 0.1.

F.simplified_dropconnect

Not available. Use F.dropout on the weight, or try torchnlp.nn.WeightDrop.

F.zoneout

N/A

Normalization functions

F.batch_normalization

F.batch_norm

F.batch_renormalization

N/A

F.decorrelated_batch_normalization

N/A

F.fixed_batch_normalization

F.batch_norm

training=False?

F.fixed_batch_renormalization

Batch Renormalization not implemented: https://discuss.pytorch.org/t/support-for-batch-renormalization/2965,  https://discuss.pytorch.org/t/batch-renormalization-implementation-in-thcunn/5144

F.fixed_decorrelated_batch_normalization

N/A

F.group_normalization

F.group_norm

Currently undocumented.

F.layer_normalization

layer_norm

`gamma = weight` & `beta= bias`

F.local_response_normalization

F.local_response_norm

F.normalize

F.normalize

The PyTorch's `F.normalize` is not only for L2 normalization. But the default behavior is for L2 normalization, i.e., the default value of the second argument `p` is set to 2.

Spatial pooling

F.average_pooling_1d

F.avg_pool1d

F.average_pooling_2d

F.avg_pool2d

Superset of Chainer's counterpart.

F.average_pooling_3d

F.avg_pool3d

F.average_pooling_nd

N/A

F.max_pooling_1d

F.max_pool1d

In addition to arguments documented in `nn.MaxPool1D`, `return_indices` argument is available to obtain index for unpooling.

F.max_pooling_2d

F.max_pool2d

ditto.

F.max_pooling_3d

F.max_pool3d

ditto.

F.max_pooling_nd

N/A

F.roi_average_align_2d

torchvision.ops.roi_align

Requires Torchvision, torchvision has only 2 roi functions, roi_align uses the average of the pixels while roi_pool uses the max value

F.roi_average_pooling_2d

N/A

F.roi_max_align_2d

N/A

F.roi_max_pooling_2d

torchvision.ops.roi_pool

The `roi_pool` function of Torchvision is meant to be roi max pooling according to the source: https://github.com/pytorch/vision/blob/ccd1b27d2b7312ebddb4d51b3a4f8ade1ba8fa8b/torchvision/csrc/cpu/ROIPool_cpu.cpp#L65

Regarding the API, the way to pass the batch indices of each set of RoI coordinates is different.

F.roi_pooling_2d

torchvision.ops.roi_pool

Requires torchvision.

F.spatial_pyramid_pooling_2d

Not available. Not too difficult to implement? It's a combination of existing functions in Chainer.

F.unpooling_1d

F.max_unpool1d

Pass `indices` returned from F.max_pool1d.

F.unpooling_2d

F.max_unpool2d

ditto.

F.unpooling_3d

F.max_unpool3d

ditto.

F.unpooling_nd

N/A

F.upsampling_2d

N/A

Utility functions

F.forget

See https://pytorch.org/docs/stable/checkpoint.html

Links

L refers to chainer.links (Chainer), and nn refers to torch.nn (PyTorch).

Chainer

PyTorch

Notes

Learnable connections

L.Bias

L.Bilinear

nn.Bilinear

L.ChildSumTreeLSTM

N/A. Reference user implementation at: https://github.com/ttpro1995/TreeLSTMSentiment

L.Convolution1D

nn.Conv1d

L.Convolution2D

nn.Conv2d

L.Convolution3D

nn.Conv3d

L.ConvolutionND

N/A

L.Deconvolution1D

nn.ConvTranspose1d

L.Deconvolution2D

nn.ConvTranspose2d

L.Deconvolution3D

nn.ConvTranspose3d

L.DeconvolutionND

L.DeformableConvolution2D

N/A: https://github.com/pytorch/pytorch/issues/2260

L.DepthwiseConvolution2D

nn.Conv2d

Use `groups` argument; see https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2

L.DilatedConvolution2D

nn.Conv2d

Use `dilation` argument.

L.EmbedID

nn.Embedding

L.GRU

nn.GRU

L.Highway

N/A

L.Inception

`torchvision.models.inception.InceptionA` seems to be the corresponding module for Chainer's `L.Inception`, but is not documented.

L.InceptionBN

See torchvision.models.inception for Inception v3

L.Linear

nn.Linear

L.LocalConvolution2D

N/A

L.LSTM

nn.LSTM

L.MLPConvolution2D

L.NaryTreeLSTM

N/A

L.NStepBiGRU

nn.GRU

bidirectional=True, no explicit activation, no stacking

L.NStepBiLSTM

nn.LSTM

bidirectional=True, no explicit activation, no stacking

L.NStepBiRNNReLU

nn.RNN

bidirectional=True, no explicit activation, no stacking

L.NStepBiRNNTanh

nn.RNN

bidirectional=True, no explicit activation, no stacking

L.NStepGRU

nn.GRU

L.NStepLSTM

nn.LSTM

L.NStepRNNReLU

nn.RNN

L.NStepRNNTanh

nn.RNN

L.Parameter

You could use torch.nn.modules.ParameterList with 1 element

L.Scale

N/A

L.StatefulGRU

nn.GRU

L.StatelessGRU

N/A

L.StatefulMGU

See https://github.com/jpeg729/pytorch_bits

L.StatelessMGU

N/A

L.StatefulPeepholeLSTM

See https://github.com/pytorch/pytorch/issues/630

L.StatefulZoneoutLSTM

N/A: https://github.com/pytorch/pytorch/pull/4838

L.StatelessLSTM

Activation/loss/normalization functions with parameters

L.BatchNormalization

nn.BatchNorm1d

nn.BatchNorm2d

nn.BatchNorm3d

The argument `momentum` in the PyTorch implementation seems to be equivalent to `1 - decay` in the Chainer's link.

The default value for the argument `eps` (1e-5) is different from Chainer's default value (2e-5).

L.BatchRenormalization

N/A

L.DecorrelatedBatchNormalization

Not available. A reference implementation (not that well implemented?) https://github.com/huangleiBuaa/IterNorm-pytorch/blob/master/extension/normailzation/dbn.py. Otherwise look at the Torch lua official implementation https://github.com/princeton-vl/DecorrelatedBN.

L.GroupNormalization

nn.GroupNorm

affine=True

L.LayerNormalization

nn.LayerNorm

elementwise_affine=True

L.BinaryHierarchicalSoftmax

L.BlackOut

N/A

L.CRF1d

L.SimplifiedDropconnect

N/A

L.PReLU

nn.PReLU

L.Swish

See https://blog.ceshine.net/post/pytorch-memory-swish/

L.Maxout

L.NegativeSampling

Machine learning models

L.Classifier

N/A

Pre-trained models

L.VGG16Layers

torchvision.models.vgg*

Superset of Chainer's VGG variations in torchvision.

L.VGG19Layers

torchvision.models.vgg19*

ditto

L.model.vision.vgg.prepare

N/A

L.GoogLeNet

torchvision.models.googlenet

L.model.vision.googlenet.prepare

transform_input=True in torchvision.models.googlenet

L.model.vision.resnet.ResNetLayers

L.ResNet50Layers

torchvision.models.resnet101

torchvision only, pretrained=True

L.ResNet101Layers

torchvision.models.resnet101

torchvision only, pretrained=True

L.ResNet152Layers

torchvision.models.resnet152

torchvision only, pretrained=True

L.model.vision.resnet.prepare

N/A

L.TheanoFunction

N/A

L.caffe.CaffeFunction

See https://github.com/marvis/pytorch-caffe or https://github.com/Microsoft/MMdnn

Configuration

Here is the mapping of configurations in Chainer (chainer.config.*) and PyTorch:

Chainer

PyTorch

Notes

autotune

torch.backends.cudnn.benchmark

Not thread-local.

cudnn_deterministic

torch.backends.cudnn.deterministic

Not thread-local.

cudnn_fast_batch_normalization

N/A

Intentionally unsupported as the precision is low in some models.

debug

N/A

Use torch.autograd.detect_anomaly() context-manager to check NaN during backward, display the corresponding forward stack trace when error occurred in backward.

dtype

torch.set_default_dtype(dtype)

Mixed precision support is done via Apex. Not thread-local.

enable_backprop

torch.no_grad()

torch.enable_grad()

You can use them as context-manager or decorator. See also Backprop modes.

is_recomputing

N/A

See torch.utils.checkpoint.checkpoint for F.forget equivalent (it also supports RNG).

keep_graph_on_report

N/A

lazy_grad_sum

N/A

train

N/A

The mode is configured per Module (using Module.train() and Module.eval()). See also Train/Test modes.

type_check

N/A

use_cudnn

torch.backends.cudnn.enabled

Enabled by default. Not thread-local.

use_cudnn_tensor_core

N/A

Tensor Cores cannot be disabled.

use_ideep

N/A

PyTorch itself supports MKL-DNN. You can check availability using torch.backends.mkldnn.is_available().

use_static_graph

N/A

warn_nondeterministic

N/A

See Reproducibility for the reproducibility (including steps to fix seeds).

Hooks

Function Hooks

There is no equivalent feature in PyTorch.

Replacements for Chainer built-in hooks:

Link Hooks

You can register Module Hooks per module. There's no way to inject a hook for every Module called under the specific scope.

Replacements for Chainer built-in hooks:

Optimizer Hooks

There is no direct equivalent in PyTorch, but you can register backward hooks per Tensor / Module to modify gradients.

Replacements for Chainer built-in hooks:

Training PyTorch model using Chainer

To quickly try a PyTorch model in a training script using Chainer, cpm.TorchModule is the tool to use. Assuming you have a training script using Chainer, you have to try the following steps:

Distributed training

As of writing, there are two major ways to run distributed deep learning applications: torch.distributed and Horovod. We recommend torch.distributed as a first option because of the following reasons.

  1. torch.distributed is a part of standard modules of PyTorch.
  2. It supports some advanced features that Horovod doesn’t, such as multi-node batch normalization (e.g. inter-process batch normalization)

In this document, we describe both approaches to migrate ChainerMN programs to PyTorch.

Pytorch model using torch.distributed        

Torch.distributed is the standard module for distributed deep learning of PyTorch.

Torch.distributed supports three backends: “nccl”, “mpi” and “gloo”. For users who are migrating from Chainer and ChainerMN and have been using NCCL with MPI, using “nccl” backend is the most straightforward way. In this section, we assume that you use NCCL and MPI to run your distributed deep learning programs. In particular we assume Open MPI as the MPI implementation used here because it is the recommended option in ChainerMN, but other MPI implementations are mentioned as well.

Invocation

In ChainerMN, process invocation is totally coordinated by the MPI runtime. However, in PyTorch and torch.distributed, you may need a few more steps to invoke distributed deep learning processes. The simplest initialization method might be environment variable initialization.

The following environmental variables are necessary (whatever system you use to invoke your script, including MPI). Other variables, WORLD_SIZE and RANK, are set from inside the following snippet.

MASTER_ADDR : Address of the computing node where the rank 0 process runs.

MASTER_PORT : A free port of the MASTER_ADDR machine. The port will be used by  the rank 0 process.

Note that process invocation is highly system-dependent issue. PyTorch supports other options such as TCP initialization and shared file-system initialization. Please refer to the official documents for more details.

Initialization

The following code snippets shows how to initialize torch.distributed module.

# setup env for torch.distributed

comm_world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])

comm_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])

comm_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

os.environ["WORLD_SIZE"] = str(comm_world_size)

os.environ["RANK"] = str(comm_rank)

torch.cuda.set_device(comm_local_rank)

torch.distributed.init_process_group(backend='nccl', init_method='env://')

Environmental variables set by MPI runtime are here, instead of communicator.intra_rank in ChainerMN because torch.distributed does not provide corresponding rank information. If you use MVAPICH2, use MV2_COMM_WORLD_SIZE, MV2_COMM_WORLD_RANK, MV2_COMM_WORLD_LOCAL_RANK respectively.

Dataset scattering

Each node can get a slice of a globally shared dataset using a DistributedSampler.

sampler = torch.utils.data.distributed.DistributedSampler(dataset,

                                                          num_replicas=comm_world_size,

                                                          rank=comm_rank)

loader_kwargs = {'num_workers': 1, 'pin_memory': True}  # Assuming we use GPUs

loader = torch.utils.data.DataLoader(train_dataset,

                                       batch_size=args.batch_size,

                                       sampler=sampler, **loader_kwargs)

This will make every worker to only load a slice of the dataset, this sampler can be normally fed to the DataLoader.

Also, you need to call DistributedSampler.set_epoch() to adjust epoch numbers.  Thus typical training loop looks like:

for epoch in range(1, args.epochs + 1):

    train_sampler.set_epoch(epoch)

    train(args, model, device, train_loader, optimizer, epoch)

    test(args, model, device, test_loader, len(test_dataset))

    scheduler.step()

Data transfer to devices

We need to specify the device to which the data is transferred using comm_local_rank.

class MyNN(nn.module):

    ...

device = torch.device("cuda:{}".format(comm_local_rank) if use_cuda else "cpu")

model = MyNN().to(device)

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[comm_local_rank])

Optimizer wrapping

In contrast to Horovod, We can use the same optimizer as in non-distributed execution.

Initial values broadcast

Parameter values are synchronized (i.e. initial broadcast and allreduce in every iteration) automatically by DistributedDataParallel class and thus no further modification is necessary.

Metrics average and reductions

https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions

Synchronization

To avoid potential data races other kinds of bugs, you may need to use torch.distributed.barrier() to synchronize processes before or after data loading, and finishing the application.

PyTorch model using Horovod

PyTorch can use Horovod to do Data Parallel training in a similar way to ChainerMN.
Data is distributed across the nodes and the optimizer is wrapped in with
Horovod to automatically average the gradients of several MPI processes.

Horovod initialization

The following snippet shows how to import horovod and retrieve the current worker id and the total number of workers.

import horovod.torch as hvd

hvd.init()

print(‘My rank is {} of {} workers‘.format(hvd.rank(), hvd.size()))

hvd.local_rank() is used to get the rank inside a single node, this is useful to assign GPUs, similar to ChainerMN’s intra_rank().

torch.cuda.set_device(hvd.local_rank())

Dataset scattering

Each node can get a slice of a globally shared dataset using a DistributedSampler.

torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=hvd.size(),

                                                rank=hvd.rank())

This will make every worker to only load a slice of the dataset, this sampler can be normally fed to the DataLoader

Optimizer wrapping

The optimizer is wrapped in a hvd.DistributedOptimizer object with the following configuration parameters.

compression : value in {hvd.Compression.fp16, hvd.Compression.none}

compression is used to reduce the size of the allreduce operations performed by the optimizer.

backward_passes_per_step : int default value usually 1

Number of batches that are performed locally before performing the gradients exchange.

optimizer = hvd.DistributedOptimizer(

    optimizer, named_parameters=model.named_parameters(),

    compression=compression,

    backward_passes_per_step=args.batches_per_allreduce)


From the documentation:

DistributedOptimizer exposes the synchronize() method, which forces allreduce operations to finish before continuing the execution. It’s useful in conjunction with gradient clipping, or other operations that modify gradients in place before step()is executed. Make sure to use optimizer.skip_synchronize() if you’re calling synchronize() in your code.

Initial values broadcast

Before starting the training loop, initial model parameters and the optimizer state must be broadcasted to all the workers:

hvd.broadcast_parameters(model.state_dict(), root_rank=0)

hvd.broadcast_optimizer_state(optimizer, root_rank=0)

Metrics average and reductions

When computing the loss and other metrics such as accuracy, the values of multiple workers can be explicitly exchanged to compute averages:

self.sum += hvd.allreduce(val.detach().cpu(), name=metric_name)

Horovod has support to exchange data using other MPI collectives:

There are _async versions of the three functions that can be queried using poll() on the returned handler or synchronize() to wait till completion.

Horovod code structure

import torch

import horovod.torch as hvd

def main():

    # Initialize horovod

    hvd.init()

    torch.cuda.set_device(hvd.local_rank())

    # Read the dataset and create the iterators

    dataset = datasets.ImageFolder(…)

    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=hvd.size(),

                                                rank=hvd.rank())

    loader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, …)
   
# Set up the model, checkpoints, …

    …

    # Create the optimizer

    optimizer = optim.SGD(model.parameters(), …)

    optimizer = hvd.DistributedOptimizer(

        optimizer, named_parameters=model.named_parameters(),

        compression= hvd.Compression.none,

        backward_passes_per_step=args.batches_per_allreduce)

    # Broadcast initial state

    broadcast parameters & optimizer state.

    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Start training

    for epoch in range(epochs):

        train_sampler.set_epoch(epoch)

        …

Obtaining Horovod traces to measure performance

Communication traces showing Horovod communications can be obtained by setting the HOROVOD_TIMELINE environment variable.

mpirun -bind-to-none -np 8 -x HOROVOD_TIMELINE=timeline.json ...

The resultant trace can be visualized in Chrome by using the browser built-in chrome://tracing feature.

Tuning Horovod performance

Horovod has several knobs to improve its performance

Using Horovod with apex

Horovod launches all-reduce in parallel with backward computation, and apex unscales gradient after backward computation.

To avoid race conditions, we have to wait for all-reduce completion before unscaling:

from apex import amp

...

with amp.scale_loss(loss, optimizer) as scaled_loss:

    scaled_loss.backward()

    optimizer.synchronize()  # Wait for all-reduce completion

with optimizer.skip_synchronize():

    optimizer.step()

Also, backward_passes_per_step should be 1 when using Horovod and apex. The current implementation of Horovoda and apex do not work as expected when backward_passes_per_step is not 1.

Multi-Node Batch Normalization in Horovod

Horovod does not yet officially support MNBN (https://github.com/horovod/horovod/issues/1384), but there exists an unofficial implementation: https://github.com/atranitell/Synchronized-BatchNorm-PyTorch-Horovod/blob/master/sync_bn.py. Apex also has an implementation: https://nvidia.github.io/apex/parallel.html#apex.parallel.SyncBatchNorm

Gathering arbitrary objects using Horovod and mpi4py

Horovod supports simultaneous usage with mpi4py (https://github.com/horovod/horovod#mpi4py). You can directly work with mpi4py to e.g. rewrite ChainerMN's comm.gather_obj:

import horovod.torch as hvd

# Initialize Horovod

hvd.init()

# Verify that MPI multi-threading is supported.

assert hvd.mpi_threads_supported()

from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD

assert hvd.size() == mpi_comm.Get_size()

mpi_comm.gather(obj, root=0)  # This is equal to ChainerMN’s comm.gather_obj

Alternatives to Horovod

Horovod is introduced here because it greatly resembles ChainerMN and can be used in our computing infrastructure right away. Alternatives are:

Chainer model using Horovod

To train chainer models in distributed environments using Horovod, the chainer link should be wrapped using cpm.LinkAsTorchModel. The use of a PyTorch optimizer is required.

model = ChainerModel()

model.to_device(ch_device)

# Initialize parameters before converting to `ChainerParameter`s.

model(ch_device.xp.zeros((1, 784)).astype('f'))

# Convert parameters to `ChainerParameter`s to share memory with PyTorch.

torched_model = cpm.LinkAsTorchModel(model)

optimizer = optim.SGD(torched_model.parameters(), lr=args.lr)

optimizer = hvd.DistributedOptimizer(

    optimizer, named_parameters=torched_model.named_parameters())

hvd.broadcast_parameters(torched_model.state_dict(), root_rank=0)

hvd.broadcast_optimizer_state(optimizer, root_rank=0)

PyTorch model using ChainerMN

Using the cpm tool it is also possible to train a PyTorch model using ChainerMN.

The current support is limited only to data parallel training.

from chainer_pytorch_migration import chainermn

comm = chainermn.create_communicator('pure_nccl')

# Set up standard ResNet-50 model.

model = models.resnet50()

model.cuda()

w_model = links.TorchModule(model)

w_model.to_gpu(device)

optimizer = optim.SGD(model.parameters(), lr=lr)

optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

optimizer.setup(w_model)

Porting code that edits the computational graph

Unchaining nodes

Explains differences of how variables can be unchained from the computational graph.

Backprop modes

Explains differences of how backprop modes are switched.

Train/Test modes

Explains differences of how train/test modes are switched.

Ecosystem

This section introduces some of the larger repositories under the PyTorch GitHub organization. It also refers to the official list of other ecosystem-libraries acknowledged by PyTorch.

PyTorch

Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration

GitHub: https://github.com/pytorch/pytorch

Ignite

Summary: High level utilities such as training loop abstraction.

GitHub: https://github.com/pytorch/ignite

torchvision

Summary: PyTorch for CV.

GitHub: https://github.com/pytorch/vision

Recommended by the official installation guide to install along with pytorch.

Provides domain-agnostic (not limited to CV) data augmentation functionality.

Provides loaders for video data. Slow due to ffmpeg but this might be improved in the future?

torchtext

Summary: PyTorch for NLP.

GitHub: https://github.com/pytorch/text

torchaudio

Summary: PyTorch for audio data.

GitHub: https://github.com/pytorch/audio

Fairseq

Summary: Seq2seq models.

GitHub: https://github.com/pytorch/fairseq

Seq2seq models such as translation. Includes the Transformer and BERT-like models.

Other

There is an official list of libraries included in the PyTorch ecosystem (besides the domain specific libraries above), including e.g. Ignite.

https://pytorch.org/ecosystem

[a]We can do it with `expand` https://pytorch.org/docs/stable/tensors.html#torch.Tensor.expand