PML / Training System
Training

Training System

The high-level Trainer + TrainingArguments API orchestrates epoch loops, LR scheduling, early stopping, gradient clipping, checkpointing, and callbacks for Sequential models.

Trainer

Pml\Training\Trainer is the high-level training orchestrator. It wraps a Sequential model and a TrainingArguments config, runs the epoch loop, calls LR scheduler, handles early stopping, saves checkpoints, and fires callbacks at each lifecycle event.

__construct (Sequential $model, TrainingArguments $args)

Binds the model and args. If the model's optimizer implements LearningRateAware, the LR is synced to $args->learningRate immediately.

train (Dataset $dataset, ?Dataset $validation = null): TrainingResult

Runs the full training loop. Returns a TrainingResult with loss history and best-epoch metadata.

$args = new TrainingArguments(
    epochs:       30,
    batchSize:    128,
    learningRate: 3e-4,
    lrSchedule:   'cosine',
    patience:     5,
    outputDir:    'ckpt/run1',
);

$trainer = new Trainer($model, $args);
$trainer->addCallback(new MyCallback());
$result = $trainer->train($trainDs, $valDs);
addCallback (TrainerCallback $callback): void

Registers a callback. Multiple callbacks can be added; they fire in registration order.

TrainingArguments

Value object holding all training hyperparameters. All properties are readonly.

PropertyTypeDefaultDescription
epochsint10Training epochs
batchSizeint32Mini-batch size
patienceint0Early stopping patience (0 = disabled)
minDeltafloat1e-4Minimum improvement to reset patience counter
learningRatefloat0.001Initial learning rate
lrSchedulestring'none''none' · 'cosine' · 'step' · 'linear'
lrDecayfloat0.1Decay factor for 'step' schedule
lrStepSizeint5Epochs between LR drops (step schedule)
warmupEpochsint0Linear warm-up epochs from 0 → learningRate
mixedPrecisionboolfalseScaffold for future AMP support
outputDir?stringnullDirectory for checkpoints (null = no auto-save)
$args = new TrainingArguments(
    epochs:        100,
    batchSize:     64,
    learningRate:  1e-3,
    lrSchedule:    'cosine',
    warmupEpochs:  5,
    patience:      10,
    outputDir:     'ckpt/mymodel',
);

LRScheduler

Pml\Training\LRScheduler adjusts the optimizer's learning rate each epoch based on the schedule set in TrainingArguments. Used internally by Trainer; you can also use it directly in custom loops.

__construct (Sequential $model, TrainingArguments $args)

Binds to the model's optimizer. Only works if the optimizer implements LearningRateAware.

step (int $epoch, int $totalEpochs): void

Computes the LR for the current epoch and sets it on the optimizer. Call at the start of each epoch.

Schedule formulas

ScheduleFormula
noneLR unchanged
cosinelr × ½(1 + cos(π × epoch/totalEpochs))
steplr × decay^(epoch / stepSize)
linearlr × (1 − epoch/totalEpochs)

All schedules are preceded by a linear warm-up if warmupEpochs > 0.

GradScaler

Pml\Training\GradScaler implements dynamic loss scaling for mixed-precision training. Currently a scaffold — useful for fp32 training on systems that benefit from gradient scaling to prevent underflow. Full fp16/bf16 support requires GPU backend.

__construct (bool $enabled = false, float $initScale = 65536.0, float $growthFactor = 2.0, float $backoffFactor = 0.5, int $growthInterval = 2000)

Loss scaling is disabled by default. Enable for mixed-precision training.

MethodDescription
scale(Tensor $lossGrad): TensorMultiplies gradient by current scale factor
unscaleAndStep(Optimizer $opt, Layer[] $layers): voidUnscales all gradients, checks for NaN/Inf, then calls optimizer step if valid
update(): voidAdjusts scale factor: grows if no overflow for growthInterval steps, backs off on overflow
currentScale(): floatCurrent scale value
isEnabled(): bool
$scaler = new GradScaler(enabled: true);

foreach ($loader->batches() as $batch) {
    $preds = $model->forward($batch->samples());
    $grad  = $model->getLoss()->differentiate($preds, $batch->labels());
    $grad  = $scaler->scale($grad);
    $model->backward($grad);
    $scaler->unscaleAndStep($model->getOptimizer(), $model->getLayers());
    $scaler->update();
}

EarlyStopping

Pml\NeuralNetwork\EarlyStopping monitors a metric (typically validation loss) and signals when training should stop. Used internally by Sequential::train() and Trainer. Can also be used standalone in custom loops.

__construct (int $patience, string $mode = 'min', float $minDelta = 1e-4)

mode: 'min' for loss (lower is better), 'max' for accuracy.

update (float $metric): int int (signal)

Returns one of three constants:

ConstantValueMeaning
EarlyStopping::IMPROVED1Metric improved — save a checkpoint
EarlyStopping::CONTINUE0No improvement, but still within patience
EarlyStopping::STOP-1Patience exhausted — stop training
MethodDescription
getBestMetric(): floatBest metric value seen so far
getCounter(): intEpochs since last improvement
reset(): voidResets counter and best metric

Callbacks

Implement Pml\Training\TrainerCallback to hook into training lifecycle events:

interface TrainerCallback
{
    public function onTrainBegin(TrainingArguments $args, int $steps): void;
    public function onEpochBegin(int $epoch, int $epochs): void;
    public function onBatchEnd(int $step, float $batchLoss): void;
    public function onEpochEnd(int $epoch, float $trainLoss, ?float $valLoss): void;
    public function onTrainEnd(TrainingResult $result): void;
}
class WandbCallback implements TrainerCallback
{
    public function onEpochEnd(int $epoch, float $trainLoss, ?float $valLoss): void
    {
        // Log to W&B, TensorBoard, Redis, etc.
        file_put_contents('logs/loss.ndjson',
            json_encode(['epoch' => $epoch, 'train' => $trainLoss, 'val' => $valLoss]) . "\n",
            FILE_APPEND
        );
    }

    public function onTrainBegin(TrainingArguments $args, int $steps): void {}
    public function onEpochBegin(int $e, int $t): void {}
    public function onBatchEnd(int $s, float $l): void {}
    public function onTrainEnd(TrainingResult $r): void {}
}

TrainingResult

Returned by Trainer::train().

PropertyTypeDescription
trainLossHistoryfloat[]Per-epoch average training loss
valLossHistoryfloat[]Per-epoch validation loss (if validation provided)
bestEpochintEpoch where best validation loss occurred
bestValLossfloatBest validation loss value
stoppedEarlyboolWhether early stopping triggered
totalEpochsintActual epochs run

Full Example

use Pml\NeuralNetwork\Sequential;
use Pml\NeuralNetwork\Layers\{Dense, LayerNorm, Gelu, Dropout};
use Pml\NeuralNetwork\Optimizers\AdamW;
use Pml\Losses\CategoricalCrossEntropy;
use Pml\Training\{Trainer, TrainingArguments};
use Pml\Dataset;

$model = new Sequential(
    layers: [
        new Dense(512, 256),
        new LayerNorm(256),
        new Gelu(),
        new Dropout(0.1),
        new Dense(256, 10),
    ],
    lossFn:    new CategoricalCrossEntropy(),
    optimizer: new AdamW(learningRate: 3e-4, weightDecay: 0.01),
);

$args = new TrainingArguments(
    epochs:       50,
    batchSize:    256,
    learningRate: 3e-4,
    lrSchedule:   'cosine',
    warmupEpochs: 3,
    patience:     8,
    outputDir:    'ckpt/mnist',
);

$trainer = new Trainer($model, $args);
$trainer->addCallback(new WandbCallback());

$result = $trainer->train($trainDs, $valDs);

printf("Best epoch: %d, Best val loss: %.4f, Stopped early: %s\n",
    $result->bestEpoch,
    $result->bestValLoss,
    $result->stoppedEarly ? 'yes' : 'no'
);