haiku_trainer¶
haiku_trainer module¶
|
TrainerState data structure. |
|
Trainer class |
TrainerState¶
-
class
haiku_trainer.TrainerState(params: Mapping[str, Mapping[str, jax._src.numpy.lax_numpy.ndarray]], aux: Mapping[str, Mapping[str, jax._src.numpy.lax_numpy.ndarray]], optim: NamedTuple, rng: jax._src.numpy.lax_numpy.ndarray)¶ TrainerState data structure.
-
params¶ model parameters.
- Type
hk.Params
-
aux¶ model auxilary parameters.
- Type
hk.Params
-
optim¶ optimizer states.
- Type
optax.OptState
-
rng¶ random states
- Type
ndarray
-
Trainer¶
-
class
haiku_trainer.Trainer(train_loss_fn, train_data_iter, val_loss_fn=None, val_data_iter=None, optimizer=GradientTransformation(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), ckpt_freq=1000, logging_freq=100, out_dir='/tmp/ckpts', resume=False, wandb=None)¶ Trainer class
-
__init__(train_loss_fn, train_data_iter, val_loss_fn=None, val_data_iter=None, optimizer=GradientTransformation(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), ckpt_freq=1000, logging_freq=100, out_dir='/tmp/ckpts', resume=False, wandb=None)¶ Create a trainer object.
This object will keep all the information to train a haiku model. Including: model parameters, current training step, etc.
- Parameters
train_loss_fn – a function which creates model and computes the training loss.
train_data_iter – an iterable object which yields a training data mini-batch.
val_loss_fn – a function which creates model and computes the validation loss.
val_data_iter – an iterable object which yields a valalidation data mini-batch.
optimizer – an optax optimizer. For example: optax.adam.
ckpt_freq (
int) – checkpoint frequency.logging_freq (
int) – logging frequency.out_dir (
str) – output directory for saving checkpoints.resume (
bool) – resume the trainer to the latest checkpoint.wandb – wandb module object for logging (or None to disable).
-
avg_training_loss()¶ Return the current average training loss.
-
avg_validation_loss()¶ Return the current average validation loss.
-
compile()¶ Compile the update function and validation function.
The update function is used to update parameters of the network. It first calls the loss function, compute the gradient, and finally calls the optimizer.
-
find_latest_checkpoint()¶ Return the latest checkpoint in the output directory.
-
fit(total_steps=1)¶ Fit model with the data.
Call registered callback functions.
Also create a checkpoint at the last training step.
- Parameters
total_steps (int) – the total training steps(does not include validation step)
-
load_state(file_obj)¶ Load model states from file object
file_obj.
-
load_step(step)¶ Load model states at checkpoint
step.
-
register_callback(callback_freq, callback_fn)¶ Register a callback function in training process.
Function
callback_fnwill be called at everycallback_freqtraining steps.The callback function is called indirectly by
Trainer.run_func_with_state. This allows the function to access model states and parameters.- Parameters
callback_freq (int) – call function at steps % callback_freq == 0.
callback_fn (Callable) – the function to carry the task.
-
resume()¶ Find the latest checkpoint in the output directory.
Resume the network states to the checkpoint.
-
run_func_with_state(fn)¶ Transform function
fnand run it.Allows the function
fnto access the model’s parameters and states. This function is very useful when we want to run prediction after the model is trained.- Parameters
fn (Callable) – the function will be transformed.
-
save_state(file_obj)¶ Save model state to file objection
file_obj
-
save_step(step)¶ Save model state at step.
-
tiktok()¶ Compute the elapsed time since the last time this function is called.
-
training_step()¶ Run one training step: * update network parameters, * and increase
Trainer.last_step.
-
trange(total_steps)¶ Return a tqdm object of a
range(last_step+1, total_steps).
-
validation_step()¶ Run one validation step.
-