Use correct naming for iteration and epoch
This commit is contained in:
@@ -7,7 +7,6 @@ use Illuminate\Broadcasting\InteractsWithSockets;
|
|||||||
use Illuminate\Contracts\Broadcasting\ShouldBroadcast;
|
use Illuminate\Contracts\Broadcasting\ShouldBroadcast;
|
||||||
use Illuminate\Foundation\Events\Dispatchable;
|
use Illuminate\Foundation\Events\Dispatchable;
|
||||||
use Illuminate\Queue\SerializesModels;
|
use Illuminate\Queue\SerializesModels;
|
||||||
use Illuminate\Support\Facades\Log;
|
|
||||||
|
|
||||||
class PerceptronTrainingIteration implements ShouldBroadcast
|
class PerceptronTrainingIteration implements ShouldBroadcast
|
||||||
{
|
{
|
||||||
@@ -17,7 +16,7 @@ class PerceptronTrainingIteration implements ShouldBroadcast
|
|||||||
* Create a new event instance.
|
* Create a new event instance.
|
||||||
*/
|
*/
|
||||||
public function __construct(
|
public function __construct(
|
||||||
public array $iterations, // ["iteration" => int, "exampleIndex" => int, "error" => float, "synaptic_weights" => array]
|
public array $iterations, // ["epoch" => int, "exampleIndex" => int, "error" => float, "synaptic_weights" => array]
|
||||||
public string $sessionId,
|
public string $sessionId,
|
||||||
public string $trainingId,
|
public string $trainingId,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ namespace App\Models;
|
|||||||
|
|
||||||
use App\Events\PerceptronTrainingEnded;
|
use App\Events\PerceptronTrainingEnded;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DataSetReader;
|
||||||
|
use App\Services\IPerceptronIterationEventBuffer;
|
||||||
use App\Services\ISynapticWeightsProvider;
|
use App\Services\ISynapticWeightsProvider;
|
||||||
use App\Services\PerceptronIterationEventBuffer;
|
use App\Services\PerceptronIterationEventBuffer;
|
||||||
|
|
||||||
@@ -18,36 +19,36 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
|||||||
public function __construct(
|
public function __construct(
|
||||||
DataSetReader $datasetReader,
|
DataSetReader $datasetReader,
|
||||||
protected float $learningRate,
|
protected float $learningRate,
|
||||||
int $maxIterations,
|
int $maxEpochs,
|
||||||
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
||||||
PerceptronIterationEventBuffer $iterationEventBuffer,
|
IPerceptronIterationEventBuffer $iterationEventBuffer,
|
||||||
string $sessionId,
|
string $sessionId,
|
||||||
string $trainingId,
|
string $trainingId,
|
||||||
private float $minError,
|
private float $minError,
|
||||||
) {
|
) {
|
||||||
parent::__construct($datasetReader, $maxIterations, $iterationEventBuffer, $sessionId, $trainingId);
|
parent::__construct($datasetReader, $maxEpochs, $iterationEventBuffer, $sessionId, $trainingId);
|
||||||
$this->perceptron = new GradientDescentPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
|
$this->perceptron = new GradientDescentPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public function start(): void
|
public function start(): void
|
||||||
{
|
{
|
||||||
$this->iteration = 0;
|
$this->epoch = 0;
|
||||||
do {
|
do {
|
||||||
$this->epochError = 0;
|
$this->epochError = 0;
|
||||||
$iterationErrorPerWeight = [];
|
$epochCorrectorPerWeight = [];
|
||||||
$this->iteration++;
|
$this->epoch++;
|
||||||
|
|
||||||
while ($nextRow = $this->datasetReader->getRandomLine()) {
|
while ($nextRow = $this->datasetReader->getRandomLine()) {
|
||||||
$inputs = array_slice($nextRow, 0, -1);
|
$inputs = array_slice($nextRow, 0, -1);
|
||||||
$correctOutput = (float) end($nextRow);
|
$correctOutput = (float) end($nextRow);
|
||||||
|
|
||||||
$iterationError = $this->iterationFunction($inputs, $correctOutput);
|
$iterationError = $this->iterationFunction($inputs, $correctOutput);
|
||||||
$this->epochError += (1 / 2) * (abs($iterationError) ** 2); // TDDO REMOVEME abs()
|
$this->epochError += ($iterationError ** 2) / 2;
|
||||||
|
|
||||||
// Store the iteration error for each weight
|
// Store the iteration error for each weight
|
||||||
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
|
$inputs_with_bias = array_merge([1], $inputs); // Add bias input
|
||||||
foreach ($inputs_with_bias as $index => $input) {
|
foreach ($inputs_with_bias as $index => $input) {
|
||||||
$iterationErrorPerWeight[$index][] = $iterationError * $input;
|
$epochCorrectorPerWeight[$index][] = $iterationError * $input;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast the training iteration event
|
// Broadcast the training iteration event
|
||||||
@@ -57,14 +58,14 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
|||||||
// Synaptic weights correction after each epoch
|
// Synaptic weights correction after each epoch
|
||||||
$synaptic_weights = $this->perceptron->getSynapticWeights();
|
$synaptic_weights = $this->perceptron->getSynapticWeights();
|
||||||
$new_weights = array_map(
|
$new_weights = array_map(
|
||||||
fn($weight, $weightIndex) => $weight + $this->learningRate * array_sum($iterationErrorPerWeight[$weightIndex]),
|
fn($weight, $weightIndex) => $weight + $this->learningRate * array_sum($epochCorrectorPerWeight[$weightIndex]),
|
||||||
$synaptic_weights,
|
$synaptic_weights,
|
||||||
array_keys($synaptic_weights)
|
array_keys($synaptic_weights)
|
||||||
);
|
);
|
||||||
$this->perceptron->setSynapticWeights($new_weights);
|
$this->perceptron->setSynapticWeights($new_weights);
|
||||||
|
|
||||||
$this->datasetReader->reset(); // Reset the dataset for the next iteration
|
$this->datasetReader->reset(); // Reset the dataset for the next iteration
|
||||||
} while ($this->iteration < $this->maxIterations && !$this->stopCondition());
|
} while ($this->epoch < $this->maxEpochs && !$this->stopCondition());
|
||||||
|
|
||||||
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
|
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ namespace App\Models;
|
|||||||
|
|
||||||
use App\Events\PerceptronTrainingEnded;
|
use App\Events\PerceptronTrainingEnded;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DataSetReader;
|
||||||
use App\Services\PerceptronIterationEventBuffer;
|
use App\Services\IPerceptronIterationEventBuffer;
|
||||||
|
|
||||||
abstract class NetworkTraining
|
abstract class NetworkTraining
|
||||||
{
|
{
|
||||||
protected int $iteration = 0;
|
protected int $epoch = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @abstract
|
* @abstract
|
||||||
@@ -18,8 +18,8 @@ abstract class NetworkTraining
|
|||||||
|
|
||||||
public function __construct(
|
public function __construct(
|
||||||
protected DataSetReader $datasetReader,
|
protected DataSetReader $datasetReader,
|
||||||
protected int $maxIterations,
|
protected int $maxEpochs,
|
||||||
protected PerceptronIterationEventBuffer $iterationEventBuffer,
|
protected IPerceptronIterationEventBuffer $iterationEventBuffer,
|
||||||
protected string $sessionId,
|
protected string $sessionId,
|
||||||
protected string $trainingId,
|
protected string $trainingId,
|
||||||
) {
|
) {
|
||||||
@@ -29,8 +29,8 @@ abstract class NetworkTraining
|
|||||||
abstract protected function stopCondition(): bool;
|
abstract protected function stopCondition(): bool;
|
||||||
|
|
||||||
protected function checkPassedMaxIterations(?float $finalError) {
|
protected function checkPassedMaxIterations(?float $finalError) {
|
||||||
if ($this->iteration >= $this->maxIterations) {
|
if ($this->epoch >= $this->maxEpochs) {
|
||||||
$message = 'Le nombre maximal d\'itérations a été atteint';
|
$message = 'Le nombre maximal d\'epoch a été atteint';
|
||||||
if ($finalError) {
|
if ($finalError) {
|
||||||
$message .= " avec une erreur finale de $finalError";
|
$message .= " avec une erreur finale de $finalError";
|
||||||
}
|
}
|
||||||
@@ -40,6 +40,6 @@ abstract class NetworkTraining
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected function addIterationToBuffer(float $error, array $synapticWeights) {
|
protected function addIterationToBuffer(float $error, array $synapticWeights) {
|
||||||
$this->iterationEventBuffer->addIteration($this->iteration, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights);
|
$this->iterationEventBuffer->addIteration($this->epoch, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ namespace App\Models;
|
|||||||
|
|
||||||
use App\Events\PerceptronTrainingEnded;
|
use App\Events\PerceptronTrainingEnded;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DataSetReader;
|
||||||
|
use App\Services\IPerceptronIterationEventBuffer;
|
||||||
use App\Services\ISynapticWeightsProvider;
|
use App\Services\ISynapticWeightsProvider;
|
||||||
use App\Services\PerceptronIterationEventBuffer;
|
|
||||||
|
|
||||||
class SimpleBinaryPerceptronTraining extends NetworkTraining
|
class SimpleBinaryPerceptronTraining extends NetworkTraining
|
||||||
{
|
{
|
||||||
@@ -19,25 +19,25 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining
|
|||||||
public function __construct(
|
public function __construct(
|
||||||
DataSetReader $datasetReader,
|
DataSetReader $datasetReader,
|
||||||
protected float $learningRate,
|
protected float $learningRate,
|
||||||
int $maxIterations,
|
int $maxEpochs,
|
||||||
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
||||||
PerceptronIterationEventBuffer $iterationEventBuffer,
|
IPerceptronIterationEventBuffer $iterationEventBuffer,
|
||||||
string $sessionId,
|
string $sessionId,
|
||||||
string $trainingId,
|
string $trainingId,
|
||||||
) {
|
) {
|
||||||
parent::__construct($datasetReader, $maxIterations, $iterationEventBuffer, $sessionId, $trainingId);
|
parent::__construct($datasetReader, $maxEpochs, $iterationEventBuffer, $sessionId, $trainingId);
|
||||||
$this->perceptron = new SimpleBinaryPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
|
$this->perceptron = new SimpleBinaryPerceptron($synapticWeightsProvider->generate($datasetReader->getInputSize()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public function start(): void
|
public function start(): void
|
||||||
{
|
{
|
||||||
$this->iteration = 0;
|
$this->epoch = 0;
|
||||||
$error = 0;
|
$error = 0;
|
||||||
do {
|
do {
|
||||||
$this->iterationErrorCounter = 0;
|
$this->iterationErrorCounter = 0;
|
||||||
$this->iteration++;
|
$this->epoch++;
|
||||||
|
|
||||||
while ($nextRow = $this->datasetReader->getRandomLine()) {
|
while ($nextRow = $this->datasetReader->getNextLine()) {
|
||||||
$inputs = array_slice($nextRow, 0, -1);
|
$inputs = array_slice($nextRow, 0, -1);
|
||||||
$correctOutput = (float) end($nextRow);
|
$correctOutput = (float) end($nextRow);
|
||||||
$correctOutput = $correctOutput > 0 ? 1 : 0; // Modify labels for non binary datasets
|
$correctOutput = $correctOutput > 0 ? 1 : 0; // Modify labels for non binary datasets
|
||||||
@@ -48,7 +48,7 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining
|
|||||||
$this->addIterationToBuffer($error, [[$this->perceptron->getSynapticWeights()]]);
|
$this->addIterationToBuffer($error, [[$this->perceptron->getSynapticWeights()]]);
|
||||||
}
|
}
|
||||||
$this->datasetReader->reset(); // Reset the dataset for the next iteration
|
$this->datasetReader->reset(); // Reset the dataset for the next iteration
|
||||||
} while ($this->iteration < $this->maxIterations && !$this->stopCondition());
|
} while ($this->epoch < $this->maxEpochs && !$this->stopCondition());
|
||||||
|
|
||||||
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
|
$this->iterationEventBuffer->flush(); // Ensure all iterations are sent to the frontend
|
||||||
|
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ function getPerceptronErrorsPerIteration(): ChartData<
|
|||||||
dataset.data.push(iteration.error);
|
dataset.data.push(iteration.error);
|
||||||
|
|
||||||
// Epoch error
|
// Epoch error
|
||||||
epochAverageError[iteration.iteration - 1] =
|
epochAverageError[iteration.epoch - 1] =
|
||||||
(epochAverageError[iteration.iteration - 1] || 0) +
|
(epochAverageError[iteration.epoch - 1] || 0) +
|
||||||
iteration.error ** 2 / 2;
|
iteration.error ** 2 / 2;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ function getPerceptronErrorsPerIteration(): ChartData<
|
|||||||
plugins: {
|
plugins: {
|
||||||
title: {
|
title: {
|
||||||
display: true,
|
display: true,
|
||||||
text: 'Nombre d\'erreurs par itération',
|
text: 'Nombre d\'erreurs par epoch',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
scales: {
|
scales: {
|
||||||
@@ -96,7 +96,7 @@ function getPerceptronErrorsPerIteration(): ChartData<
|
|||||||
color: function (context) {
|
color: function (context) {
|
||||||
if (context.tick.value == 0) {
|
if (context.tick.value == 0) {
|
||||||
return gridColorBold;
|
return gridColorBold;
|
||||||
}
|
}
|
||||||
|
|
||||||
return gridColor;
|
return gridColor;
|
||||||
},
|
},
|
||||||
@@ -106,8 +106,8 @@ function getPerceptronErrorsPerIteration(): ChartData<
|
|||||||
}"
|
}"
|
||||||
:data="{
|
:data="{
|
||||||
labels: props.iterations.reduce((labels, iteration) => {
|
labels: props.iterations.reduce((labels, iteration) => {
|
||||||
if (!labels.includes(`Itération ${iteration.iteration}`)) {
|
if (!labels.includes(`Époch ${iteration.epoch}`)) {
|
||||||
labels.push(`Itération ${iteration.iteration}`);
|
labels.push(`Époch ${iteration.epoch}`);
|
||||||
}
|
}
|
||||||
return labels;
|
return labels;
|
||||||
}, [] as string[]),
|
}, [] as string[]),
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ watch(selectedDatasetCopy, (newValue) => {
|
|||||||
<!-- MAX ITERATIONS -->
|
<!-- MAX ITERATIONS -->
|
||||||
<FormField name="max_iterations">
|
<FormField name="max_iterations">
|
||||||
<FormItem>
|
<FormItem>
|
||||||
<FormLabel>Nombre maximum d'epoch</FormLabel>
|
<FormLabel>Nombre maximum d'itérations</FormLabel>
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<Input
|
<Input
|
||||||
type="number"
|
type="number"
|
||||||
|
|||||||
@@ -34,4 +34,4 @@ export const colors = [
|
|||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
export const gridColor = '#444';
|
export const gridColor = '#444';
|
||||||
export const gridColorBold = '#999';
|
export const gridColorBold = '#999';
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import type { Point } from "chart.js";
|
import type { Point } from "chart.js";
|
||||||
|
|
||||||
export type Iteration = {
|
export type Iteration = {
|
||||||
iteration: number;
|
epoch: number;
|
||||||
exampleIndex: number;
|
exampleIndex: number;
|
||||||
weights: number[][][];
|
weights: number[][][];
|
||||||
error: number;
|
error: number;
|
||||||
@@ -10,7 +10,8 @@ export type Iteration = {
|
|||||||
export type Dataset = {
|
export type Dataset = {
|
||||||
label: string;
|
label: string;
|
||||||
data: DatasetPoint[];
|
data: DatasetPoint[];
|
||||||
defaultLearningRate: number | undefined;
|
defaultLearningRate?: number;
|
||||||
|
defaultMinError?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type DatasetPoint = {
|
export type DatasetPoint = {
|
||||||
|
|||||||
Reference in New Issue
Block a user