Rafactored Perceptrons and network training

This commit is contained in:
2026-03-22 14:58:34 +01:00
parent 47991fe736
commit 42e07de287
9 changed files with 18 additions and 29 deletions

View File

@@ -0,0 +1,101 @@
<?php
namespace App\Models\NetworksTraining;
use App\Events\PerceptronTrainingEnded;
use App\Models\ActivationsFunctions;
use App\Models\Perceptrons\GradientDescentPerceptron;
use App\Models\Perceptrons\Perceptron;
use App\Services\DatasetReader\IDataSetReader;
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
class GradientDescentPerceptronTraining extends NetworkTraining
{
private Perceptron $perceptron;
public ActivationsFunctions $activationFunction = ActivationsFunctions::LINEAR;
private float $epochError;
public function __construct(
IDataSetReader $datasetReader,
protected float $learningRate,
int $maxEpochs,
protected ISynapticWeightsProvider $synapticWeightsProvider,
IPerceptronIterationEventBuffer $iterationEventBuffer,
string $sessionId,
string $trainingId,
private float $minError,
) {
parent::__construct($datasetReader, $maxEpochs, $iterationEventBuffer, $sessionId, $trainingId);
$this->perceptron = new GradientDescentPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
}
public function start(): void
{
$this->epoch = 0;
do {
$this->epochError = 0;
$epochCorrectorPerWeight = [];
$this->epoch++;
while ($nextRow = $this->datasetReader->getNextLine()) {
$inputs = array_slice($nextRow, 0, -1);
$correctOutput = (float) end($nextRow);
$iterationError = $this->iterationFunction($inputs, $correctOutput);
$this->epochError += ($iterationError ** 2) / 2;
// Store the iteration error for each weight
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
foreach ($inputs_with_bias as $index => $input) {
$epochCorrectorPerWeight[$index][] = $iterationError * $input;
}
// Broadcast the training iteration event
$this->addIterationToBuffer($iterationError, [[$this->perceptron->getSynapticWeights()]]);
}
$this->epochError /= $this->datasetReader->getEpochExamplesCount(); // Average error for the epoch
// Synaptic weights correction after each epoch
$synaptic_weights = $this->perceptron->getSynapticWeights();
$new_weights = array_map(
fn($weight, $weightIndex) => $weight + $this->learningRate * array_sum($epochCorrectorPerWeight[$weightIndex]),
$synaptic_weights,
array_keys($synaptic_weights)
);
$this->perceptron->setSynapticWeights($new_weights);
$this->datasetReader->reset(); // Reset the dataset for the next iteration
} while ($this->epoch < $this->maxEpochs && !$this->stopCondition());
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
$this->checkPassedMaxIterations($this->epochError);
}
protected function stopCondition(): bool
{
$condition = $this->epochError <= $this->minError && $this->perceptron->getSynapticWeights() !== [[0.0, 0.0, 0.0]];
if ($condition === true) {
event(new PerceptronTrainingEnded('Le perceptron à atteint l\'erreur minimale', $this->sessionId, $this->trainingId));
}
return $condition;
}
private function iterationFunction(array $inputs, int $correctOutput)
{
$output = $this->perceptron->test($inputs);
$error = $correctOutput - $output;
return $error;
}
public function getSynapticWeights(): array
{
return [[$this->perceptron->getSynapticWeights()]];
}
}

View File

@@ -0,0 +1,53 @@
<?php
namespace App\Models\NetworksTraining;
use App\Events\PerceptronTrainingEnded;
use App\Services\DatasetReader\IDataSetReader;
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
use App\Models\ActivationsFunctions;
abstract class NetworkTraining
{
protected int $epoch = 0;
/**
* @abstract
* @var ActivationsFunctions
*/
public ActivationsFunctions $activationFunction;
public function __construct(
protected IDataSetReader $datasetReader,
protected int $maxEpochs,
protected IPerceptronIterationEventBuffer $iterationEventBuffer,
protected string $sessionId,
protected string $trainingId,
) {
}
abstract public function start() : void;
abstract protected function stopCondition(): bool;
protected function checkPassedMaxIterations(?float $finalError) {
if ($this->epoch >= $this->maxEpochs) {
$message = 'Le nombre maximal d\'epoch a été atteint';
if ($finalError) {
$message .= " avec une erreur finale de $finalError";
}
event(new PerceptronTrainingEnded($message, $this->sessionId, $this->trainingId));
}
}
protected function addIterationToBuffer(float $error, array $synapticWeights) {
$this->iterationEventBuffer->addIteration($this->epoch, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights);
}
public function getEpoch(): int
{
return $this->epoch;
}
abstract public function getSynapticWeights(): array;
}

View File

@@ -0,0 +1,92 @@
<?php
namespace App\Models\NetworksTraining;
use App\Events\PerceptronTrainingEnded;
use App\Models\ActivationsFunctions;
use App\Models\Perceptrons\Perceptron;
use App\Models\Perceptrons\SimpleBinaryPerceptron;
use App\Services\DatasetReader\IDataSetReader;
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
class SimpleBinaryPerceptronTraining extends NetworkTraining
{
private Perceptron $perceptron;
private int $iterationErrorCounter = 0;
public ActivationsFunctions $activationFunction = ActivationsFunctions::STEP;
public const MIN_ERROR = 0;
public function __construct(
IDataSetReader $datasetReader,
protected float $learningRate,
int $maxEpochs,
protected ISynapticWeightsProvider $synapticWeightsProvider,
IPerceptronIterationEventBuffer $iterationEventBuffer,
string $sessionId,
string $trainingId,
) {
parent::__construct($datasetReader, $maxEpochs, $iterationEventBuffer, $sessionId, $trainingId);
$this->perceptron = new SimpleBinaryPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
}
public function start(): void
{
$this->epoch = 0;
$error = 0;
do {
$this->iterationErrorCounter = 0;
$this->epoch++;
while ($nextRow = $this->datasetReader->getNextLine()) {
$inputs = array_slice($nextRow, 0, -1);
$correctOutput = (float) end($nextRow);
$correctOutput = $correctOutput > 0 ? 1 : 0; // Modify labels for non binary datasets
$error = $this->iterationFunction($inputs, $correctOutput);
// Broadcast the training iteration event
$this->addIterationToBuffer($error, [[$this->perceptron->getSynapticWeights()]]);
}
$this->datasetReader->reset(); // Reset the dataset for the next iteration
} while ($this->epoch < $this->maxEpochs && !$this->stopCondition());
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
$this->checkPassedMaxIterations(null);
}
protected function stopCondition(): bool
{
$condition = $this->iterationErrorCounter == 0;
if ($condition === true) {
event(new PerceptronTrainingEnded('Le perceptron ne commet plus d\'erreurs sur aucune des données', $this->sessionId, $this->trainingId));
}
return $this->iterationErrorCounter == 0;
}
private function iterationFunction(array $inputs, int $correctOutput)
{
$output = $this->perceptron->test($inputs);
$error = $correctOutput - $output;
if (abs($error) > $this::MIN_ERROR) {
$this->iterationErrorCounter++;
}
if ($error !== 0) { // Update synaptic weights if needed
$synaptic_weights = $this->perceptron->getSynapticWeights();
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
$new_weights = array_map(fn($weight, $input) => $weight + $this->learningRate * $error * $input, $synaptic_weights, $inputs_with_bias);
$this->perceptron->setSynapticWeights($new_weights);
}
return $error;
}
public function getSynapticWeights(): array
{
return [[$this->perceptron->getSynapticWeights()]];
}
}