GradientStatsMonitor

class lightning.pytorch.callbacks.GradientStatsMonitor(log_every_n_steps=50, track_epochs=True, per_layer=False, track_sparsity=True, explosion_threshold=10000.0)[source]

Bases: Callback

A PyTorch Lightning callback that monitors and logs gradient statistics during training.

Gradients are captured in on_before_optimizer_step, i.e. before the Trainer applies gradient clipping, so all metrics reflect true unclipped gradients.

Features:
  • Logs global gradient norm across all parameters

  • Optionally logs per-layer gradient norms

  • Computes mean and standard deviation of gradients

  • Measures gradient sparsity (fraction of near-zero values)

  • Detects potential exploding gradients via a configurable threshold

Logging Behavior:
  • Per-step metrics are logged under train/ every log_every_n_steps global steps (e.g. train/grad_norm, train/grad_mean).

  • Per-epoch metrics are logged at the end of every epoch under train_epoch/ (e.g. train_epoch/grad_norm), aggregated over all optimizer steps regardless of log_every_n_steps.

  • Logging is performed only on the global rank (for distributed training safety).

  • Uses Lightning’s log_dict for compatibility with all supported loggers.

  • The epoch accumulator and step counter are saved in checkpoints via state_dict / load_state_dict, so epoch aggregates remain correct after a mid-epoch resume.

Subclassing:

Override any of the following to customise what is computed or logged:

  • step_prefix — property that controls the per-step metric namespace (default "train/")

  • epoch_prefix — property that controls the per-epoch metric namespace (default "train_epoch/")

  • compute_batch_stats — metrics logged after each optimizer step

  • init_epoch_stats — initial accumulator state at the start of each epoch

  • update_epoch_stats — how each step updates the accumulator

  • compute_epoch_stats — metrics logged at the end of each epoch

Parameters:
  • log_every_n_steps (int) – Frequency (in global steps) at which per-step gradient statistics are logged, i.e. when trainer.global_step % log_every_n_steps == 0. Set to 0 to disable per-step logging entirely (epoch logging unaffected).

  • track_epochs (bool) – If True, logs gradient statistics aggregated over each full epoch.

  • per_layer (bool) – If True, logs gradient norms for each parameter individually. Parameter names are formatted to be compatible with hierarchical loggers.

  • track_sparsity (bool) – If True, logs the fraction of gradients that are near zero (useful for detecting dead neurons or sparse updates).

  • explosion_threshold (float | None) – Threshold for the global gradient norm above which a warning is raised. Operates on pre-clip gradients, so it fires even when gradient_clip_val is set. Set to None to disable.

Notes

  • With multiple optimizers, only the first on_before_optimizer_step call per global step is processed; subsequent calls for the same step are skipped.

  • Parameters with grad=None are safely ignored.

  • If no gradients are available (e.g., frozen model or inside no_grad), the callback exits silently.

  • Designed to be lightweight and not interfere with the training loop.

compute_batch_stats(layer_grads)[source]

Compute and return the metric dict logged after each optimizer step.

The returned dict is passed directly to pl_module.log_dict. Override to add, remove, or rename metrics.

Return type:

dict[str, float]

compute_epoch_stats(state)[source]

Compute and return the metric dict logged at the end of each epoch.

Parameters:

state (dict[str, Any]) – the accumulator produced by init_epoch_stats and updated by update_epoch_stats.

Return type:

Optional[dict[str, float]]

Returns None if no steps were recorded (empty dataloader). The returned dict is passed directly to pl_module.log_dict. Override to add, remove, or rename metrics, or to derive additional values from extra state added in init_epoch_stats / update_epoch_stats.

Note

{epoch_prefix}grad_norm is the mean of per-step global norms (i.e. mean(‖g_t‖₂) over optimizer steps t), not the true L2 norm of all gradients accumulated over the epoch. The same applies to per-layer norm averages.

init_epoch_stats()[source]

Return a fresh accumulator for the start of an epoch.

Override to add extra fields that update_epoch_stats and compute_epoch_stats can then use.

Return type:

dict[str, Any]

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_before_optimizer_step(trainer, pl_module, optimizer)[source]

Called before optimizer.step().

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type:

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

update_epoch_stats(state, layer_grads)[source]

Update the epoch accumulator in-place with one step’s gradients.

Override to accumulate additional fields introduced in init_epoch_stats.

Return type:

None

property epoch_prefix: str

Metric prefix used for per-epoch stats (e.g. "train_epoch/"train_epoch/grad_norm).

property state_key: str

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

property step_prefix: str

Metric prefix used for per-step stats (e.g. "train/"train/grad_norm).