93 lines
3.4 KiB
PHP
93 lines
3.4 KiB
PHP
<?php
|
|
|
|
namespace App\Models;
|
|
|
|
use App\Events\PerceptronTrainingEnded;
|
|
use App\Services\DataSetReader;
|
|
use App\Services\IPerceptronIterationEventBuffer;
|
|
use App\Services\ISynapticWeightsProvider;
|
|
use App\Services\PerceptronIterationEventBuffer;
|
|
|
|
class GradientDescentPerceptronTraining extends NetworkTraining
|
|
{
|
|
private Perceptron $perceptron;
|
|
|
|
public ActivationsFunctions $activationFunction = ActivationsFunctions::LINEAR;
|
|
|
|
private float $epochError;
|
|
|
|
public function __construct(
|
|
DataSetReader $datasetReader,
|
|
protected float $learningRate,
|
|
int $maxEpochs,
|
|
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
|
IPerceptronIterationEventBuffer $iterationEventBuffer,
|
|
string $sessionId,
|
|
string $trainingId,
|
|
private float $minError,
|
|
) {
|
|
parent::__construct($datasetReader, $maxEpochs, $iterationEventBuffer, $sessionId, $trainingId);
|
|
$this->perceptron = new GradientDescentPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
|
|
}
|
|
|
|
public function start(): void
|
|
{
|
|
$this->epoch = 0;
|
|
do {
|
|
$this->epochError = 0;
|
|
$epochCorrectorPerWeight = [];
|
|
$this->epoch++;
|
|
|
|
while ($nextRow = $this->datasetReader->getRandomLine()) {
|
|
$inputs = array_slice($nextRow, 0, -1);
|
|
$correctOutput = (float) end($nextRow);
|
|
|
|
$iterationError = $this->iterationFunction($inputs, $correctOutput);
|
|
$this->epochError += ($iterationError ** 2) / 2;
|
|
|
|
// Store the iteration error for each weight
|
|
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
|
|
foreach ($inputs_with_bias as $index => $input) {
|
|
$epochCorrectorPerWeight[$index][] = $iterationError * $input;
|
|
}
|
|
|
|
// Broadcast the training iteration event
|
|
$this->addIterationToBuffer($iterationError, [[$this->perceptron->getSynapticWeights()]]);
|
|
}
|
|
|
|
// Synaptic weights correction after each epoch
|
|
$synaptic_weights = $this->perceptron->getSynapticWeights();
|
|
$new_weights = array_map(
|
|
fn($weight, $weightIndex) => $weight + $this->learningRate * array_sum($epochCorrectorPerWeight[$weightIndex]),
|
|
$synaptic_weights,
|
|
array_keys($synaptic_weights)
|
|
);
|
|
$this->perceptron->setSynapticWeights($new_weights);
|
|
|
|
$this->datasetReader->reset(); // Reset the dataset for the next iteration
|
|
} while ($this->epoch < $this->maxEpochs && !$this->stopCondition());
|
|
|
|
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
|
|
|
|
$this->checkPassedMaxIterations($this->epochError);
|
|
}
|
|
|
|
protected function stopCondition(): bool
|
|
{
|
|
$condition = $this->epochError <= $this->minError && $this->perceptron->getSynapticWeights() !== [[0.0, 0.0, 0.0]];
|
|
if ($condition === true) {
|
|
event(new PerceptronTrainingEnded('Le perceptron à atteint l\'erreur minimale', $this->sessionId, $this->trainingId));
|
|
}
|
|
return $condition;
|
|
}
|
|
|
|
private function iterationFunction(array $inputs, int $correctOutput)
|
|
{
|
|
$output = $this->perceptron->test($inputs);
|
|
|
|
$error = $correctOutput - $output;
|
|
|
|
return $error;
|
|
}
|
|
}
|