Implement callbacks using hooks

Hook callbacks

This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any nn.Module, for either the forward or the backward pass.

We'll start by looking at the pre-defined hook ActivationStats, then we'll see how to create our own.

class ActivationStats[source][test]

ActivationStats(learn:Learner, modules:Sequence[Module]=None, do_remove:bool=True) :: HookCallback

No tests found for ActivationStats. To contribute a test please refer to this guide and this discussion.

Callback that record the mean and std of activations.

ActivationStats saves the layer activations in self.stats for all modules passed to it. By default it will save activations for all modules. For instance:

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
#learn = cnn_learner(data, models.resnet18, callback_fns=ActivationStats)
learn = Learner(data, simple_cnn((3,16,16,2)), callback_fns=ActivationStats)
learn.fit(1)
epoch train_loss valid_loss time
0 0.142666 0.101166 00:03

The saved stats is a FloatTensor of shape (2,num_modules,num_batches). The first axis is (mean,stdev).

len(learn.data.train_dl),len(learn.activation_stats.modules)
(193, 3)
learn.activation_stats.stats.shape
torch.Size([2, 3, 193])

So this shows the standard deviation (axis0==1) of 2th last layer (axis1==-2) for each batch (axis2):

plt.plot(learn.activation_stats.stats[1][-2].numpy());

Internal implementation

hook[source][test]

hook(m:Module, i:Tensors, o:Tensors) → Tuple[Rank0Tensor, Rank0Tensor]

Tests found for hook:

Some other tests where hook is used:

  • pytest -sv tests/test_callbacks_hooks.py::test_hook_output_basics [source]

To run tests please refer to this guide.

Take the mean and std of o.

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality.

on_train_begin[source][test]

on_train_begin(**kwargs)

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

Initialize stats.

on_batch_end[source][test]

on_batch_end(train, **kwargs)

No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

Take the stored results and puts it in self.stats

on_train_end[source][test]

on_train_end(**kwargs)

No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.

Polish the final result.

class Hook[source][test]

Hook(m:Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True)

No tests found for Hook. To contribute a test please refer to this guide and this discussion.

Create a hook on m with hook_func.

Registers and manually deregisters a PyTorch hook. Your hook_func will be called automatically when forward/backward (depending on is_forward) for your module m is run, and the result of that function is placed in self.stored.

remove[source][test]

remove()

No tests found for remove. To contribute a test please refer to this guide and this discussion.

Remove the hook from the model.

Deregister the hook, if not called already.

class Hooks[source][test]

Hooks(ms:ModuleList, hook_func:HookFunc, is_forward:bool=True, detach:bool=True)

No tests found for Hooks. To contribute a test please refer to this guide and this discussion.

Create several hooks on the modules in ms with hook_func.

Acts as a Collection (i.e. len(hooks) and hooks[i]) and an Iterator (i.e. for hook in hooks) of a group of hooks, one for each module in ms, with the ability to remove all as a group. Use stored to get all hook results. hook_func and is_forward behavior is the same as Hook. See the source code for HookCallback for a simple example.

remove[source][test]

remove()

No tests found for remove. To contribute a test please refer to this guide and this discussion.

Remove the hooks from the model.

Deregister all hooks created by this class, if not previously called.

Convenience functions for hooks

hook_output[source][test]

hook_output(module:Module, detach:bool=True, grad:bool=False) → Hook

Tests found for hook_output:

  • pytest -sv tests/test_callbacks_hooks.py::test_hook_output_basics [source]

To run tests please refer to this guide.

Return a Hook that stores activations of module in self.stored

Function that creates a Hook for module that simply stores the output of the layer.

hook_outputs[source][test]

hook_outputs(modules:ModuleList, detach:bool=True, grad:bool=False) → Hooks

No tests found for hook_outputs. To contribute a test please refer to this guide and this discussion.

Return Hooks that store activations of all modules in self.stored

Function that creates a Hook for all passed modules that simply stores the output of the layers. For example, the (slightly simplified) source code of model_sizes is:

def model_sizes(m, size):
    x = m(torch.zeros(1, in_channels(m), *size))
    return [o.stored.shape for o in hook_outputs(m)]

model_sizes[source][test]

model_sizes(m:Module, size:tuple=(64, 64)) → Tuple[Sizes, Tensor, Hooks]

No tests found for model_sizes. To contribute a test please refer to this guide and this discussion.

Pass a dummy input through the model m to get the various sizes of activations.

model_summary[source][test]

model_summary(m:Learner, n:int=70)

Tests found for model_summary:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision [source]

To run tests please refer to this guide.

Print a summary of m using a output text width of n chars

This method only works on a Learner object with train_ds in it. If it was created as a result of load_learner, there is no data to run through the model and therefore it's not possible to create such summary.

A sample summary looks like:

======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Conv2d               [64, 176, 176]       9,408      False     
______________________________________________________________________
BatchNorm2d          [64, 176, 176]       128        True      
______________________________________________________________________
ReLU                 [64, 176, 176]       0          False     
______________________________________________________________________
MaxPool2d            [64, 88, 88]         0          False     
______________________________________________________________________
Conv2d               [64, 88, 88]         36,864     False     
...

Column definition:

  1. Layer (type) is the name of the corresponding nn.Module.

  2. Output Shape is the shape of the output of the corresponding layer (minus the batch dimension, which is always the same and has no impact on the model params).

  3. Param # is the number of weights (and optionally bias), and it will vary for each layer.

    The number of params is calculated differently for each layer type. Here is how it's calculated for some of the most common layer types:

    • Conv: kernel_size*kernel_size*ch_in*ch_out
    • Linear: (n_in+bias) * n_out
    • Batchnorm: 2 * n_out
    • Embeddings: n_embed * emb_sz
  4. Trainable indicates whether a layer is trainable or not.

    • Layers with 0 parameters are always Untrainable (e.g., ReLU and MaxPool2d).
    • Other layers are either Trainable or not, usually depending on whether they are frozen or not. See Discriminative layer training.

To better understand this summary it helps to also execute learn.model and correlate the two outputs.

Example:

Let's feed to a Learner a dataset of 3-channel images size 352x352 and look at the model and its summary:

data.train_ds[0][0].data.shape
learn = cnn_learner(data, models.resnet34, ...)
print(learn.model)
print(learn.summary())

Here are the outputs with everything but the relevant to the example lines removed:

torch.Size([3, 352, 352])

    [...]
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    [...]
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    [...]
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): Linear(in_features=512, out_features=37, bias=True)


======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Conv2d               [64, 176, 176]       9,408      False     
______________________________________________________________________
BatchNorm2d          [64, 176, 176]       128        True      
______________________________________________________________________
[...]
MaxPool2d            [64, 88, 88]         0          False    
______________________________________________________________________
Conv2d               [64, 88, 88]         36,864     False     
[...]
______________________________________________________________________
Linear               [37]                 18,981     True

So let's calculate some params:

For the Conv2d layers, multiply the first 4 numbers from the corresponding layer definition:

Conv2d(3, 64, kernel_size=(7, 7), ...)

3*64*7*7 = 9,408

Conv2d(64, 64, kernel_size=(3, 3), ...)

64*64*3*3 = 36,864

For the BatchNorm2d layer, multiply the first number by 2:

BatchNorm2d(64, ...)
64*2 = 128

For Linear we multiply the first 2 and include the bias if it's True:

Linear(in_features=512, out_features=37, bias=True)

(512+1)*37 = 18,981

Now let's calculate some output shapes:

We started with 3x352x352 image and run it through this layer:

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

How did we get: [64, 176, 176]

The number of output channels is 64, that's the first dimension in the number above. And then our image of 352x352 got convolved into 176x176 because of stride 2x2 (352/2).

Then we had:

MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

which reduced [64, 176, 176] to [64, 88, 88] again because of stride 2.

And so on, finishing with:

Linear(in_features=512, out_features=37, bias=True)

which reduced everything to just [37].

num_features_model[source][test]

num_features_model(m:Module) → int

No tests found for num_features_model. To contribute a test please refer to this guide and this discussion.

Return the number of output features for model.

It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a DynamicUnet), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of size.

dummy_batch[source][test]

dummy_batch(m:Module, size:tuple=(64, 64)) → Tensor

No tests found for dummy_batch. To contribute a test please refer to this guide and this discussion.

Create a dummy batch to go through m with size.

dummy_eval[source][test]

dummy_eval(m:Module, size:tuple=(64, 64))

No tests found for dummy_eval. To contribute a test please refer to this guide and this discussion.

Pass a dummy_batch in evaluation mode in m with size.

class HookCallback[source][test]

HookCallback(learn:Learner, modules:Sequence[Module]=None, do_remove:bool=True) :: LearnerCallback

No tests found for HookCallback. To contribute a test please refer to this guide and this discussion.

Callback that can be used to register hooks on modules. Implement the corresponding function in self.hook.

For all modules, uses a callback to automatically register a method self.hook (that you must define in an inherited class) as a hook. This method must have the signature:

def hook(self, m:Model, input:Tensors, output:Tensors)

If do_remove then the hook is automatically deregistered at the end of training. See ActivationStats for a simple example of inheriting from this class.

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality.

on_train_begin[source][test]

on_train_begin(**kwargs)

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

Register the Hooks on self.modules.

on_train_end[source][test]

on_train_end(**kwargs)

No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.

Remove the Hooks.