Skip to content

Callbacks

Callback are small piece of logic that affect the trainer they are associated with, provide log infos, or do anything you can think of.
Inside a Trainer, they can be called at four points:

  • At the beginning of an training epoch
  • At the end of an training epoch
  • At the end of a forward/backward pass
  • At the end of a testing epoch

Callbacks are added at the beginning of the training using the add_callbacks method of the Trainer class:

from implicitlab.training import SimpleRegressionTrainer, TrainingConfig
from implicitlab.training import callbacks

trainer = SimpleRegressionTrainer(TrainingConfig(), lossfun=torch.nn.MSELoss())
trainer.add_callbacks(
    callbacks.LoggerCB("output/training_log.txt"), # write losses and info about training in a .txt file
    callbacks.Render2DCB("output", 10) # makes a snapshot of a 2D neural implicit every 10 training epochs
)

List of implemented callbacks

CheckpointCB(save_folder, freq, only_weights=False)

Bases: Callback

A Callback responsible for saving the model currently in training into a file

Parameters:

Name Type Description Default
save_folder str

folder into which the model will be saved. The filename if formatted as model_e{epoch}.pt

required
freq int

frequency (in terms of number of epochs) at which a file is saved

required
only_weights bool

whether to save the whole model (as a pytorch trace) or only the model's state dict. In the latter case, the exact model architecture is not saved. Defaults to False.

False

LoggerCB(file_path)

Bases: Callback

Registers metrics of the training inside a .log file

Parameters:

Name Type Description Default
file_path str

path to the log text file.

required

MarchingCubeCB(save_folder, freq, domain=None, res=100, iso=0, prefix='')

Bases: Callback

A Callback that makes a snapshot of a 3D neural implicit by using the marching cubes algorithm to extract some level sets.

Parameters:

Name Type Description Default
save_folder str

output folder into which the images are saved

required
freq int

frequency (in terms of number of epochs) at which a snapshot is taken

required
domain AABB

AABB domain over which the grid is defined. If not provided, the default domain will be [-1.2 ; 1.2]^3. Defaults to None.

None
res int

Grid resolution for marching cubes. res^3 values will be sampled from the neural model. Defaults to 100.

100
iso int

Which iso-level will be reconstructed. Several levels can be provided in a list. Defaults to 0.

0
prefix str

prefix for the name of the saved file. The name will have the form _e_iso. Defaults to the empty string.

''

Render2DCB(save_folder, freq, plot_domain=None, resolution=800, output_contours=True, output_gradient_norm=True, prefix='')

Bases: Callback

A Callback that makes a snapshot of a 2D neural implicit by sampling its values on a grid. Can also sample the gradient's norm and make a contour plot.

Parameters:

Name Type Description Default
save_folder str

output folder into which the images are saved

required
freq int

frequency (in terms of number of epochs) at which a snapshot is taken

required
plot_domain AABB

Spanning domain of the taken snapshot. If not provided, the domain will be taken as a default [-1.5, 1.5]^2. Defaults to None.

None
resolution int

Resolution of the snapshot grid. resolution^2 samples will be computed from the neural implicit model. Defaults to 800.

800
output_contours bool

Whether to output a contour plot of the neural field. Defaults to True.

True
output_gradient_norm bool

Whether to also output a plot of the norm of the neural field's gradient. Defaults to True.

True
prefix str

prefix for the name of the saved file. The name will have the form _e_iso. Defaults to the empty string.

''
Warning

Fails if the neural implicit currently training is not 2-dimensionnal.

Write your own callback

All callbacks inherit from the base class Callback, which implements four methods:

class Callback:
    def callOnBeginTrain(self, trainer, model): pass

    def callOnEndTrain(self, trainer, model): pass

    def callOnEndForward(self, trainer, model): pass

    def callOnEndTest(self, trainer, model): pass

Every method takes as argument the trainer it's linked to, and the model the trainer is currently optimizing. Therefore, it has full information about what's going on during optimization.

To implement our own custom callback, simply create a class that inherits from Callback and that implements one or several of these methods.