Implementation of a flexible training API

TrainingPhase and General scheduler

Creates a scheduler that lets you train a model with following different TrainingPhase.

class TrainingPhase[source][test]

TrainingPhase(length:int)

No tests found for TrainingPhase. To contribute a test please refer to this guide and this discussion.

Schedule hyper-parameters for a phase of length iterations.

You can then schedule any hyper-parameter you want by using the following method.

schedule_hp[source][test]

schedule_hp(name, vals, anneal=None)

No tests found for schedule_hp. To contribute a test please refer to this guide and this discussion.

Adds a schedule for name between vals using anneal.

The phase will make the hyper-parameter vary from the first value in vals to the second, following anneal. If an annealing function is specified but vals is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing for a tuple, a constant parameter if it's a float.

The basic hyper-parameters are named:

  • 'lr' for learning rate
  • 'mom' for momentum (or beta1 in Adam)
  • 'beta' for the beta2 in Adam or the alpha in RMSprop
  • 'wd' for weight decay

You can also add any hyper-parameter that is in your optimizer (even if it's custom or a GeneralOptimizer), like 'eps' if you're using Adam.

Let's make an example by using this to code SGD with warm restarts.

def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):
    n = len(learn.data.train_dl)
    phases = [(TrainingPhase(n * (cycle_len * cycle_mult**i))
                 .schedule_hp('lr', lr, anneal=annealing_cos)
                 .schedule_hp('mom', mom)) for i in range(n_cycles)]
    sched = GeneralScheduler(learn, phases)
    learn.callbacks.append(sched)
    if cycle_mult != 1:
        total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult)) 
    else: total_epochs = n_cycles * cycle_len
    learn.fit(total_epochs)
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)
epoch train_loss valid_loss accuracy time
0 0.162146 0.153532 0.942100 00:02
1 0.126112 0.117267 0.960255 00:02
2 0.112045 0.110586 0.962218 00:02
3 0.097603 0.090838 0.967615 00:02
4 0.086883 0.081375 0.973013 00:02
5 0.083673 0.076160 0.973994 00:02
6 0.084835 0.076211 0.973994 00:02
learn.recorder.plot_lr()

class GeneralScheduler[source][test]

GeneralScheduler(learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None) :: LearnerCallback

No tests found for GeneralScheduler. To contribute a test please refer to this guide and this discussion.

Schedule multiple TrainingPhase for a Learner.

Callback methods

You don't call these yourself - they're called by fastai's Callback system automatically to enable the class's functionality.

on_batch_end[source][test]

on_batch_end(train, **kwargs:Any)

No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

Takes a step in the current phase and prepare the hyperparameters for the next batch.

on_train_begin[source][test]

on_train_begin(epoch:int, **kwargs:Any)

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

Initiates the hyperparameters to the start values of the first phase.