Welcome to rotor’s documentation!¶
Contents:
rotor: Rematerializing Optimally with pyTORch¶
Description¶
Purpose: This code is meant to replace torch/utils/checkpoint.py, by providing more efficient checkpointing strategies. The algorithm is easier to tune, as the required parameter as input is the available memory instead of the number of segments.
For more details about the algorithm inside rotor
, and for experimental validation, see our paper at
https://hal.inria.fr/hal-02352969.
Installation¶
Beside the obvious requirement on pytorch
, rotor
has only one prerequisite: psutil
. For installation, just run
python setup.py install
This will also compile the C extension which contains faster implementations of the dynamic programs.
Documentation¶
Usage¶
Standard usage¶
The main class is rotor.Checkpointable
. To use it:
Describe your module as a
torch.Sequential
moduleCreate a sequential module from it:
chk = Checkpointable(module)
Provide a sample input to perform timing and memory measurements of your module
input = torch.rand(shape) chk.measure(input)
Compute the optimal sequence for a given memory limit (in bytes)
chk.compute_sequence(500*1024*1024)
Use the Checkpointable as a module
output = chk(input) output.backward() grad = input.grad
As an alternative, the sample input and memory limit can be specified in the construction of Checkpointable
:
input = torch.rand(shape)
chk = Checkpointable(module, input, 500*1024*1024)
output = chk(input)
Or only the sample input:
input = torch.rand(shape)
chk = Checkpointable(module, input)
chk.compute_sequence(500*1024*1024)
output = chk(input)
Implemented models¶
Rotor also contains a suitable adaptation of the main models available in torchvision
: ResNet, VGG, Inception, and
Densenet, available respectively as rotor.resnet
, rotor.vgg
, rotor.inception
, and
rotor.densenet
. Each of these modules has the same interface as the corresponding torchvision
version.
Complete example¶
Here is a complete working example, which runs ResNet18 with a batch size of 32 with a memory usage of 700MB, whereas making the same run without checkpointing would use 843MB:
import rotor
import torch
device = torch.device("cuda")
net = rotor.models.resnet18()
net.to(device=device)
net_check = rotor.Checkpointable(net)
shape = (32, 3, 224, 224)
sample = torch.rand(*shape, device=device)
net_check.measure(sample)
net_check.compute_sequence(mem_limit = 700*1024*1024)
data = torch.rand(*shape, device=device)
data.requires_grad = True
result = net_check(data).sum()
result.backward()
grad = data.grad
Notes:¶
As of now, for technical reasons rotor does not use the
save_for_backward()
interface that detects when tensors are changed between the forward and the backward phase. Make sure to not modify the tensors.Just like
torch/utils/checkpoint.py
,Checkpointable
preserves RNG states by default at each checkpoint, and restores it during the backward phase.Unclear to us how to portably get CPU resource usage information. Our code (
memory.py
) usesimport resource
to allow measuring resource usage even when not running on CUDA. If portability to non-Unix platforms is important, either this will require to remove this feature or use a more portable alternative.