Refactored into folders
This commit is contained in:
@@ -5,12 +5,14 @@ namespace App\Http\Controllers;
|
||||
use App\Events\PerceptronInitialization;
|
||||
use App\Models\GradientDescentPerceptronTraining;
|
||||
use App\Models\SimpleBinaryPerceptronTraining;
|
||||
use App\Services\DataSetReader;
|
||||
use App\Services\ISynapticWeightsProvider;
|
||||
use App\Services\PerceptronIterationEventBuffer;
|
||||
use App\Services\PerceptronLimitedEpochEventBuffer;
|
||||
use App\Services\ZeroSynapticWeights;
|
||||
use App\Services\DatasetReader\IDataSetReader;
|
||||
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||
use App\Services\IterationEventBuffer\PerceptronIterationEventBuffer;
|
||||
use App\Services\IterationEventBuffer\PerceptronLimitedEpochEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||
use Illuminate\Http\Request;
|
||||
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||
|
||||
class PerceptronController extends Controller
|
||||
{
|
||||
@@ -23,7 +25,7 @@ class PerceptronController extends Controller
|
||||
|
||||
$learningRate = 0.01;
|
||||
$maxIterations = 200;
|
||||
$minError = 0.6;
|
||||
$minError = 0.1;
|
||||
|
||||
switch ($perceptronType) {
|
||||
case 'simple':
|
||||
@@ -53,7 +55,7 @@ class PerceptronController extends Controller
|
||||
if (pathinfo($file, PATHINFO_EXTENSION) === 'csv') {
|
||||
$dataset = [];
|
||||
$dataset['label'] = str_replace('.csv', '', $file);
|
||||
$dataSetReader = new DataSetReader($dataSetsDirectory . '/' . $file);
|
||||
$dataSetReader = new LinearOrderDataSetReader($dataSetsDirectory . '/' . $file);
|
||||
$dataset['data'] = [];
|
||||
switch (count($dataSetReader->lines[0])) {
|
||||
case 3:
|
||||
@@ -84,6 +86,7 @@ class PerceptronController extends Controller
|
||||
switch ($perceptronType) {
|
||||
case 'gradientdescent':
|
||||
$dataset['defaultLearningRate'] = 0.3;
|
||||
$dataset['defaultMinError'] = 0.125;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@@ -94,7 +97,6 @@ class PerceptronController extends Controller
|
||||
break;
|
||||
case 'gradientdescent':
|
||||
$dataset['defaultLearningRate'] = 0.001;
|
||||
$dataset['defaultMinError'] = 2.0;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@@ -108,10 +110,10 @@ class PerceptronController extends Controller
|
||||
return $datasets;
|
||||
}
|
||||
|
||||
private function getDataSetReader(string $dataSet): DataSetReader
|
||||
private function getDataSetReader(string $dataSet): IDataSetReader
|
||||
{
|
||||
$dataSetFileName = "data_sets/{$dataSet}.csv";
|
||||
return new DataSetReader($dataSetFileName);
|
||||
return new LinearOrderDataSetReader($dataSetFileName);
|
||||
}
|
||||
|
||||
public function run(Request $request, ISynapticWeightsProvider $synapticWeightsProvider)
|
||||
@@ -141,6 +143,16 @@ class PerceptronController extends Controller
|
||||
$networkTraining = match ($perceptronType) {
|
||||
'simple' => new SimpleBinaryPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId),
|
||||
'gradientdescent' => new GradientDescentPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId, $minError),
|
||||
'gradientdescentTest' => new GradientDescentPerceptronTraining(
|
||||
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and_gradient.csv')),
|
||||
learningRate: 0.2,
|
||||
maxEpochs: 100,
|
||||
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||
iterationEventBuffer: $iterationEventBuffer,
|
||||
sessionId: 'test-session',
|
||||
trainingId: 'test-training',
|
||||
minError: 0.125001,
|
||||
),
|
||||
default => null,
|
||||
};
|
||||
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
namespace App\Models;
|
||||
|
||||
use App\Events\PerceptronTrainingEnded;
|
||||
use App\Services\DataSetReader;
|
||||
use App\Services\IPerceptronIterationEventBuffer;
|
||||
use App\Services\ISynapticWeightsProvider;
|
||||
use App\Services\PerceptronIterationEventBuffer;
|
||||
use App\Services\DatasetReader\IDataSetReader;
|
||||
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||
|
||||
class GradientDescentPerceptronTraining extends NetworkTraining
|
||||
{
|
||||
@@ -17,7 +16,7 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
||||
private float $epochError;
|
||||
|
||||
public function __construct(
|
||||
DataSetReader $datasetReader,
|
||||
IDataSetReader $datasetReader,
|
||||
protected float $learningRate,
|
||||
int $maxEpochs,
|
||||
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
||||
@@ -38,7 +37,7 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
||||
$epochCorrectorPerWeight = [];
|
||||
$this->epoch++;
|
||||
|
||||
while ($nextRow = $this->datasetReader->getRandomLine()) {
|
||||
while ($nextRow = $this->datasetReader->getNextLine()) {
|
||||
$inputs = array_slice($nextRow, 0, -1);
|
||||
$correctOutput = (float) end($nextRow);
|
||||
|
||||
@@ -89,4 +88,9 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
||||
|
||||
return $error;
|
||||
}
|
||||
|
||||
public function getSynapticWeights(): array
|
||||
{
|
||||
return [[$this->perceptron->getSynapticWeights()]];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
namespace App\Models;
|
||||
|
||||
use App\Events\PerceptronTrainingEnded;
|
||||
use App\Services\DataSetReader;
|
||||
use App\Services\IPerceptronIterationEventBuffer;
|
||||
use App\Services\DatasetReader\IDataSetReader;
|
||||
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||
|
||||
abstract class NetworkTraining
|
||||
{
|
||||
@@ -17,7 +17,7 @@ abstract class NetworkTraining
|
||||
public ActivationsFunctions $activationFunction;
|
||||
|
||||
public function __construct(
|
||||
protected DataSetReader $datasetReader,
|
||||
protected IDataSetReader $datasetReader,
|
||||
protected int $maxEpochs,
|
||||
protected IPerceptronIterationEventBuffer $iterationEventBuffer,
|
||||
protected string $sessionId,
|
||||
@@ -42,4 +42,11 @@ abstract class NetworkTraining
|
||||
protected function addIterationToBuffer(float $error, array $synapticWeights) {
|
||||
$this->iterationEventBuffer->addIteration($this->epoch, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights);
|
||||
}
|
||||
|
||||
public function getEpoch(): int
|
||||
{
|
||||
return $this->epoch;
|
||||
}
|
||||
|
||||
abstract public function getSynapticWeights(): array;
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
namespace App\Models;
|
||||
|
||||
use App\Events\PerceptronTrainingEnded;
|
||||
use App\Services\DataSetReader;
|
||||
use App\Services\IPerceptronIterationEventBuffer;
|
||||
use App\Services\ISynapticWeightsProvider;
|
||||
use App\Services\DatasetReader\IDataSetReader;
|
||||
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||
|
||||
class SimpleBinaryPerceptronTraining extends NetworkTraining
|
||||
{
|
||||
@@ -17,7 +17,7 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining
|
||||
public const MIN_ERROR = 0;
|
||||
|
||||
public function __construct(
|
||||
DataSetReader $datasetReader,
|
||||
IDataSetReader $datasetReader,
|
||||
protected float $learningRate,
|
||||
int $maxEpochs,
|
||||
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
||||
@@ -81,4 +81,9 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining
|
||||
}
|
||||
return $error;
|
||||
}
|
||||
|
||||
public function getSynapticWeights(): array
|
||||
{
|
||||
return [[$this->perceptron->getSynapticWeights()]];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
namespace App\Providers;
|
||||
|
||||
use App\Services\ISynapticWeightsProvider;
|
||||
use App\Services\RandomSynapticWeights;
|
||||
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||
use App\Services\SynapticWeightsProvider\RandomSynapticWeights;
|
||||
use Illuminate\Support\ServiceProvider;
|
||||
|
||||
class InitialSynapticWeightsProvider extends ServiceProvider
|
||||
|
||||
11
app/Services/DatasetReader/IDataSetReader.php
Normal file
11
app/Services/DatasetReader/IDataSetReader.php
Normal file
@@ -0,0 +1,11 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services\DatasetReader;
|
||||
|
||||
interface IDataSetReader {
|
||||
public function getNextLine(): array | null;
|
||||
public function getInputSize(): int;
|
||||
public function reset(): void;
|
||||
public function getLastReadLineIndex(): int;
|
||||
public function getEpochExamplesCount(): int;
|
||||
}
|
||||
68
app/Services/DatasetReader/LinearOrderDataSetReader.php
Normal file
68
app/Services/DatasetReader/LinearOrderDataSetReader.php
Normal file
@@ -0,0 +1,68 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services\DatasetReader;
|
||||
|
||||
use App\Services\CsvReader;
|
||||
|
||||
class LinearOrderDataSetReader implements IDataSetReader {
|
||||
public array $lines = [];
|
||||
private array $currentLines = [];
|
||||
|
||||
private int $lastReadLineIndex = -1;
|
||||
|
||||
public function __construct(
|
||||
public string $filename,
|
||||
) {
|
||||
// For now, we only support CSV files, so we can delegate to CsvReader
|
||||
$csvReader = new CsvReader($filename);
|
||||
$this->readEntireFile($csvReader);
|
||||
$this->reset();
|
||||
}
|
||||
|
||||
private function readEntireFile(CsvReader $reader): void
|
||||
{
|
||||
while ($line = $reader->readNextLine()) {
|
||||
$newLine = [];
|
||||
foreach ($line as $value) { // Transform to float
|
||||
$newLine[] = (float) $value;
|
||||
}
|
||||
|
||||
// if the dataset is for regression, we add a fake label of 0
|
||||
if (count($newLine) === 2) {
|
||||
$newLine[] = 0.0;
|
||||
}
|
||||
|
||||
$this->lines[] = $newLine;
|
||||
}
|
||||
}
|
||||
|
||||
public function getNextLine(): array | null {
|
||||
if (!isset($this->currentLines[0])) {
|
||||
return null; // No more lines to read
|
||||
}
|
||||
|
||||
$this->lastReadLineIndex = array_search($this->currentLines[0], $this->lines, true);
|
||||
|
||||
return array_shift($this->currentLines);
|
||||
}
|
||||
|
||||
public function getInputSize(): int
|
||||
{
|
||||
return count($this->lines[0]) - 1; // Don't count the label
|
||||
}
|
||||
|
||||
public function reset(): void
|
||||
{
|
||||
$this->currentLines = $this->lines;
|
||||
}
|
||||
|
||||
public function getLastReadLineIndex(): int
|
||||
{
|
||||
return $this->lastReadLineIndex;
|
||||
}
|
||||
|
||||
public function getEpochExamplesCount(): int
|
||||
{
|
||||
return count($this->lines);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\DatasetReader;
|
||||
|
||||
class DataSetReader {
|
||||
use App\Services\CsvReader;
|
||||
|
||||
class RandomOrderDataSetReaders implements IDataSetReader {
|
||||
public array $lines = [];
|
||||
private array $currentLines = [];
|
||||
|
||||
@@ -34,7 +36,7 @@ class DataSetReader {
|
||||
}
|
||||
}
|
||||
|
||||
public function getRandomLine(): array | null
|
||||
public function getNextLine(): array | null
|
||||
{
|
||||
if (empty($this->currentLines)) {
|
||||
return null; // No more lines to read
|
||||
@@ -51,16 +53,6 @@ class DataSetReader {
|
||||
return $randomLine;
|
||||
}
|
||||
|
||||
public function getNextLine(): array | null {
|
||||
if (!isset($this->currentLines[0])) {
|
||||
return null; // No more lines to read
|
||||
}
|
||||
|
||||
$this->lastReadLineIndex = array_search($this->currentLines[0], $this->lines, true);
|
||||
|
||||
return array_shift($this->currentLines);
|
||||
}
|
||||
|
||||
public function getInputSize(): int
|
||||
{
|
||||
return count($this->lines[0]) - 1; // Don't count the label
|
||||
@@ -75,4 +67,9 @@ class DataSetReader {
|
||||
{
|
||||
return $this->lastReadLineIndex;
|
||||
}
|
||||
|
||||
public function getEpochExamplesCount(): int
|
||||
{
|
||||
return count($this->lines);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\IterationEventBuffer;
|
||||
|
||||
interface IPerceptronIterationEventBuffer {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\IterationEventBuffer;
|
||||
|
||||
class PerceptronIterationEventBuffer implements IPerceptronIterationEventBuffer {
|
||||
private $data;
|
||||
@@ -1,6 +1,6 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\IterationEventBuffer;
|
||||
|
||||
class PerceptronLimitedEpochEventBuffer implements IPerceptronIterationEventBuffer {
|
||||
private array $data;
|
||||
@@ -1,6 +1,6 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\SynapticWeightsProvider;
|
||||
|
||||
interface ISynapticWeightsProvider {
|
||||
public function generate(int $input_size): array;
|
||||
@@ -1,6 +1,6 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\SynapticWeightsProvider;
|
||||
|
||||
class RandomSynapticWeights implements ISynapticWeightsProvider {
|
||||
public function generate(int $input_size): array
|
||||
@@ -1,6 +1,6 @@
|
||||
<?php
|
||||
|
||||
namespace App\Services;
|
||||
namespace App\Services\SynapticWeightsProvider;
|
||||
|
||||
class ZeroSynapticWeights implements ISynapticWeightsProvider {
|
||||
public function generate(int $input_size): array
|
||||
@@ -0,0 +1,22 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Services\IterationEventBuffer;
|
||||
|
||||
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||
|
||||
class DullIterationEventBuffer implements IPerceptronIterationEventBuffer {
|
||||
|
||||
public function __construct(
|
||||
|
||||
) {
|
||||
|
||||
}
|
||||
|
||||
public function flush(): void {
|
||||
return;
|
||||
}
|
||||
|
||||
public function addIteration(int $epoch, int $exampleIndex, float $error, array $synaptic_weights): void {
|
||||
return;
|
||||
}
|
||||
}
|
||||
32
tests/Unit/Training/GradientDescentPerceptronTest.php
Normal file
32
tests/Unit/Training/GradientDescentPerceptronTest.php
Normal file
@@ -0,0 +1,32 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Unit\Training;
|
||||
|
||||
use App\Models\GradientDescentPerceptronTraining;
|
||||
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||
|
||||
class GradientDescentPerceptronTest extends TrainingTestCase
|
||||
{
|
||||
|
||||
public function test_simple_perceptron_training_logic_and()
|
||||
{
|
||||
$training = new GradientDescentPerceptronTraining(
|
||||
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and_gradient.csv')),
|
||||
learningRate: 0.2,
|
||||
maxEpochs: 100,
|
||||
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||
iterationEventBuffer: new DullIterationEventBuffer(),
|
||||
sessionId: 'test-session',
|
||||
trainingId: 'test-training',
|
||||
minError: 0.125001,
|
||||
);
|
||||
|
||||
$this->verifyTrainingResults(
|
||||
training: $training,
|
||||
expectedWeights: [[[-1.497898, 0.998228, 0.998228]]],
|
||||
expectedEpochs: 49
|
||||
);
|
||||
}
|
||||
}
|
||||
31
tests/Unit/Training/SimplePerceptronTest.php
Normal file
31
tests/Unit/Training/SimplePerceptronTest.php
Normal file
@@ -0,0 +1,31 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Unit\Training;
|
||||
|
||||
use App\Models\SimpleBinaryPerceptronTraining;
|
||||
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||
|
||||
class SimplePerceptronTest extends TrainingTestCase
|
||||
{
|
||||
|
||||
public function test_simple_perceptron_training_logic_and()
|
||||
{
|
||||
$training = new SimpleBinaryPerceptronTraining(
|
||||
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and.csv')),
|
||||
learningRate: 1.0,
|
||||
maxEpochs: 100,
|
||||
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||
iterationEventBuffer: new DullIterationEventBuffer(),
|
||||
sessionId: 'test-session',
|
||||
trainingId: 'test-training',
|
||||
);
|
||||
|
||||
$this->verifyTrainingResults(
|
||||
training: $training,
|
||||
expectedWeights: [[[-3.0, 2.0, 1.0]]],
|
||||
expectedEpochs: 6
|
||||
);
|
||||
}
|
||||
}
|
||||
25
tests/Unit/Training/TrainingTestCase.php
Normal file
25
tests/Unit/Training/TrainingTestCase.php
Normal file
@@ -0,0 +1,25 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Unit\Training;
|
||||
|
||||
use App\Models\NetworkTraining;
|
||||
use Tests\TestCase;
|
||||
|
||||
class TrainingTestCase extends TestCase
|
||||
{
|
||||
public const MARGIN_OF_ERROR = 0.001;
|
||||
|
||||
public function verifyTrainingResults(NetworkTraining $training, array $expectedWeights, int $expectedEpochs): void
|
||||
{
|
||||
$training->start();
|
||||
|
||||
|
||||
// Assert that the final synaptic weights are as expected withing the margin of error
|
||||
$finalWeights = $training->getSynapticWeights();
|
||||
$this->assertEqualsWithDelta($expectedWeights, $finalWeights, self::MARGIN_OF_ERROR, "Final synaptic weights do not match expected values.");
|
||||
|
||||
// Assert that the number of epochs taken is as expected
|
||||
$this->assertEquals($expectedEpochs, $training->getEpoch(), "Expected training to take $expectedEpochs epochs, but it took {$training->getEpoch()} epochs.");
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user