Refactored into folders
This commit is contained in:
32
tests/Unit/Training/GradientDescentPerceptronTest.php
Normal file
32
tests/Unit/Training/GradientDescentPerceptronTest.php
Normal file
@@ -0,0 +1,32 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Unit\Training;
|
||||
|
||||
use App\Models\GradientDescentPerceptronTraining;
|
||||
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||
|
||||
class GradientDescentPerceptronTest extends TrainingTestCase
|
||||
{
|
||||
|
||||
public function test_simple_perceptron_training_logic_and()
|
||||
{
|
||||
$training = new GradientDescentPerceptronTraining(
|
||||
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and_gradient.csv')),
|
||||
learningRate: 0.2,
|
||||
maxEpochs: 100,
|
||||
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||
iterationEventBuffer: new DullIterationEventBuffer(),
|
||||
sessionId: 'test-session',
|
||||
trainingId: 'test-training',
|
||||
minError: 0.125001,
|
||||
);
|
||||
|
||||
$this->verifyTrainingResults(
|
||||
training: $training,
|
||||
expectedWeights: [[[-1.497898, 0.998228, 0.998228]]],
|
||||
expectedEpochs: 49
|
||||
);
|
||||
}
|
||||
}
|
||||
31
tests/Unit/Training/SimplePerceptronTest.php
Normal file
31
tests/Unit/Training/SimplePerceptronTest.php
Normal file
@@ -0,0 +1,31 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Unit\Training;
|
||||
|
||||
use App\Models\SimpleBinaryPerceptronTraining;
|
||||
use App\Services\DatasetReader\LinearOrderDataSetReader;
|
||||
use Tests\Services\IterationEventBuffer\DullIterationEventBuffer;
|
||||
use App\Services\SynapticWeightsProvider\ZeroSynapticWeights;
|
||||
|
||||
class SimplePerceptronTest extends TrainingTestCase
|
||||
{
|
||||
|
||||
public function test_simple_perceptron_training_logic_and()
|
||||
{
|
||||
$training = new SimpleBinaryPerceptronTraining(
|
||||
datasetReader: new LinearOrderDataSetReader(public_path('data_sets/logic_and.csv')),
|
||||
learningRate: 1.0,
|
||||
maxEpochs: 100,
|
||||
synapticWeightsProvider: new ZeroSynapticWeights(),
|
||||
iterationEventBuffer: new DullIterationEventBuffer(),
|
||||
sessionId: 'test-session',
|
||||
trainingId: 'test-training',
|
||||
);
|
||||
|
||||
$this->verifyTrainingResults(
|
||||
training: $training,
|
||||
expectedWeights: [[[-3.0, 2.0, 1.0]]],
|
||||
expectedEpochs: 6
|
||||
);
|
||||
}
|
||||
}
|
||||
25
tests/Unit/Training/TrainingTestCase.php
Normal file
25
tests/Unit/Training/TrainingTestCase.php
Normal file
@@ -0,0 +1,25 @@
|
||||
<?php
|
||||
|
||||
namespace Tests\Unit\Training;
|
||||
|
||||
use App\Models\NetworkTraining;
|
||||
use Tests\TestCase;
|
||||
|
||||
class TrainingTestCase extends TestCase
|
||||
{
|
||||
public const MARGIN_OF_ERROR = 0.001;
|
||||
|
||||
public function verifyTrainingResults(NetworkTraining $training, array $expectedWeights, int $expectedEpochs): void
|
||||
{
|
||||
$training->start();
|
||||
|
||||
|
||||
// Assert that the final synaptic weights are as expected withing the margin of error
|
||||
$finalWeights = $training->getSynapticWeights();
|
||||
$this->assertEqualsWithDelta($expectedWeights, $finalWeights, self::MARGIN_OF_ERROR, "Final synaptic weights do not match expected values.");
|
||||
|
||||
// Assert that the number of epochs taken is as expected
|
||||
$this->assertEquals($expectedEpochs, $training->getEpoch(), "Expected training to take $expectedEpochs epochs, but it took {$training->getEpoch()} epochs.");
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user