GradientStatsMonitor¶
- class lightning.pytorch.callbacks.GradientStatsMonitor(log_every_n_steps=50, track_epochs=True, per_layer=False, track_sparsity=True, explosion_threshold=10000.0)[source]¶
Bases:
CallbackA PyTorch Lightning callback that monitors and logs gradient statistics during training.
Gradients are captured in
on_before_optimizer_step, i.e. before the Trainer applies gradient clipping, so all metrics reflect true unclipped gradients.- Features:
Logs global gradient norm across all parameters
Optionally logs per-layer gradient norms
Computes mean and standard deviation of gradients
Measures gradient sparsity (fraction of near-zero values)
Detects potential exploding gradients via a configurable threshold
- Logging Behavior:
Per-step metrics are logged under
train/everylog_every_n_stepsglobal steps (e.g.train/grad_norm,train/grad_mean).Per-epoch metrics are logged at the end of every epoch under
train_epoch/(e.g.train_epoch/grad_norm), aggregated over all optimizer steps regardless oflog_every_n_steps.Logging is performed only on the global rank (for distributed training safety).
Uses Lightning’s
log_dictfor compatibility with all supported loggers.The epoch accumulator and step counter are saved in checkpoints via
state_dict/load_state_dict, so epoch aggregates remain correct after a mid-epoch resume.
- Subclassing:
Override any of the following to customise what is computed or logged:
step_prefix— property that controls the per-step metric namespace (default"train/")epoch_prefix— property that controls the per-epoch metric namespace (default"train_epoch/")compute_batch_stats— metrics logged after each optimizer stepinit_epoch_stats— initial accumulator state at the start of each epochupdate_epoch_stats— how each step updates the accumulatorcompute_epoch_stats— metrics logged at the end of each epoch
- Parameters:
log_every_n_steps¶ (int) – Frequency (in global steps) at which per-step gradient statistics are logged, i.e. when
trainer.global_step % log_every_n_steps == 0. Set to0to disable per-step logging entirely (epoch logging unaffected).track_epochs¶ (bool) – If True, logs gradient statistics aggregated over each full epoch.
per_layer¶ (bool) – If True, logs gradient norms for each parameter individually. Parameter names are formatted to be compatible with hierarchical loggers.
track_sparsity¶ (bool) – If True, logs the fraction of gradients that are near zero (useful for detecting dead neurons or sparse updates).
explosion_threshold¶ (float | None) – Threshold for the global gradient norm above which a warning is raised. Operates on pre-clip gradients, so it fires even when
gradient_clip_valis set. Set toNoneto disable.
Notes
With multiple optimizers, only the first
on_before_optimizer_stepcall per global step is processed; subsequent calls for the same step are skipped.Parameters with
grad=Noneare safely ignored.If no gradients are available (e.g., frozen model or inside no_grad), the callback exits silently.
Designed to be lightweight and not interfere with the training loop.
- compute_batch_stats(layer_grads)[source]¶
Compute and return the metric dict logged after each optimizer step.
The returned dict is passed directly to
pl_module.log_dict. Override to add, remove, or rename metrics.
- compute_epoch_stats(state)[source]¶
Compute and return the metric dict logged at the end of each epoch.
- Parameters:
state¶ (
dict[str,Any]) – the accumulator produced byinit_epoch_statsand updated byupdate_epoch_stats.- Return type:
Returns
Noneif no steps were recorded (empty dataloader). The returned dict is passed directly topl_module.log_dict. Override to add, remove, or rename metrics, or to derive additional values from extra state added ininit_epoch_stats/update_epoch_stats.Note
{epoch_prefix}grad_normis the mean of per-step global norms (i.e.mean(‖g_t‖₂)over optimizer steps t), not the true L2 norm of all gradients accumulated over the epoch. The same applies to per-layer norm averages.
- init_epoch_stats()[source]¶
Return a fresh accumulator for the start of an epoch.
Override to add extra fields that
update_epoch_statsandcompute_epoch_statscan then use.
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict.
- on_before_optimizer_step(trainer, pl_module, optimizer)[source]¶
Called before
optimizer.step().- Return type:
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
- update_epoch_stats(state, layer_grads)[source]¶
Update the epoch accumulator in-place with one step’s gradients.
Override to accumulate additional fields introduced in
init_epoch_stats.- Return type:
- property epoch_prefix: str¶
Metric prefix used for per-epoch stats (e.g.
"train_epoch/"→train_epoch/grad_norm).
- property state_key: str¶
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.