Refactored into folders
This commit is contained in:
@@ -5,12 +5,14 @@ namespace App\Http\Controllers;
|
|||||||
use App\Events\PerceptronInitialization;
|
use App\Events\PerceptronInitialization;
|
||||||
use App\Models\GradientDescentPerceptronTraining;
|
use App\Models\GradientDescentPerceptronTraining;
|
||||||
use App\Models\SimpleBinaryPerceptronTraining;
|
use App\Models\SimpleBinaryPerceptronTraining;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DatasetReader\IDataSetReader;
|
||||||
use App\Services\ISynapticWeightsProvider;
|
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||||
use App\Services\PerceptronIterationEventBuffer;
|
use App\Services\IterationEventBuffer\PerceptronIterationEventBuffer;
|
||||||
use App\Services\PerceptronLimitedEpochEventBuffer;
|
use App\Services\IterationEventBuffer\PerceptronLimitedEpochEventBuffer;
|
||||||
use App\Services\ZeroSynapticWeights;
|
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||||
|
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||||
use Illuminate\Http\Request;
|
use Illuminate\Http\Request;
|
||||||
|
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||||
|
|
||||||
class PerceptronController extends Controller
|
class PerceptronController extends Controller
|
||||||
{
|
{
|
||||||
@@ -23,7 +25,7 @@ class PerceptronController extends Controller
|
|||||||
|
|
||||||
$learningRate = 0.01;
|
$learningRate = 0.01;
|
||||||
$maxIterations = 200;
|
$maxIterations = 200;
|
||||||
$minError = 0.6;
|
$minError = 0.1;
|
||||||
|
|
||||||
switch ($perceptronType) {
|
switch ($perceptronType) {
|
||||||
case 'simple':
|
case 'simple':
|
||||||
@@ -53,7 +55,7 @@ class PerceptronController extends Controller
|
|||||||
if (pathinfo($file, PATHINFO_EXTENSION) === 'csv') {
|
if (pathinfo($file, PATHINFO_EXTENSION) === 'csv') {
|
||||||
$dataset = [];
|
$dataset = [];
|
||||||
$dataset['label'] = str_replace('.csv', '', $file);
|
$dataset['label'] = str_replace('.csv', '', $file);
|
||||||
$dataSetReader = new DataSetReader($dataSetsDirectory . '/' . $file);
|
$dataSetReader = new LinearOrderDataSetReader($dataSetsDirectory . '/' . $file);
|
||||||
$dataset['data'] = [];
|
$dataset['data'] = [];
|
||||||
switch (count($dataSetReader->lines[0])) {
|
switch (count($dataSetReader->lines[0])) {
|
||||||
case 3:
|
case 3:
|
||||||
@@ -84,6 +86,7 @@ class PerceptronController extends Controller
|
|||||||
switch ($perceptronType) {
|
switch ($perceptronType) {
|
||||||
case 'gradientdescent':
|
case 'gradientdescent':
|
||||||
$dataset['defaultLearningRate'] = 0.3;
|
$dataset['defaultLearningRate'] = 0.3;
|
||||||
|
$dataset['defaultMinError'] = 0.125;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -94,7 +97,6 @@ class PerceptronController extends Controller
|
|||||||
break;
|
break;
|
||||||
case 'gradientdescent':
|
case 'gradientdescent':
|
||||||
$dataset['defaultLearningRate'] = 0.001;
|
$dataset['defaultLearningRate'] = 0.001;
|
||||||
$dataset['defaultMinError'] = 2.0;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -108,10 +110,10 @@ class PerceptronController extends Controller
|
|||||||
return $datasets;
|
return $datasets;
|
||||||
}
|
}
|
||||||
|
|
||||||
private function getDataSetReader(string $dataSet): DataSetReader
|
private function getDataSetReader(string $dataSet): IDataSetReader
|
||||||
{
|
{
|
||||||
$dataSetFileName = "data_sets/{$dataSet}.csv";
|
$dataSetFileName = "data_sets/{$dataSet}.csv";
|
||||||
return new DataSetReader($dataSetFileName);
|
return new LinearOrderDataSetReader($dataSetFileName);
|
||||||
}
|
}
|
||||||
|
|
||||||
public function run(Request $request, ISynapticWeightsProvider $synapticWeightsProvider)
|
public function run(Request $request, ISynapticWeightsProvider $synapticWeightsProvider)
|
||||||
@@ -141,6 +143,16 @@ class PerceptronController extends Controller
|
|||||||
$networkTraining = match ($perceptronType) {
|
$networkTraining = match ($perceptronType) {
|
||||||
'simple' => new SimpleBinaryPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId),
|
'simple' => new SimpleBinaryPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId),
|
||||||
'gradientdescent' => new GradientDescentPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId, $minError),
|
'gradientdescent' => new GradientDescentPerceptronTraining($dataSetReader, $learningRate, $maxIterations, $synapticWeightsProvider, $iterationEventBuffer, $sessionId, $trainingId, $minError),
|
||||||
|
'gradientdescentTest' => new GradientDescentPerceptronTraining(
|
||||||
|
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and_gradient.csv')),
|
||||||
|
learningRate: 0.2,
|
||||||
|
maxEpochs: 100,
|
||||||
|
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||||
|
iterationEventBuffer: $iterationEventBuffer,
|
||||||
|
sessionId: 'test-session',
|
||||||
|
trainingId: 'test-training',
|
||||||
|
minError: 0.125001,
|
||||||
|
),
|
||||||
default => null,
|
default => null,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,9 @@
|
|||||||
namespace App\Models;
|
namespace App\Models;
|
||||||
|
|
||||||
use App\Events\PerceptronTrainingEnded;
|
use App\Events\PerceptronTrainingEnded;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DatasetReader\IDataSetReader;
|
||||||
use App\Services\IPerceptronIterationEventBuffer;
|
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||||
use App\Services\ISynapticWeightsProvider;
|
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||||
use App\Services\PerceptronIterationEventBuffer;
|
|
||||||
|
|
||||||
class GradientDescentPerceptronTraining extends NetworkTraining
|
class GradientDescentPerceptronTraining extends NetworkTraining
|
||||||
{
|
{
|
||||||
@@ -17,7 +16,7 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
|||||||
private float $epochError;
|
private float $epochError;
|
||||||
|
|
||||||
public function __construct(
|
public function __construct(
|
||||||
DataSetReader $datasetReader,
|
IDataSetReader $datasetReader,
|
||||||
protected float $learningRate,
|
protected float $learningRate,
|
||||||
int $maxEpochs,
|
int $maxEpochs,
|
||||||
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
||||||
@@ -38,7 +37,7 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
|||||||
$epochCorrectorPerWeight = [];
|
$epochCorrectorPerWeight = [];
|
||||||
$this->epoch++;
|
$this->epoch++;
|
||||||
|
|
||||||
while ($nextRow = $this->datasetReader->getRandomLine()) {
|
while ($nextRow = $this->datasetReader->getNextLine()) {
|
||||||
$inputs = array_slice($nextRow, 0, -1);
|
$inputs = array_slice($nextRow, 0, -1);
|
||||||
$correctOutput = (float) end($nextRow);
|
$correctOutput = (float) end($nextRow);
|
||||||
|
|
||||||
@@ -89,4 +88,9 @@ class GradientDescentPerceptronTraining extends NetworkTraining
|
|||||||
|
|
||||||
return $error;
|
return $error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function getSynapticWeights(): array
|
||||||
|
{
|
||||||
|
return [[$this->perceptron->getSynapticWeights()]];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
namespace App\Models;
|
namespace App\Models;
|
||||||
|
|
||||||
use App\Events\PerceptronTrainingEnded;
|
use App\Events\PerceptronTrainingEnded;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DatasetReader\IDataSetReader;
|
||||||
use App\Services\IPerceptronIterationEventBuffer;
|
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||||
|
|
||||||
abstract class NetworkTraining
|
abstract class NetworkTraining
|
||||||
{
|
{
|
||||||
@@ -17,7 +17,7 @@ abstract class NetworkTraining
|
|||||||
public ActivationsFunctions $activationFunction;
|
public ActivationsFunctions $activationFunction;
|
||||||
|
|
||||||
public function __construct(
|
public function __construct(
|
||||||
protected DataSetReader $datasetReader,
|
protected IDataSetReader $datasetReader,
|
||||||
protected int $maxEpochs,
|
protected int $maxEpochs,
|
||||||
protected IPerceptronIterationEventBuffer $iterationEventBuffer,
|
protected IPerceptronIterationEventBuffer $iterationEventBuffer,
|
||||||
protected string $sessionId,
|
protected string $sessionId,
|
||||||
@@ -42,4 +42,11 @@ abstract class NetworkTraining
|
|||||||
protected function addIterationToBuffer(float $error, array $synapticWeights) {
|
protected function addIterationToBuffer(float $error, array $synapticWeights) {
|
||||||
$this->iterationEventBuffer->addIteration($this->epoch, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights);
|
$this->iterationEventBuffer->addIteration($this->epoch, $this->datasetReader->getLastReadLineIndex(), $error, $synapticWeights);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function getEpoch(): int
|
||||||
|
{
|
||||||
|
return $this->epoch;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract public function getSynapticWeights(): array;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
namespace App\Models;
|
namespace App\Models;
|
||||||
|
|
||||||
use App\Events\PerceptronTrainingEnded;
|
use App\Events\PerceptronTrainingEnded;
|
||||||
use App\Services\DataSetReader;
|
use App\Services\DatasetReader\IDataSetReader;
|
||||||
use App\Services\IPerceptronIterationEventBuffer;
|
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||||
use App\Services\ISynapticWeightsProvider;
|
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||||
|
|
||||||
class SimpleBinaryPerceptronTraining extends NetworkTraining
|
class SimpleBinaryPerceptronTraining extends NetworkTraining
|
||||||
{
|
{
|
||||||
@@ -17,7 +17,7 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining
|
|||||||
public const MIN_ERROR = 0;
|
public const MIN_ERROR = 0;
|
||||||
|
|
||||||
public function __construct(
|
public function __construct(
|
||||||
DataSetReader $datasetReader,
|
IDataSetReader $datasetReader,
|
||||||
protected float $learningRate,
|
protected float $learningRate,
|
||||||
int $maxEpochs,
|
int $maxEpochs,
|
||||||
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
protected ISynapticWeightsProvider $synapticWeightsProvider,
|
||||||
@@ -81,4 +81,9 @@ class SimpleBinaryPerceptronTraining extends NetworkTraining
|
|||||||
}
|
}
|
||||||
return $error;
|
return $error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function getSynapticWeights(): array
|
||||||
|
{
|
||||||
|
return [[$this->perceptron->getSynapticWeights()]];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
namespace App\Providers;
|
namespace App\Providers;
|
||||||
|
|
||||||
use App\Services\ISynapticWeightsProvider;
|
use App\Services\SynapticWeightsProvider\ISynapticWeightsProvider;
|
||||||
use App\Services\RandomSynapticWeights;
|
use App\Services\SynapticWeightsProvider\RandomSynapticWeights;
|
||||||
use Illuminate\Support\ServiceProvider;
|
use Illuminate\Support\ServiceProvider;
|
||||||
|
|
||||||
class InitialSynapticWeightsProvider extends ServiceProvider
|
class InitialSynapticWeightsProvider extends ServiceProvider
|
||||||
|
|||||||
11
app/Services/DatasetReader/IDataSetReader.php
Normal file
11
app/Services/DatasetReader/IDataSetReader.php
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
namespace App\Services\DatasetReader;
|
||||||
|
|
||||||
|
interface IDataSetReader {
|
||||||
|
public function getNextLine(): array | null;
|
||||||
|
public function getInputSize(): int;
|
||||||
|
public function reset(): void;
|
||||||
|
public function getLastReadLineIndex(): int;
|
||||||
|
public function getEpochExamplesCount(): int;
|
||||||
|
}
|
||||||
68
app/Services/DatasetReader/LinearOrderDataSetReader.php
Normal file
68
app/Services/DatasetReader/LinearOrderDataSetReader.php
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
namespace App\Services\DatasetReader;
|
||||||
|
|
||||||
|
use App\Services\CsvReader;
|
||||||
|
|
||||||
|
class LinearOrderDataSetReader implements IDataSetReader {
|
||||||
|
public array $lines = [];
|
||||||
|
private array $currentLines = [];
|
||||||
|
|
||||||
|
private int $lastReadLineIndex = -1;
|
||||||
|
|
||||||
|
public function __construct(
|
||||||
|
public string $filename,
|
||||||
|
) {
|
||||||
|
// For now, we only support CSV files, so we can delegate to CsvReader
|
||||||
|
$csvReader = new CsvReader($filename);
|
||||||
|
$this->readEntireFile($csvReader);
|
||||||
|
$this->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
private function readEntireFile(CsvReader $reader): void
|
||||||
|
{
|
||||||
|
while ($line = $reader->readNextLine()) {
|
||||||
|
$newLine = [];
|
||||||
|
foreach ($line as $value) { // Transform to float
|
||||||
|
$newLine[] = (float) $value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the dataset is for regression, we add a fake label of 0
|
||||||
|
if (count($newLine) === 2) {
|
||||||
|
$newLine[] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->lines[] = $newLine;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public function getNextLine(): array | null {
|
||||||
|
if (!isset($this->currentLines[0])) {
|
||||||
|
return null; // No more lines to read
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->lastReadLineIndex = array_search($this->currentLines[0], $this->lines, true);
|
||||||
|
|
||||||
|
return array_shift($this->currentLines);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function getInputSize(): int
|
||||||
|
{
|
||||||
|
return count($this->lines[0]) - 1; // Don't count the label
|
||||||
|
}
|
||||||
|
|
||||||
|
public function reset(): void
|
||||||
|
{
|
||||||
|
$this->currentLines = $this->lines;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function getLastReadLineIndex(): int
|
||||||
|
{
|
||||||
|
return $this->lastReadLineIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function getEpochExamplesCount(): int
|
||||||
|
{
|
||||||
|
return count($this->lines);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\DatasetReader;
|
||||||
|
|
||||||
class DataSetReader {
|
use App\Services\CsvReader;
|
||||||
|
|
||||||
|
class RandomOrderDataSetReaders implements IDataSetReader {
|
||||||
public array $lines = [];
|
public array $lines = [];
|
||||||
private array $currentLines = [];
|
private array $currentLines = [];
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ class DataSetReader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public function getRandomLine(): array | null
|
public function getNextLine(): array | null
|
||||||
{
|
{
|
||||||
if (empty($this->currentLines)) {
|
if (empty($this->currentLines)) {
|
||||||
return null; // No more lines to read
|
return null; // No more lines to read
|
||||||
@@ -51,16 +53,6 @@ class DataSetReader {
|
|||||||
return $randomLine;
|
return $randomLine;
|
||||||
}
|
}
|
||||||
|
|
||||||
public function getNextLine(): array | null {
|
|
||||||
if (!isset($this->currentLines[0])) {
|
|
||||||
return null; // No more lines to read
|
|
||||||
}
|
|
||||||
|
|
||||||
$this->lastReadLineIndex = array_search($this->currentLines[0], $this->lines, true);
|
|
||||||
|
|
||||||
return array_shift($this->currentLines);
|
|
||||||
}
|
|
||||||
|
|
||||||
public function getInputSize(): int
|
public function getInputSize(): int
|
||||||
{
|
{
|
||||||
return count($this->lines[0]) - 1; // Don't count the label
|
return count($this->lines[0]) - 1; // Don't count the label
|
||||||
@@ -75,4 +67,9 @@ class DataSetReader {
|
|||||||
{
|
{
|
||||||
return $this->lastReadLineIndex;
|
return $this->lastReadLineIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function getEpochExamplesCount(): int
|
||||||
|
{
|
||||||
|
return count($this->lines);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\IterationEventBuffer;
|
||||||
|
|
||||||
interface IPerceptronIterationEventBuffer {
|
interface IPerceptronIterationEventBuffer {
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\IterationEventBuffer;
|
||||||
|
|
||||||
class PerceptronIterationEventBuffer implements IPerceptronIterationEventBuffer {
|
class PerceptronIterationEventBuffer implements IPerceptronIterationEventBuffer {
|
||||||
private $data;
|
private $data;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\IterationEventBuffer;
|
||||||
|
|
||||||
class PerceptronLimitedEpochEventBuffer implements IPerceptronIterationEventBuffer {
|
class PerceptronLimitedEpochEventBuffer implements IPerceptronIterationEventBuffer {
|
||||||
private array $data;
|
private array $data;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\SynapticWeightsProvider;
|
||||||
|
|
||||||
interface ISynapticWeightsProvider {
|
interface ISynapticWeightsProvider {
|
||||||
public function generate(int $input_size): array;
|
public function generate(int $input_size): array;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\SynapticWeightsProvider;
|
||||||
|
|
||||||
class RandomSynapticWeights implements ISynapticWeightsProvider {
|
class RandomSynapticWeights implements ISynapticWeightsProvider {
|
||||||
public function generate(int $input_size): array
|
public function generate(int $input_size): array
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
<?php
|
<?php
|
||||||
|
|
||||||
namespace App\Services;
|
namespace App\Services\SynapticWeightsProvider;
|
||||||
|
|
||||||
class ZeroSynapticWeights implements ISynapticWeightsProvider {
|
class ZeroSynapticWeights implements ISynapticWeightsProvider {
|
||||||
public function generate(int $input_size): array
|
public function generate(int $input_size): array
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
namespace Tests\Services\IterationEventBuffer;
|
||||||
|
|
||||||
|
use App\Services\IterationEventBuffer\IPerceptronIterationEventBuffer;
|
||||||
|
|
||||||
|
class DullIterationEventBuffer implements IPerceptronIterationEventBuffer {
|
||||||
|
|
||||||
|
public function __construct(
|
||||||
|
|
||||||
|
) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public function flush(): void {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function addIteration(int $epoch, int $exampleIndex, float $error, array $synaptic_weights): void {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
32
tests/Unit/Training/GradientDescentPerceptronTest.php
Normal file
32
tests/Unit/Training/GradientDescentPerceptronTest.php
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
namespace Tests\Unit\Training;
|
||||||
|
|
||||||
|
use App\Models\GradientDescentPerceptronTraining;
|
||||||
|
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||||
|
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||||
|
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||||
|
|
||||||
|
class GradientDescentPerceptronTest extends TrainingTestCase
|
||||||
|
{
|
||||||
|
|
||||||
|
public function test_simple_perceptron_training_logic_and()
|
||||||
|
{
|
||||||
|
$training = new GradientDescentPerceptronTraining(
|
||||||
|
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and_gradient.csv')),
|
||||||
|
learningRate: 0.2,
|
||||||
|
maxEpochs: 100,
|
||||||
|
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||||
|
iterationEventBuffer: new DullIterationEventBuffer(),
|
||||||
|
sessionId: 'test-session',
|
||||||
|
trainingId: 'test-training',
|
||||||
|
minError: 0.125001,
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->verifyTrainingResults(
|
||||||
|
training: $training,
|
||||||
|
expectedWeights: [[[-1.497898, 0.998228, 0.998228]]],
|
||||||
|
expectedEpochs: 49
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
31
tests/Unit/Training/SimplePerceptronTest.php
Normal file
31
tests/Unit/Training/SimplePerceptronTest.php
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
namespace Tests\Unit\Training;
|
||||||
|
|
||||||
|
use App\Models\SimpleBinaryPerceptronTraining;
|
||||||
|
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||||
|
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||||
|
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||||
|
|
||||||
|
class SimplePerceptronTest extends TrainingTestCase
|
||||||
|
{
|
||||||
|
|
||||||
|
public function test_simple_perceptron_training_logic_and()
|
||||||
|
{
|
||||||
|
$training = new SimpleBinaryPerceptronTraining(
|
||||||
|
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and.csv')),
|
||||||
|
learningRate: 1.0,
|
||||||
|
maxEpochs: 100,
|
||||||
|
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||||
|
iterationEventBuffer: new DullIterationEventBuffer(),
|
||||||
|
sessionId: 'test-session',
|
||||||
|
trainingId: 'test-training',
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->verifyTrainingResults(
|
||||||
|
training: $training,
|
||||||
|
expectedWeights: [[[-3.0, 2.0, 1.0]]],
|
||||||
|
expectedEpochs: 6
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
25
tests/Unit/Training/TrainingTestCase.php
Normal file
25
tests/Unit/Training/TrainingTestCase.php
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
namespace Tests\Unit\Training;
|
||||||
|
|
||||||
|
use App\Models\NetworkTraining;
|
||||||
|
use Tests\TestCase;
|
||||||
|
|
||||||
|
class TrainingTestCase extends TestCase
|
||||||
|
{
|
||||||
|
public const MARGIN_OF_ERROR = 0.001;
|
||||||
|
|
||||||
|
public function verifyTrainingResults(NetworkTraining $training, array $expectedWeights, int $expectedEpochs): void
|
||||||
|
{
|
||||||
|
$training->start();
|
||||||
|
|
||||||
|
|
||||||
|
// Assert that the final synaptic weights are as expected withing the margin of error
|
||||||
|
$finalWeights = $training->getSynapticWeights();
|
||||||
|
$this->assertEqualsWithDelta($expectedWeights, $finalWeights, self::MARGIN_OF_ERROR, "Final synaptic weights do not match expected values.");
|
||||||
|
|
||||||
|
// Assert that the number of epochs taken is as expected
|
||||||
|
$this->assertEquals($expectedEpochs, $training->getEpoch(), "Expected training to take $expectedEpochs epochs, but it took {$training->getEpoch()} epochs.");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user