diff --git a/app/Http/Controllers/PerceptronController.php b/app/Http/Controllers/PerceptronController.php index e8d54e3..9832c62 100644 --- a/app/Http/Controllers/PerceptronController.php +++ b/app/Http/Controllers/PerceptronController.php @@ -5,12 +5,14 @@ namespace App\Http\Controllers; use App\Events\PerceptronInitialization; use App\Models\GradientDescentPerceptronTraining; use App\Models\SimpleBinaryPerceptronTraining; -use App\Services\DataSetReader; -use App\Services\ISynapticWeightsProvider; -use App\Services\PerceptronIterationEventBuffer; -use App\Services\PerceptronLimitedEpochEventBuffer; -use App\Services\ZeroSynapticWeights; +use App\Services\DatasetReader\IDataSetReader; +use App\Services\DatasetReader\LinearOrderDataSetReader; +use App\Services\IterationEventBuffer\PerceptronIterationEventBuffer; +use App\Services\IterationEventBuffer\PerceptronLimitedEpochEventBuffer; +use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider; +use App\Services\SynapticWeightsProvider\ZeroSynapticWeights; use Illuminate\Http\Request; +use Tests\Services\IterationEventBuffer\DullIterationEventBuffer; class PerceptronController extends Controller { @@ -23,7 +25,7 @@ class PerceptronController extends Controller $learningRate = 0.01; $maxIterations = 200; - $minError = 0.6; + $minError = 0.1; switch ($perceptronType) { case 'simple': @@ -53,7 +55,7 @@ class PerceptronController extends Controller if (pathinfo($file, PATHINFO_EXTENSION) === 'csv') { $dataset = []; $dataset['label'] = str_replace('.csv', '', $file); - $dataSetReader = new DataSetReader($dataSetsDirectory . '/' . $file); + $dataSetReader = new LinearOrderDataSetReader($dataSetsDirectory . '/' . $file); $dataset['data'] = []; switch (count($dataSetReader->lines[0])) { case 3: @@ -84,6 +86,7 @@ class PerceptronController extends Controller switch ($perceptronType) { case 'gradientdescent': $dataset['defaultLearningRate'] = 0.3; + $dataset['defaultMinError'] = 0.125; break; } break; @@ -94,7 +97,6 @@ class PerceptronController extends Controller break; case 'gradientdescent': $dataset['defaultLearningRate'] = 0.001; - $dataset['defaultMinError'] = 2.0; break; } break; @@ -108,10 +110,10 @@ class PerceptronController extends Controller return $datasets; } - private function getDataSetReader(string $dataSet): DataSetReader + private function getDataSetReader(string $dataSet): IDataSetReader { $dataSetFileName = "data_sets/{$dataSet}.csv"; - return new DataSetReader($dataSetFileName); + return new LinearOrderDataSetReader($dataSetFileName); } public function run(Request $request, ISynapticWeightsProvider $synapticWeightsProvider) @@ -141,6 +143,16 @@ class PerceptronController extends Controller $networkTraining = match ($perceptronType) { 'simple' => new SimpleBinaryPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId), 'gradientdescent' => new GradientDescentPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId, $minError), + 'gradientdescentTest' => new GradientDescentPerceptronTraining( + datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and_gradient.csv')), + learningRate: 0.2, + maxEpochs: 100, + synapticWeightsProvider: new ZeroSynapticWeights(), + iterationEventBuffer: $iterationEventBuffer, + sessionId: 'test-session', + trainingId: 'test-training', + minError: 0.125001, + ), default => null, }; diff --git a/app/Models/GradientDescentPerceptronTraining.php b/app/Models/GradientDescentPerceptronTraining.php index 952df78..703aca0 100644 --- a/app/Models/GradientDescentPerceptronTraining.php +++ b/app/Models/GradientDescentPerceptronTraining.php @@ -3,10 +3,9 @@ namespace App\Models; use App\Events\PerceptronTrainingEnded; -use App\Services\DataSetReader; -use App\Services\IPerceptronIterationEventBuffer; -use App\Services\ISynapticWeightsProvider; -use App\Services\PerceptronIterationEventBuffer; +use App\Services\DatasetReader\IDataSetReader; +use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer; +use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider; class GradientDescentPerceptronTraining extends NetworkTraining { @@ -17,7 +16,7 @@ class GradientDescentPerceptronTraining extends NetworkTraining private float $epochError; public function __construct( - DataSetReader $datasetReader, + IDataSetReader $datasetReader, protected float $learningRate, int $maxEpochs, protected ISynapticWeightsProvider $synapticWeightsProvider, @@ -38,7 +37,7 @@ class GradientDescentPerceptronTraining extends NetworkTraining $epochCorrectorPerWeight = []; $this->epoch++; - while ($nextRow = $this->datasetReader->getRandomLine()) { + while ($nextRow = $this->datasetReader->getNextLine()) { $inputs = array_slice($nextRow, 0, -1); $correctOutput = (float) end($nextRow); @@ -89,4 +88,9 @@ class GradientDescentPerceptronTraining extends NetworkTraining return $error; } + + public function getSynapticWeights(): array + { + return [[$this->perceptron->getSynapticWeights()]]; + } } diff --git a/app/Models/NetworkTraining.php b/app/Models/NetworkTraining.php index 3f873a0..af13be2 100644 --- a/app/Models/NetworkTraining.php +++ b/app/Models/NetworkTraining.php @@ -3,8 +3,8 @@ namespace App\Models; use App\Events\PerceptronTrainingEnded; -use App\Services\DataSetReader; -use App\Services\IPerceptronIterationEventBuffer; +use App\Services\DatasetReader\IDataSetReader; +use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer; abstract class NetworkTraining { @@ -17,7 +17,7 @@ abstract class NetworkTraining public ActivationsFunctions $activationFunction; public function __construct( - protected DataSetReader $datasetReader, + protected IDataSetReader $datasetReader, protected int $maxEpochs, protected IPerceptronIterationEventBuffer $iterationEventBuffer, protected string $sessionId, @@ -42,4 +42,11 @@ abstract class NetworkTraining protected function addIterationToBuffer(float $error, array $synapticWeights) { $this->iterationEventBuffer->addIteration($this->epoch, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights); } + + public function getEpoch(): int + { + return $this->epoch; + } + + abstract public function getSynapticWeights(): array; } diff --git a/app/Models/SimpleBinaryPerceptronTraining.php b/app/Models/SimpleBinaryPerceptronTraining.php index 1570edb..d885f74 100644 --- a/app/Models/SimpleBinaryPerceptronTraining.php +++ b/app/Models/SimpleBinaryPerceptronTraining.php @@ -3,9 +3,9 @@ namespace App\Models; use App\Events\PerceptronTrainingEnded; -use App\Services\DataSetReader; -use App\Services\IPerceptronIterationEventBuffer; -use App\Services\ISynapticWeightsProvider; +use App\Services\DatasetReader\IDataSetReader; +use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer; +use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider; class SimpleBinaryPerceptronTraining extends NetworkTraining { @@ -17,7 +17,7 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining public const MIN_ERROR = 0; public function __construct( - DataSetReader $datasetReader, + IDataSetReader $datasetReader, protected float $learningRate, int $maxEpochs, protected ISynapticWeightsProvider $synapticWeightsProvider, @@ -81,4 +81,9 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining } return $error; } + + public function getSynapticWeights(): array + { + return [[$this->perceptron->getSynapticWeights()]]; + } } diff --git a/app/Providers/InitialSynapticWeightsProvider.php b/app/Providers/InitialSynapticWeightsProvider.php index 6dcb2e7..29b32cc 100644 --- a/app/Providers/InitialSynapticWeightsProvider.php +++ b/app/Providers/InitialSynapticWeightsProvider.php @@ -2,8 +2,8 @@ namespace App\Providers; -use App\Services\ISynapticWeightsProvider; -use App\Services\RandomSynapticWeights; +use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider; +use App\Services\SynapticWeightsProvider\RandomSynapticWeights; use Illuminate\Support\ServiceProvider; class InitialSynapticWeightsProvider extends ServiceProvider diff --git a/app/Services/DatasetReader/IDataSetReader.php b/app/Services/DatasetReader/IDataSetReader.php new file mode 100644 index 0000000..bd5316e --- /dev/null +++ b/app/Services/DatasetReader/IDataSetReader.php @@ -0,0 +1,11 @@ +readEntireFile($csvReader); + $this->reset(); + } + + private function readEntireFile(CsvReader $reader): void + { + while ($line = $reader->readNextLine()) { + $newLine = []; + foreach ($line as $value) { // Transform to float + $newLine[] = (float) $value; + } + + // if the dataset is for regression, we add a fake label of 0 + if (count($newLine) === 2) { + $newLine[] = 0.0; + } + + $this->lines[] = $newLine; + } + } + + public function getNextLine(): array | null { + if (!isset($this->currentLines[0])) { + return null; // No more lines to read + } + + $this->lastReadLineIndex = array_search($this->currentLines[0], $this->lines, true); + + return array_shift($this->currentLines); + } + + public function getInputSize(): int + { + return count($this->lines[0]) - 1; // Don't count the label + } + + public function reset(): void + { + $this->currentLines = $this->lines; + } + + public function getLastReadLineIndex(): int + { + return $this->lastReadLineIndex; + } + + public function getEpochExamplesCount(): int + { + return count($this->lines); + } +} diff --git a/app/Services/DataSetReader.php b/app/Services/DatasetReader/RandomOrderDataSetReader.php similarity index 81% rename from app/Services/DataSetReader.php rename to app/Services/DatasetReader/RandomOrderDataSetReader.php index ddd93fd..eeaa748 100644 --- a/app/Services/DataSetReader.php +++ b/app/Services/DatasetReader/RandomOrderDataSetReader.php @@ -1,8 +1,10 @@ currentLines)) { return null; // No more lines to read @@ -51,16 +53,6 @@ class DataSetReader { return $randomLine; } - public function getNextLine(): array | null { - if (!isset($this->currentLines[0])) { - return null; // No more lines to read - } - - $this->lastReadLineIndex = array_search($this->currentLines[0], $this->lines, true); - - return array_shift($this->currentLines); - } - public function getInputSize(): int { return count($this->lines[0]) - 1; // Don't count the label @@ -75,4 +67,9 @@ class DataSetReader { { return $this->lastReadLineIndex; } + + public function getEpochExamplesCount(): int + { + return count($this->lines); + } } diff --git a/app/Services/IPerceptronIterationEventBuffer.php b/app/Services/IterationEventBuffer/IPerceptronIterationEventBuffer.php similarity index 82% rename from app/Services/IPerceptronIterationEventBuffer.php rename to app/Services/IterationEventBuffer/IPerceptronIterationEventBuffer.php index ecc1de3..ea6c80d 100644 --- a/app/Services/IPerceptronIterationEventBuffer.php +++ b/app/Services/IterationEventBuffer/IPerceptronIterationEventBuffer.php @@ -1,6 +1,6 @@ verifyTrainingResults( + training: $training, + expectedWeights: [[[-1.497898, 0.998228, 0.998228]]], + expectedEpochs: 49 + ); + } +} diff --git a/tests/Unit/Training/SimplePerceptronTest.php b/tests/Unit/Training/SimplePerceptronTest.php new file mode 100644 index 0000000..4f838c0 --- /dev/null +++ b/tests/Unit/Training/SimplePerceptronTest.php @@ -0,0 +1,31 @@ +verifyTrainingResults( + training: $training, + expectedWeights: [[[-3.0, 2.0, 1.0]]], + expectedEpochs: 6 + ); + } +} diff --git a/tests/Unit/Training/TrainingTestCase.php b/tests/Unit/Training/TrainingTestCase.php new file mode 100644 index 0000000..7f3c6ed --- /dev/null +++ b/tests/Unit/Training/TrainingTestCase.php @@ -0,0 +1,25 @@ +start(); + + + // Assert that the final synaptic weights are as expected withing the margin of error + $finalWeights = $training->getSynapticWeights(); + $this->assertEqualsWithDelta($expectedWeights, $finalWeights, self::MARGIN_OF_ERROR, "Final synaptic weights do not match expected values."); + + // Assert that the number of epochs taken is as expected + $this->assertEquals($expectedEpochs, $training->getEpoch(), "Expected training to take $expectedEpochs epochs, but it took {$training->getEpoch()} epochs."); + } + +}