Gradient descent training + Added all dataset + graphs improvements
Some checks failed
linter / quality (push) Failing after 18s
tests / ci (8.4) (push) Failing after 10s
tests / ci (8.5) (push) Failing after 11s

This commit is contained in:
2026-03-13 22:06:08 +01:00
parent f8d9fbc5b1
commit f0e7be4476
29 changed files with 872 additions and 68 deletions

View File

@@ -5,6 +5,7 @@ namespace App\Models;
enum ActivationsFunctions: string
{
case STEP = 'step';
case LINEAR = 'linear';
case SIGMOID = 'sigmoid';
case RELU = 'relu';
}

View File

@@ -2,7 +2,7 @@
namespace App\Models;
class SimplePerceptron extends Perceptron {
class GradientDescentPerceptron extends Perceptron {
public function __construct(
array $synaptic_weights,
@@ -10,9 +10,9 @@ class SimplePerceptron extends Perceptron {
parent::__construct($synaptic_weights);
}
public function activationFunction(float $weighted_sum): int
public function activationFunction(float $weighted_sum): float
{
return $weighted_sum >= 0 ? 1 : 0;
return $weighted_sum;
}
}

View File

@@ -0,0 +1,91 @@
<?php
namespace App\Models;
use App\Events\PerceptronTrainingEnded;
use App\Services\DataSetReader;
use App\Services\ISynapticWeightsProvider;
use App\Services\PerceptronIterationEventBuffer;
class GradientDescentPerceptronTraining extends NetworkTraining
{
private Perceptron $perceptron;
public ActivationsFunctions $activationFunction = ActivationsFunctions::LINEAR;
private float $epochError;
public function __construct(
DataSetReader $datasetReader,
protected float $learningRate,
int $maxIterations,
protected ISynapticWeightsProvider $synapticWeightsProvider,
PerceptronIterationEventBuffer $iterationEventBuffer,
string $sessionId,
string $trainingId,
private float $minError,
) {
parent::__construct($datasetReader, $maxIterations, $iterationEventBuffer, $sessionId, $trainingId);
$this->perceptron = new GradientDescentPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
}
public function start(): void
{
$this->iteration = 0;
do {
$this->epochError = 0;
$iterationErrorPerWeight = [];
$this->iteration++;
while ($nextRow = $this->datasetReader->getRandomLine()) {
$inputs = array_slice($nextRow, 0, -1);
$correctOutput = (float) end($nextRow);
$iterationError = $this->iterationFunction($inputs, $correctOutput);
$this->epochError += (1 / 2) * (abs($iterationError) ** 2); // TDDO REMOVEME abs()
// Store the iteration error for each weight
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
foreach ($inputs_with_bias as $index => $input) {
$iterationErrorPerWeight[$index][] = $iterationError * $input;
}
// Broadcast the training iteration event
$this->addIterationToBuffer($iterationError, [[$this->perceptron->getSynapticWeights()]]);
}
// Synaptic weights correction after each epoch
$synaptic_weights = $this->perceptron->getSynapticWeights();
$new_weights = array_map(
fn($weight, $weightIndex) => $weight + $this->learningRate * array_sum($iterationErrorPerWeight[$weightIndex]),
$synaptic_weights,
array_keys($synaptic_weights)
);
$this->perceptron->setSynapticWeights($new_weights);
$this->datasetReader->reset(); // Reset the dataset for the next iteration
} while ($this->iteration < $this->maxIterations && !$this->stopCondition());
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
$this->checkPassedMaxIterations($this->epochError);
}
protected function stopCondition(): bool
{
$condition = $this->epochError <= $this->minError && $this->perceptron->getSynapticWeights() !== [[0.0, 0.0, 0.0]];
if ($condition === true) {
event(new PerceptronTrainingEnded('Le perceptron à atteint l\'erreur minimale', $this->sessionId, $this->trainingId));
}
return $condition;
}
private function iterationFunction(array $inputs, int $correctOutput)
{
$output = $this->perceptron->test($inputs);
$error = $correctOutput - $output;
return $error;
}
}

8
app/Models/Network.php Normal file
View File

@@ -0,0 +1,8 @@
<?php
namespace App\Models;
abstract class Network
{
}

View File

@@ -28,9 +28,14 @@ abstract class NetworkTraining
abstract public function start() : void;
abstract protected function stopCondition(): bool;
protected function checkPassedMaxIterations() {
protected function checkPassedMaxIterations(?float $finalError) {
if ($this->iteration >= $this->maxIterations) {
event(new PerceptronTrainingEnded('Le nombre maximal d\'itérations a été atteint', $this->sessionId, $this->trainingId));
$message = 'Le nombre maximal d\'itérations a été atteint';
if ($finalError) {
$message .= " avec une erreur finale de $finalError";
}
event(new PerceptronTrainingEnded($message, $this->sessionId, $this->trainingId));
}
}

View File

@@ -12,7 +12,7 @@ abstract class Perceptron extends Model
$this->synaptic_weights = $synaptic_weights;
}
public function test(array $inputs): int
public function test(array $inputs): float
{
$inputs = array_merge([1], $inputs); // Add bias input
@@ -24,7 +24,7 @@ abstract class Perceptron extends Model
return $this->activationFunction($weighted_sum);
}
abstract public function activationFunction(float $weighted_sum): int;
abstract public function activationFunction(float $weighted_sum): float;
public function getSynapticWeights(): array
{

View File

@@ -0,0 +1,18 @@
<?php
namespace App\Models;
class SimpleBinaryPerceptron extends Perceptron {
public function __construct(
array $synaptic_weights,
) {
parent::__construct($synaptic_weights);
}
public function activationFunction(float $weighted_sum): float
{
return $weighted_sum >= 0.0 ? 1.0 : 0.0;
}
}

View File

@@ -6,9 +6,8 @@ use App\Events\PerceptronTrainingEnded;
use App\Services\DataSetReader;
use App\Services\ISynapticWeightsProvider;
use App\Services\PerceptronIterationEventBuffer;
use Illuminate\Support\Facades\Log;
class SimplePerceptronTraining extends NetworkTraining
class SimpleBinaryPerceptronTraining extends NetworkTraining
{
private Perceptron $perceptron;
private int $iterationErrorCounter = 0;
@@ -27,7 +26,7 @@ class SimplePerceptronTraining extends NetworkTraining
string $trainingId,
) {
parent::__construct($datasetReader, $maxIterations, $iterationEventBuffer, $sessionId, $trainingId);
$this->perceptron = new SimplePerceptron($synapticWeightsProvider->generate(2));
$this->perceptron = new SimpleBinaryPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
}
public function start(): void
@@ -40,13 +39,11 @@ class SimplePerceptronTraining extends NetworkTraining
while ($nextRow = $this->datasetReader->getRandomLine()) {
$inputs = array_slice($nextRow, 0, -1);
$correctOutput = end($nextRow);
$correctOutput = (float) end($nextRow);
$correctOutput = $correctOutput > 0 ? 1 : 0; // Modify labels for non binary datasets
$error = $this->iterationFunction($inputs, $correctOutput);
$error = abs($error); // Use absolute error
// Broadcast the training iteration event
$this->addIterationToBuffer($error, [[$this->perceptron->getSynapticWeights()]]);
}
@@ -55,7 +52,7 @@ class SimplePerceptronTraining extends NetworkTraining
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
$this->checkPassedMaxIterations();
$this->checkPassedMaxIterations(null);
}
protected function stopCondition(): bool