Learner class and training loop

Basic training functionality

basic_train wraps together the data (in a DataBunch object) with a PyTorch model to define a Learner object. Here the basic training loop is defined for the fit method. The Learner object is the entry point of most of the Callback objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the train module, notably:

  • Learner.lr_find will launch an LR range test that will help you select a good learning rate.
  • Learner.fit_one_cycle will launch a training using the 1cycle policy to help you train your model faster.
  • Learner.to_fp16 will convert your model to half precision and help you launch a training in mixed precision.

class Learner[source][test]

Learner(data:DataBunch, model:Module, opt_func:Callable='Adam', loss_func:Callable=None, metrics:Collection[Callable]=None, true_wd:bool=True, bn_wd:bool=True, wd:Floats=0.01, train_bn:bool=True, path:str=None, model_dir:PathOrStr='models', callback_fns:Collection[Callable]=None, callbacks:Collection[Callback]=<factory>, layer_groups:ModuleList=None, add_time:bool=True, silent:bool=None)

Tests found for Learner:

Some other tests where Learner is used:

  • pytest -sv tests/test_basic_train.py::test_memory [source]

To run tests please refer to this guide.

Trainer for model using data to minimize loss_func with optimizer opt_func.

The main purpose of Learner is to train model using Learner.fit. After every epoch, all metrics will be printed and also made available to callbacks.

The default weight decay will be wd, which will be handled using the method from Fixing Weight Decay Regularization in Adam if true_wd is set (otherwise it's L2 regularization). If true_wd is set it will affect all optimizers, not only Adam. If bn_wd is False, then weight decay will be removed from batchnorm layers, as recommended in Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. If train_bn, batchnorm layer learnable params are trained even for frozen layer groups.

To use discriminative layer training, pass a list of nn.Module as layer_groups; each nn.Module will be used to customize the optimization of the corresponding layer group.

If path is provided, all the model files created will be saved in path/model_dir; if not, then they will be saved in data.path/model_dir.

You can pass a list of callbacks that you have already created, or (more commonly) simply pass a list of callback functions to callback_fns and each function will be called (passing self) on object initialization, with the results stored as callback objects. For a walk-through, see the training overview page. You may also want to use an application specific model. For example, if you are dealing with a vision dataset, here the MNIST, you might want to use the cnn_learner method:

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = cnn_learner(data, models.resnet18, metrics=accuracy)

Model fitting methods

lr_find[source][test]

lr_find(learn:Learner, start_lr:Floats=1e-07, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None)

Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

Explore lr from start_lr to end_lr over num_it iterations in learn. If stop_div, stops when loss diverges.

Runs the learning rate finder defined in LRFinder, as discussed in Cyclical Learning Rates for Training Neural Networks.

learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()

fit[source][test]

fit(epochs:int, lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), wd:Floats=None, callbacks:Collection[Callback]=None)

Tests found for fit:

  • pytest -sv tests/test_train.py::test_fit [source]

Some other tests where fit is used:

  • pytest -sv tests/test_basic_train.py::test_destroy [source]

To run tests please refer to this guide.

Fit the model on this learner with lr learning rate, wd weight decay for epochs with callbacks.

Uses discriminative layer training if multiple learning rates or weight decay values are passed. To control training behaviour, use the callback system or one or more of the pre-defined callbacks.

learn.fit(1)
epoch train_loss valid_loss accuracy time
1 0.135343 0.083190 0.972031 00:05

fit_one_cycle[source][test]

fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[float, Collection[float], slice]=slice(None, 0.003, None), moms:Point=(0.95, 0.85), div_factor:float=25.0, pct_start:float=0.3, final_div:float=None, wd:float=None, callbacks:Optional[Collection[Callback]]=None, tot_epochs:int=None, start_epoch:int=None)

Tests found for fit_one_cycle:

  • pytest -sv tests/test_train.py::test_fit_one_cycle [source]

Some other tests where fit_one_cycle is used:

  • pytest -sv tests/test_tabular_train.py::test_empty_cont [source]
  • pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided [source]
  • pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split [source]

To run tests please refer to this guide.

Fit a model following the 1cycle policy.

Use cycle length cyc_len, a per cycle maximal learning rate max_lr, momentum moms, division factor div_factor, weight decay wd, and optional callbacks callbacks. Uses the OneCycleScheduler callback. Please refer to What is 1-cycle for a conceptual background of 1-cycle training policy and more technical details on what do the method's arguments do.

learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
1 0.075838 0.061869 0.979882 00:05

See results

predict[source][test]

predict(item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=False, **kwargs)

Tests found for predict:

  • pytest -sv tests/test_vision_train.py::test_models_meta [source]
  • pytest -sv tests/test_vision_train.py::test_preds [source]

To run tests please refer to this guide.

Return predicted class, label and probabilities for item.

predict can be used to get a single prediction from the trained learner on one specific piece of data you are interested in.

learn.data.train_ds[0]
(Image (3, 28, 28), <fastai.core.Category at 0x7fb0e0dee1d0>)

Each element of the dataset is a tuple, where the first element is the data itself, while the second element is the target label. So to get the data, we need to index one more time.

data = learn.data.train_ds[0][0]
data
pred = learn.predict(data)
pred
(<fastai.core.Category at 0x7fb0e02f29b0>, tensor(0), tensor([0.5748, 0.4252]))

The first two elements of the tuple are, respectively, the predicted class and label. Label here is essentially an internal representation of each class, since class name is a string and cannot be used in computation. To check what each label corresponds to, run:

learn.data.classes
['3', '7']

So category 0 is 3 while category 1 is 7.

probs = pred[2]

The last element in the tuple is the predicted probabilities. For a categorization dataset, the number of probabilities returned is the same as the number of classes; probs[i] is the probability that the item belongs to learn.data.classes[i].

learn.data.valid_ds[0][0]

You could always check yourself if the probabilities given make sense.

get_preds[source][test]

get_preds(ds_type:DatasetType=<DatasetType.Valid: 2>, activ:Module=None, with_loss:bool=False, n_batch:Optional[int]=None, pbar:Union[MasterBar, ProgressBar, NoneType]=None) → List[Tensor]

Tests found for get_preds:

  • pytest -sv tests/test_basic_train.py::test_get_preds [source]

To run tests please refer to this guide.

Return predictions and targets on ds_type dataset.

It will run inference using the learner on all the data in the ds_type dataset and return the predictions; if n_batch is not specified, it will run the predictions on the default batch size. If with_loss, it will also return the loss on each prediction.

Here is how you check the default batch size.

learn.data.batch_size
64
preds = learn.get_preds()
preds
[tensor([[9.9925e-01, 7.4895e-04],
         [9.8333e-01, 1.6672e-02],
         [9.9996e-01, 3.8919e-05],
         ...,
         [1.6180e-04, 9.9984e-01],
         [2.5164e-02, 9.7484e-01],
         [1.8179e-02, 9.8182e-01]]), tensor([0, 0, 0,  ..., 1, 1, 1])]

The first element of the tuple is a tensor that contains all the predictions.

preds[0]
tensor([[9.9925e-01, 7.4895e-04],
        [9.8333e-01, 1.6672e-02],
        [9.9996e-01, 3.8919e-05],
        ...,
        [1.6180e-04, 9.9984e-01],
        [2.5164e-02, 9.7484e-01],
        [1.8179e-02, 9.8182e-01]])

While the second element of the tuple is a tensor that contains all the target labels.

preds[1]
tensor([0, 0, 0,  ..., 1, 1, 1])
preds[1][0]
tensor(0)

For more details about what each number mean, refer to the documentation of predict.

Since get_preds gets predictions on all the data in the ds_type dataset, here the number of predictions will be equal to the number of data in the validation dataset.

len(learn.data.valid_ds)
2038
len(preds[0]), len(preds[1])
(2038, 2038)

To get predictions on the entire training dataset, simply set the ds_type argument accordingly.

learn.get_preds(ds_type=DatasetType.Train)
[tensor([[9.9801e-01, 1.9876e-03],
         [1.7900e-06, 1.0000e+00],
         [1.3191e-03, 9.9868e-01],
         ...,
         [9.9991e-01, 8.6866e-05],
         [1.6420e-04, 9.9984e-01],
         [2.2937e-03, 9.9771e-01]]), tensor([0, 1, 1,  ..., 0, 1, 1])]

To also get prediction loss along with the predictions and the targets, set with_loss=True in the arguments.

learn.get_preds(with_loss=True)
[tensor([[9.9925e-01, 7.4895e-04],
         [9.8333e-01, 1.6672e-02],
         [9.9996e-01, 3.8919e-05],
         ...,
         [1.6180e-04, 9.9984e-01],
         [2.5164e-02, 9.7484e-01],
         [1.8179e-02, 9.8182e-01]]),
 tensor([0, 0, 0,  ..., 1, 1, 1]),
 tensor([7.4911e-04, 1.6813e-02, 3.8624e-05,  ..., 1.6165e-04, 2.5486e-02,
         1.8347e-02])]

Note that the third tensor in the output tuple contains the losses.

validate[source][test]

validate(dl=None, callbacks=None, metrics=None)

Tests found for validate:

  • pytest -sv tests/test_collab_train.py::test_val_loss [source]
  • pytest -sv tests/test_text_train.py::test_val_loss [source]

To run tests please refer to this guide.

Validate on dl with potential callbacks and metrics.

Return the calculated loss and the metrics of the current model on the given data loader dl. The default data loader dl is the validation dataloader.

You can check the default metrics of the learner using:

str(learn.metrics)
'[<function accuracy at 0x7fb0e1880d08>]'
learn.validate()
[0.061868817, tensor(0.9799)]
learn.validate(learn.data.valid_dl)
[0.061868817, tensor(0.9799)]
learn.validate(learn.data.train_dl)
[0.036164965, tensor(0.9887)]

show_results[source][test]

show_results(ds_type=<DatasetType.Valid: 2>, rows:int=5, **kwargs)

No tests found for show_results. To contribute a test please refer to this guide and this discussion.

Show rows result of predictions on ds_type dataset.

Note that the text number on the top is the ground truth, or the target label, the one in the middle is the prediction, while the image number on the bottom is the image data itself.

learn.show_results()
learn.show_results(ds_type=DatasetType.Train)

pred_batch[source][test]

pred_batch(ds_type:DatasetType=<DatasetType.Valid: 2>, batch:Tuple=None, reconstruct:bool=False, with_dropout:bool=False, activ:Module=None) → List[Tensor]

No tests found for pred_batch. To contribute a test please refer to this guide and this discussion.

Return output of the model on one batch from ds_type dataset.

Note that the number of predictions given equals to the batch size.

learn.data.batch_size
64
preds = learn.pred_batch()
len(preds)
64

Since the total number of predictions is too large, we will only look at a part of them.

preds[:10]
tensor([[9.9925e-01, 7.4895e-04],
        [9.8333e-01, 1.6672e-02],
        [9.9996e-01, 3.8919e-05],
        [9.9998e-01, 1.7812e-05],
        [9.9993e-01, 6.8040e-05],
        [9.9533e-01, 4.6744e-03],
        [9.9838e-01, 1.6157e-03],
        [1.0000e+00, 1.4298e-06],
        [9.9942e-01, 5.8188e-04],
        [9.9999e-01, 1.2754e-05]])
item = learn.data.train_ds[0][0]
item
batch = learn.data.one_item(item)
batch
(tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0'),
 tensor([0], device='cuda:0'))
learn.pred_batch(batch=batch)
tensor([[0.5748, 0.4252]])

interpret[source][test]

interpret(learn:Learner, ds_type:DatasetType=<DatasetType.Valid: 2>, tta=False)

Tests found for _learner_interpret:

  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

Create a ClassificationInterpretation object from learner on ds_type with tta.

For more details, refer to ClassificationInterpretation

Model summary

model_summary[source][test]

model_summary(m:Learner, n:int=70)

Tests found for model_summary:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision [source]

To run tests please refer to this guide.

Print a summary of m using a output text width of n chars

Test time augmentation

TTA[source][test]

TTA(learn:Learner, beta:float=0.4, scale:float=1.35, ds_type:DatasetType=<DatasetType.Valid: 2>, activ:Module=None, with_loss:bool=False) → Tensors

No tests found for _TTA. To contribute a test please refer to this guide and this discussion.

Applies TTA to predict on ds_type dataset.

Applies Test Time Augmentation to learn on the dataset ds_type. We take the average of our regular predictions (with a weight beta) with the average of predictions obtained through augmented versions of the training set (with a weight 1-beta). The transforms decided for the training set are applied with a few changes scale controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total).

Gradient clipping

clip_grad[source][test]

clip_grad(learn:Learner, clip:float=0.1) → Learner

No tests found for clip_grad. To contribute a test please refer to this guide and this discussion.

Add gradient clipping of clip during training.

Mixed precision training

to_fp16[source][test]

to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None, flat_master:bool=False, max_scale:float=16777216, loss_fp32:bool=True) → Learner

No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.

Put learn in FP16 precision mode.

Uses the MixedPrecision callback to train in mixed precision (i.e. forward and backward passes using fp16, with weight updates using fp32), using all NVIDIA recommendations for ensuring speed and accuracy.

to_fp32[source][test]

to_fp32(learn:Learner)

No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.

Put learn back to FP32 precision mode.

Distributed training

If you want to use ditributed training or torch.nn.DataParallel these will directly wrap the model for you.

to_distributed[source][test]

to_distributed(learn:Learner, cuda_id:int, cache_dir:PathOrStr='tmp')

No tests found for _learner_distributed. To contribute a test please refer to this guide and this discussion.

Put learn on distributed training with cuda_id.

to_parallel[source][test]

to_parallel(learn:Learner)

No tests found for _learner_parallel. To contribute a test please refer to this guide and this discussion.

Use nn.DataParallel when training and remove when done

Discriminative layer training

When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each layer group (i.e. the parameters of each module in self.layer_groups). See the Universal Language Model Fine-tuning for Text Classification paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a Learner on which you've called split, you can set hyperparameters in four ways:

  1. param = [val1, val2 ..., valn] (n = number of layer groups)
  2. param = val
  3. param = slice(start,end)
  4. param = slice(end)

If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See Learner.lr_range for an explanation of the slice syntax).

Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call Learner.split in this case, since fastai uses this exact function as the default split for resnet18; this is just to show how to customize it):

# creates 3 layer groups
learn.split(lambda m: (m[0][6], m[1]))
# only randomly initialized head now trainable
learn.freeze()
learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
1 0.059613 0.054604 0.981845 00:05
# all layers now trainable
learn.unfreeze()
# optionally, separate LR and WD for each group
learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1))
epoch train_loss valid_loss accuracy time
1 0.026379 0.008763 0.998037 00:07

lr_range[source][test]

lr_range(lr:Union[float, slice]) → ndarray

No tests found for lr_range. To contribute a test please refer to this guide and this discussion.

Build differential learning rates from lr.

Rather than manually setting an LR for every group, it's often easier to use Learner.lr_range. This is a convenience method that returns one learning rate for each layer group. If you pass slice(start,end) then the first group's learning rate is start, the last is end, and the remaining are evenly geometrically spaced.

If you pass just slice(end) then the last group's learning rate is end, and all the other groups are end/10. For instance (for our learner that has 3 layer groups):

learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(1e-3))
(array([1.e-05, 1.e-04, 1.e-03]), array([0.0001, 0.0001, 0.001 ]))

unfreeze[source][test]

unfreeze()

Tests found for unfreeze:

  • pytest -sv tests/test_basic_train.py::test_unfreeze [source]

To run tests please refer to this guide.

Unfreeze entire model.

Sets every layer group to trainable (i.e. requires_grad=True).

freeze[source][test]

freeze()

Tests found for freeze:

  • pytest -sv tests/test_basic_train.py::test_freeze [source]

To run tests please refer to this guide.

Freeze up to last layer group.

Sets every layer group except the last to untrainable (i.e. requires_grad=False).

What does 'the last layer group' mean?

In the case of transfer learning, such as learn = cnn_learner(data, models.resnet18, metrics=error_rate), learn.modelwill print out two large groups of layers: (0) Sequential and (1) Sequental in the following structure. We can consider the last conv layer as the break line between the two groups.

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    ...

            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
             (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (1): Sequential(
    (0): AdaptiveConcatPool2d(
      (ap): AdaptiveAvgPool2d(output_size=1)
      (mp): AdaptiveMaxPool2d(output_size=1)
    )
    (1): Flatten()
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.25)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): ReLU(inplace)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5)
    (8): Linear(in_features=512, out_features=12, bias=True)
  )
)

learn.freeze freezes the first group and keeps the second or last group free to train, including multiple layers inside (this is why calling it 'group'), as you can see in learn.summary() output. How to read the table below, please see model summary docs.

======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
...
...
...
______________________________________________________________________
Conv2d               [1, 512, 4, 4]       2,359,296  False     
______________________________________________________________________
BatchNorm2d          [1, 512, 4, 4]       1,024      True      
______________________________________________________________________
AdaptiveAvgPool2d    [1, 512, 1, 1]       0          False     
______________________________________________________________________
AdaptiveMaxPool2d    [1, 512, 1, 1]       0          False     
______________________________________________________________________
Flatten              [1, 1024]            0          False     
______________________________________________________________________
BatchNorm1d          [1, 1024]            2,048      True      
______________________________________________________________________
Dropout              [1, 1024]            0          False     
______________________________________________________________________
Linear               [1, 512]             524,800    True      
______________________________________________________________________
ReLU                 [1, 512]             0          False     
______________________________________________________________________
BatchNorm1d          [1, 512]             1,024      True      
______________________________________________________________________
Dropout              [1, 512]             0          False     
______________________________________________________________________
Linear               [1, 12]              6,156      True      
______________________________________________________________________

Total params: 11,710,540
Total trainable params: 543,628
Total non-trainable params: 11,166,912

freeze_to[source][test]

freeze_to(n:int)

Tests found for freeze_to:

  • pytest -sv tests/test_basic_train.py::test_freeze_to [source]

To run tests please refer to this guide.

Freeze layers up to layer group n.

From above we know what is layer group, but what exactly does freeze_to do behind the scenes?

The freeze_to source code can be understood as the following pseudo-code:

def freeze_to(self, n:int)->None:
    for g in self.layer_groups[:n]: freeze 
    for g in self.layer_groups[n:]: unfreeze

In other words, for example, freeze_to(1) is to freeze layer group 0 and unfreeze the rest layer groups, and freeze_to(3) is to freeze layer groups 0, 1, and 2 but unfreeze the rest layer groups (if there are more layer groups left).

Both freeze and unfreeze sources are defined using freeze_to:

  • When we say freeze, we mean that in the specified layer groups the requires_grad of all layers with weights (except BatchNorm layers) are set False, so the layer weights won't be updated during training.
  • when we say unfreeze, we mean that in the specified layer groups the requires_grad of all layers with weights (except BatchNorm layers) are set True, so the layer weights will be updated during training.

split[source][test]

split(split_on:SplitFuncOrIdxList)

No tests found for split. To contribute a test please refer to this guide and this discussion.

Split the model at split_on.

A convenience method that sets layer_groups based on the result of split_model. If split_on is a function, it calls that function and passes the result to split_model (see above for example).

Saving and loading models

Simply call Learner.save and Learner.load to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the path/model_dir directory.

save[source][test]

save(file:PathLikeOrBinaryStream=None, return_path:bool=False, with_opt:bool=True)

Tests found for save:

  • pytest -sv tests/test_basic_train.py::test_memory [source]
  • pytest -sv tests/test_basic_train.py::test_save_load [source]

Some other tests where save is used:

  • pytest -sv tests/test_basic_train.py::test_model_load_mem_leak [source]

To run tests please refer to this guide.

Save model and optimizer state (if with_opt) with file to self.model_dir. file can be file-like (file or buffer)

If argument file is a pathlib object that's an absolute path, it'll override the default base directory (learn.path), otherwise the model will be saved in a file relative to learn.path.

learn.save("trained_model")
learn.save("trained_model", return_path=True)
PosixPath('/home/ubuntu/.fastai/data/mnist_sample/models/trained_model.pth')

load[source][test]

load(file:PathLikeOrBinaryStream=None, device:device=None, strict:bool=True, with_opt:bool=None, purge:bool=False, remove_module:bool=False) → Learner

Tests found for load:

  • pytest -sv tests/test_basic_train.py::test_memory [source]
  • pytest -sv tests/test_basic_train.py::test_save_load [source]

Some other tests where load is used:

  • pytest -sv tests/test_basic_train.py::test_model_load_mem_leak [source]

To run tests please refer to this guide.

Load model and optimizer state (if with_opt) file from self.model_dir using device. file can be file-like (file or buffer)

This method only works after save (don't confuse with export/load_learner pair).

If the purge argument is True, load internally calls purge with clear_opt=False to presever learn.opt.

learn = learn.load("trained_model")

Deploying your model

When you are ready to put your model in production, export the minimal state of your Learner with:

export[source][test]

export(file:PathLikeOrBinaryStream='export.pkl', destroy=False)

Tests found for export:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]

To run tests please refer to this guide.

Export the state of the Learner in self.path/file. file can be file-like (file or buffer)

If argument file is a pathlib object that's an absolute path, it'll override the default base directory (learn.path), otherwise the model will be saved in a file relative to learn.path.

Passing destroy=True will destroy the Learner, freeing most of its memory consumption. For specifics see Learner.destroy.

This method only works with the Learner whose data was created through the data block API.

Otherwise, you will have to create a Learner yourself at inference and load the model with Learner.load.

learn.export()
learn.export('trained_model.pkl')
path = learn.path
path
PosixPath('/home/ubuntu/.fastai/data/mnist_sample')

load_learner[source][test]

load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, tfm_y=None, **db_kwargs)

Tests found for load_learner:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]

To run tests please refer to this guide.

Load a Learner object saved with export_state in path/file with empty data, optionally add test and load on cpu. file can be file-like (file or buffer)

This function only works after export (don't confuse with save/load pair).

The db_kwargs will be passed to the call to databunch so you can specify a bs for the test set, or num_workers.

learn = load_learner(path)
learn = load_learner(path, 'trained_model.pkl')

WARNING: If you used any customized classes when creating your learner, you must first define these classes first before executing load_learner.

You can find more information and multiple examples in this tutorial.

Freeing memory

If you want to be able to do more without needing to restart your notebook, the following methods are designed to free memory when it's no longer needed.

Refer to this tutorial to learn how and when to use these methods.

purge[source][test]

purge(clear_opt:bool=True)

Tests found for purge:

  • pytest -sv tests/test_basic_train.py::test_memory [source]
  • pytest -sv tests/test_basic_train.py::test_purge [source]
  • pytest -sv tests/test_basic_train.py::test_save_load [source]

To run tests please refer to this guide.

Purge the Learner of all cached attributes to release some GPU memory.

If learn.path is read-only, you can set model_dir attribute in Learner to a full libpath path that is writable (by setting learn.model_dir or passing model_dir argument in the Learner constructor).

destroy[source][test]

destroy()

Tests found for destroy:

  • pytest -sv tests/test_basic_train.py::test_destroy [source]
  • pytest -sv tests/test_basic_train.py::test_memory [source]

To run tests please refer to this guide.

Free the Learner internals, leaving just an empty shell that consumes no memory

If you need to free the memory consumed by the Learner object, call this method.

It can also be automatically invoked through Learner.export via its destroy=True argument.

Other methods

init[source][test]

init(init)

No tests found for init. To contribute a test please refer to this guide and this discussion.

Initializes all weights (except batchnorm) using function init, which will often be from PyTorch's nn.init module.

mixup[source][test]

mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) → Learner

No tests found for mixup. To contribute a test please refer to this guide and this discussion.

Add mixup https://arxiv.org/abs/1710.09412 to learn.

backward[source][test]

backward(item)

No tests found for backward. To contribute a test please refer to this guide and this discussion.

Pass item through the model and computes the gradient. Useful if backward_hooks are attached.

create_opt[source][test]

create_opt(lr:Floats, wd:Floats=0.0)

No tests found for create_opt. To contribute a test please refer to this guide and this discussion.

Create optimizer with lr learning rate and wd weight decay.

You generally won't need to call this yourself - it's used to create the optim optimizer before fitting the model.

dl[source][test]

dl(ds_type:DatasetType=<DatasetType.Valid: 2>)

No tests found for dl. To contribute a test please refer to this guide and this discussion.

Return DataLoader for DatasetType ds_type.

learn.dl()
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7fb0e1700a58>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7fb0e24e92f0>)
learn.dl(DatasetType.Train)
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7fb0cae08a20>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7fb0e24e92f0>)

class Recorder[source][test]

Recorder(learn:Learner, add_time:bool=True, silent:bool=False) :: LearnerCallback

Tests found for Recorder:

  • pytest -sv tests/test_vision_train.py::test_1cycle_lrs [source]
  • pytest -sv tests/test_vision_train.py::test_1cycle_moms [source]

To run tests please refer to this guide.

A LearnerCallback that records epoch, loss, opt and metric data during training.

A Learner creates a Recorder object automatically - you do not need to explicitly pass it to callback_fns - because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics for each batch, and provides plotting methods for each. Note that Learner automatically sets an attribute with the snake-cased name of each callback, so you can access this through Learner.recorder, as shown below.

Plotting methods

plot[source][test]

plot(skip_start:int=10, skip_end:int=5, suggestion:bool=False, return_fig:bool=None, show_grid:bool=False, **kwargs) → Optional[Figure]

No tests found for plot. To contribute a test please refer to this guide and this discussion.

Plot learning rate and losses, trimmed between skip_start and skip_end. Optionally plot and return min gradient

This is mainly used with the learning rate finder, since it shows a scatterplot of loss vs learning rate.

path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

plot_losses[source][test]

plot_losses(skip_start:int=0, skip_end:int=0, return_fig:bool=None, show_grid:bool=False) → Optional[Figure]

No tests found for plot_losses. To contribute a test please refer to this guide and this discussion.

Plot training and validation losses.

Note that validation losses are only calculated once per epoch, whereas training losses are calculated after every batch.

learn.fit_one_cycle(5)
learn.recorder.plot_losses()
epoch train_loss valid_loss accuracy time
1 0.228524 0.122285 0.958783 00:05
2 0.118838 0.075222 0.971050 00:05
3 0.066715 0.054920 0.981354 00:05
4 0.048155 0.048612 0.983317 00:05
5 0.037535 0.046014 0.982336 00:05

plot_lr[source][test]

plot_lr(show_moms=False, skip_start:int=0, skip_end:int=0, return_fig:bool=None) → Optional[Figure]

No tests found for plot_lr. To contribute a test please refer to this guide and this discussion.

Plot learning rate, show_moms to include momentum.

The learning rate and momentum shown are that of the last layer (opt.lr and opt.mom).

learn.recorder.plot_lr()
learn.recorder.plot_lr(show_moms=True)

plot_metrics[source][test]

plot_metrics(skip_start:int=0, skip_end:int=0, return_fig:bool=None, show_grid:bool=False) → Optional[Figure]

No tests found for plot_metrics. To contribute a test please refer to this guide and this discussion.

Plot metrics collected during training.

Note that metrics are only collected at the end of each epoch, so you'll need to train at least two epochs to have anything to show here.

learn.recorder.plot_metrics()

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality. Refer to Callback for more details.

on_backward_begin[source][test]

on_backward_begin(smooth_loss:Tensor, **kwargs:Any)

No tests found for on_backward_begin. To contribute a test please refer to this guide and this discussion.

Record the loss before any other callback has a chance to modify it.

on_batch_begin[source][test]

on_batch_begin(train, **kwargs:Any)

No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.

Record learning rate and momentum at beginning of batch.

on_epoch_end[source][test]

on_epoch_end(epoch:int, num_batch:int, smooth_loss:Tensor, last_metrics:MetricsList, **kwargs:Any) → bool

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

Save epoch info: num_batch, smooth_loss, metrics.

on_train_begin[source][test]

on_train_begin(pbar:PBar, metrics_names:StrList, **kwargs:Any)

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

Initialize recording status at beginning of training.

Inner functions

The following functions are used along the way by the Recorder or can be called by other callbacks.

add_metric_names[source][test]

add_metric_names(names)

No tests found for add_metric_names. To contribute a test please refer to this guide and this discussion.

Add names to the inner metric names.

format_stats[source][test]

format_stats(stats:MetricsList)

No tests found for format_stats. To contribute a test please refer to this guide and this discussion.

Format stats before printing.

Module functions

Generally you'll want to use a Learner to train your model, since they provide a lot of functionality and make things easier. However, for ultimate flexibility, you can call the same underlying functions that Learner calls behind the scenes:

fit[source][test]

fit(epochs:int, learn:BasicLearner, callbacks:Optional[Collection[Callback]]=None, metrics:OptMetrics=None)

Tests found for fit:

Some other tests where fit is used:

  • pytest -sv tests/test_basic_train.py::test_destroy [source]

To run tests please refer to this guide.

Fit the model on data and learn using loss_func and opt.

Note that you have to create the Optimizer yourself if you call this function, whereas Learn.fit creates it for you automatically.

train_epoch[source][test]

train_epoch(model:Module, dl:DataLoader, opt:Optimizer, loss_func:LossFunction)

No tests found for train_epoch. To contribute a test please refer to this guide and this discussion.

Simple training of model for 1 epoch of dl using optim opt and loss function loss_func.

You won't generally need to call this yourself - it's what fit calls for each epoch.

validate[source][test]

validate(model:Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None, pbar:Union[MasterBar, ProgressBar, NoneType]=None, average=True, n_batch:Optional[int]=None) → Iterator[Tuple[IntOrTensor, Ellipsis]]

Tests found for validate:

  • pytest -sv tests/test_tabular_train.py::test_accuracy [source]

To run tests please refer to this guide.

Calculate loss_func of model on dl in evaluation mode.

This is what fit calls after each epoch. You can call it if you want to run inference on a DataLoader manually.

get_preds[source][test]

get_preds(model:Module, dl:DataLoader, pbar:Union[MasterBar, ProgressBar, NoneType]=None, cb_handler:Optional[CallbackHandler]=None, activ:Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) → List[Tensor]

Tests found for get_preds:

Some other tests where get_preds is used:

  • pytest -sv tests/test_basic_train.py::test_get_preds [source]

To run tests please refer to this guide.

Tuple of predictions and targets, and optional losses (if loss_func) using dl, max batches n_batch.

loss_batch[source][test]

loss_batch(model:Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None, cb_handler:Optional[CallbackHandler]=None) → Tuple[Union[Tensor, int, float, str]]

No tests found for loss_batch. To contribute a test please refer to this guide and this discussion.

Calculate loss and metrics for a batch, call out to callbacks as necessary.

You won't generally need to call this yourself - it's what fit and validate call for each batch. It only does a backward pass if you set opt.

Other classes

class LearnerCallback[source][test]

LearnerCallback(learn) :: Callback

No tests found for LearnerCallback. To contribute a test please refer to this guide and this discussion.

Base class for creating callbacks for a Learner.

class RecordOnCPU[source][test]

RecordOnCPU() :: Callback

No tests found for RecordOnCPU. To contribute a test please refer to this guide and this discussion.

Store the input and target going through the model on the CPU.

Open This Notebook