haiku_trainer

haiku_trainer module

TrainerState(params, Mapping[str, …)

TrainerState data structure.

Trainer(train_loss_fn, train_data_iter[, …])

Trainer class

TrainerState

class haiku_trainer.TrainerState(params: Mapping[str, Mapping[str, jax._src.numpy.lax_numpy.ndarray]], aux: Mapping[str, Mapping[str, jax._src.numpy.lax_numpy.ndarray]], optim: NamedTuple, rng: jax._src.numpy.lax_numpy.ndarray)

TrainerState data structure.

params

model parameters.

Type

hk.Params

aux

model auxilary parameters.

Type

hk.Params

optim

optimizer states.

Type

optax.OptState

rng

random states

Type

ndarray

Trainer

class haiku_trainer.Trainer(train_loss_fn, train_data_iter, val_loss_fn=None, val_data_iter=None, optimizer=GradientTransformation(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), ckpt_freq=1000, logging_freq=100, out_dir='/tmp/ckpts', resume=False, wandb=None)

Trainer class

__init__(train_loss_fn, train_data_iter, val_loss_fn=None, val_data_iter=None, optimizer=GradientTransformation(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), ckpt_freq=1000, logging_freq=100, out_dir='/tmp/ckpts', resume=False, wandb=None)

Create a trainer object.

This object will keep all the information to train a haiku model. Including: model parameters, current training step, etc.

Parameters
  • train_loss_fn – a function which creates model and computes the training loss.

  • train_data_iter – an iterable object which yields a training data mini-batch.

  • val_loss_fn – a function which creates model and computes the validation loss.

  • val_data_iter – an iterable object which yields a valalidation data mini-batch.

  • optimizer – an optax optimizer. For example: optax.adam.

  • ckpt_freq (int) – checkpoint frequency.

  • logging_freq (int) – logging frequency.

  • out_dir (str) – output directory for saving checkpoints.

  • resume (bool) – resume the trainer to the latest checkpoint.

  • wandb – wandb module object for logging (or None to disable).

avg_training_loss()

Return the current average training loss.

avg_validation_loss()

Return the current average validation loss.

compile()

Compile the update function and validation function.

The update function is used to update parameters of the network. It first calls the loss function, compute the gradient, and finally calls the optimizer.

find_latest_checkpoint()

Return the latest checkpoint in the output directory.

fit(total_steps=1)

Fit model with the data.

Call registered callback functions.

Also create a checkpoint at the last training step.

Parameters

total_steps (int) – the total training steps(does not include validation step)

load_state(file_obj)

Load model states from file object file_obj.

load_step(step)

Load model states at checkpoint step.

register_callback(callback_freq, callback_fn)

Register a callback function in training process.

Function callback_fn will be called at every callback_freq training steps.

The callback function is called indirectly by Trainer.run_func_with_state. This allows the function to access model states and parameters.

Parameters
  • callback_freq (int) – call function at steps % callback_freq == 0.

  • callback_fn (Callable) – the function to carry the task.

resume()

Find the latest checkpoint in the output directory.

Resume the network states to the checkpoint.

run_func_with_state(fn)

Transform function fn and run it.

Allows the function fn to access the model’s parameters and states. This function is very useful when we want to run prediction after the model is trained.

Parameters

fn (Callable) – the function will be transformed.

save_state(file_obj)

Save model state to file objection file_obj

save_step(step)

Save model state at step.

tiktok()

Compute the elapsed time since the last time this function is called.

training_step()

Run one training step: * update network parameters, * and increase Trainer.last_step.

trange(total_steps)

Return a tqdm object of a range(last_step+1, total_steps).

validation_step()

Run one validation step.