Files
Reseaux-de-neurones-artific…/app/Models/NetworksTraining/ADALINEPerceptronTraining.php
Matthias Guillitte 88a932391b
Some checks failed
linter / quality (push) Successful in 7m39s
tests / ci (8.4) (push) Successful in 6m10s
tests / ci (8.5) (push) Has been cancelled
Renamed tests + added some + misc
2026-03-22 16:57:12 +01:00

106 lines
4.0 KiB
PHP

<?php
namespace App\Models\NetworksTraining;
use App\Events\PerceptronTrainingEnded;
use App\Models\ActivationsFunctions;
use App\Models\Perceptrons\GradientDescentPerceptron;
use App\Models\Perceptrons\Perceptron;
use App\Services\DatasetReader\IDataSetReader;
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
class ADALINEPerceptronTraining extends NetworkTraining
{
private Perceptron $perceptron;
public ActivationsFunctions $activationFunction = ActivationsFunctions::LINEAR;
private float $epochError;
public function __construct(
IDataSetReader $datasetReader,
protected float $learningRate,
int $maxEpochs,
protected ISynapticWeightsProvider $synapticWeightsProvider,
IPerceptronIterationEventBuffer $iterationEventBuffer,
string $sessionId,
string $trainingId,
private float $minError,
) {
parent::__construct($datasetReader, $maxEpochs, $iterationEventBuffer, $sessionId, $trainingId);
$this->perceptron = new GradientDescentPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
}
public function start(): void
{
$this->epoch = 0;
do {
$this->epochError = 0;
$this->epoch++;
$inputsForCurrentEpoch = [];
while ($nextRow = $this->datasetReader->getNextLine()) {
$inputsForCurrentEpoch[] = $nextRow;
$inputs = array_slice($nextRow, 0, -1);
$correctOutput = (float) end($nextRow);
$iterationError = $this->iterationFunction($inputs, $correctOutput);
// Synaptic weights correction after each example
$synaptic_weights = $this->perceptron->getSynapticWeights();
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
$new_weights = array_map(
fn($weight, $weightIndex) => $weight + ($this->learningRate * $iterationError * $inputs_with_bias[$weightIndex]),
$synaptic_weights,
array_keys($synaptic_weights)
);
$this->perceptron->setSynapticWeights($new_weights);
// Broadcast the training iteration event
$this->addIterationToBuffer($iterationError, [[$this->perceptron->getSynapticWeights()]]);
}
// Calculte the average error for the epoch with the last synaptic weights
foreach ($inputsForCurrentEpoch as $inputsWithLabel) {
$inputs = array_slice($inputsWithLabel, 0, -1);
$correctOutput = (float) end($inputsWithLabel);
$output = $this->perceptron->test($inputs);
$iterationError = $correctOutput - $output;
$this->epochError += ($iterationError ** 2) / 2; // Squared error for the example
}
$this->epochError /= $this->datasetReader->getEpochExamplesCount(); // Average error for the epoch
$this->datasetReader->reset(); // Reset the dataset for the next iteration
} while ($this->epoch < $this->maxEpochs && !$this->stopCondition());
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
$this->checkPassedMaxIterations($this->epochError);
}
protected function stopCondition(): bool
{
$condition = $this->epochError <= $this->minError;
if ($condition === true) {
event(new PerceptronTrainingEnded('Le perceptron à atteint l\'erreur minimale', $this->sessionId, $this->trainingId));
}
return $condition;
}
private function iterationFunction(array $inputs, int $correctOutput)
{
$output = $this->perceptron->test($inputs);
$error = $correctOutput - $output;
return $error;
}
public function getSynapticWeights(): array
{
return [[$this->perceptron->getSynapticWeights()]];
}
}