limboai/tests/test_probability_selector.h

189 lines
6.4 KiB
C++

/**
* test_probability_selector.h
* =============================================================================
* Copyright (c) 2023-present Serhii Snitsaruk and the LimboAI contributors.
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
#ifndef TEST_PROBABILITY_SELECTOR_H
#define TEST_PROBABILITY_SELECTOR_H
#include "limbo_test.h"
#include "modules/limboai/bt/tasks/bt_task.h"
#include "modules/limboai/bt/tasks/composites/bt_probability_selector.h"
namespace TestProbabilitySelector {
TEST_CASE("[Modules][LimboAI] BTProbabilitySelector") {
Ref<BTProbabilitySelector> sel = memnew(BTProbabilitySelector);
SUBCASE("When empty") {
ERR_PRINT_OFF;
CHECK(sel->execute(0.01666) == BTTask::FAILURE);
ERR_PRINT_ON;
}
Ref<BTTestAction> task1 = memnew(BTTestAction);
Ref<BTTestAction> task2 = memnew(BTTestAction);
Ref<BTTestAction> task3 = memnew(BTTestAction);
sel->add_child(task1);
sel->add_child(task2);
sel->add_child(task3);
Math::randomize();
SUBCASE("With zero weight") {
sel->set_weight(0, 0.0);
sel->set_weight(1, 0.0);
sel->set_weight(2, 0.0);
CHECK(sel->execute(0.01666) == BTTask::FAILURE);
for (int i = 0; i < 100; i++) {
sel->execute(0.01666);
}
CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0);
}
SUBCASE("When a child task returns SUCCESS") {
sel->set_weight(0, 1.0);
sel->set_weight(1, 0.0);
sel->set_weight(2, 0.0);
task1->ret_status = BTTask::SUCCESS;
CHECK(sel->execute(0.01666) == BTTask::SUCCESS);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 1, 1, 1);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0);
CHECK(sel->execute(0.01666) == BTTask::SUCCESS);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 2, 2, 2);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0);
}
SUBCASE("With a RUNNING status and a low-weight remaining child") {
sel->set_weight(0, 0.0);
sel->set_weight(1, 1.0);
sel->set_weight(2, 0.0);
task1->ret_status = BTTask::FAILURE;
task2->ret_status = BTTask::RUNNING;
task3->ret_status = BTTask::FAILURE;
CHECK(sel->execute(0.01666) == BTTask::RUNNING);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); // * ignored
CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::RUNNING, 1, 1, 0); // * running
CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); // * ignored
CHECK(sel->execute(0.01666) == BTTask::RUNNING);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::RUNNING, 1, 2, 0); // * continued
CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0);
task2->ret_status = BTTask::FAILURE;
task1->ret_status = BTTask::SUCCESS;
sel->set_weight(0, 0.000000000001); // * extremely low weight, however, when it is the only child to evaluate, it should have 100% probability of being chosen.
CHECK(sel->execute(0.01666) == BTTask::SUCCESS);
CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 1, 1, 1); // * started & succeeded (2)
CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FAILURE, 1, 3, 1); // * continued & failed (1)
CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); // * ignored
}
SUBCASE("When all return SUCCESS status") {
task1->ret_status = BTTask::SUCCESS;
task2->ret_status = BTTask::SUCCESS;
task3->ret_status = BTTask::SUCCESS;
CHECK(sel->execute(0.01666) == BTTask::SUCCESS);
CHECK(sel->execute(0.01666) == BTTask::SUCCESS);
CHECK(sel->execute(0.01666) == BTTask::SUCCESS);
int num_ticks = task1->num_ticks + task2->num_ticks + task3->num_ticks;
CHECK(num_ticks == 3);
int num_entries = task1->num_entries + task2->num_entries + task3->num_entries;
CHECK(num_entries == 3);
int num_exits = task1->num_exits + task2->num_exits + task3->num_exits;
CHECK(num_exits == 3);
CHECK(task1->is_status_either(BTTask::SUCCESS, BTTask::FRESH));
CHECK(task2->is_status_either(BTTask::SUCCESS, BTTask::FRESH));
CHECK(task3->is_status_either(BTTask::SUCCESS, BTTask::FRESH));
}
SUBCASE("With balanced weights") {
task1->ret_status = BTTask::SUCCESS;
task2->ret_status = BTTask::SUCCESS;
task3->ret_status = BTTask::SUCCESS;
int sample_size = 1000;
sel->set_weight(0, 1.0);
sel->set_weight(1, 1.0);
sel->set_weight(2, 1.0);
for (int i = 0; i < sample_size; i++) {
sel->execute(0.01666);
}
CHECK(task1->num_ticks > 300);
CHECK(task1->num_ticks < 366);
CHECK(task2->num_ticks > 300);
CHECK(task2->num_ticks < 366);
CHECK(task3->num_ticks > 300);
CHECK(task3->num_ticks < 366);
}
SUBCASE("With imbalanced weights") {
task1->ret_status = BTTask::SUCCESS;
task2->ret_status = BTTask::SUCCESS;
task3->ret_status = BTTask::SUCCESS;
int sample_size = 10000;
sel->set_weight(0, 1.0); // * ~1250
sel->set_weight(1, 2.0); // * ~2500
sel->set_weight(2, 5.0); // * ~6250
for (int i = 0; i < sample_size; i++) {
sel->execute(0.01666);
}
CHECK(task1->num_ticks > 1150);
CHECK(task1->num_ticks < 1350);
CHECK(task2->num_ticks > 2250);
CHECK(task2->num_ticks < 2750);
CHECK(task3->num_ticks > 5750);
CHECK(task3->num_ticks < 6750);
}
SUBCASE("Test abort_on_failure") {
task1->ret_status = BTTask::FAILURE;
task2->ret_status = BTTask::FAILURE;
task3->ret_status = BTTask::FAILURE;
int expected_child_executions = 0;
SUBCASE("When abort_on_failure == false") {
sel->set_abort_on_failure(false);
expected_child_executions = 3;
}
SUBCASE("When abort_on_failure == true") {
sel->set_abort_on_failure(true);
expected_child_executions = 1;
}
sel->execute(0.01666);
int num_ticks = task1->num_ticks + task2->num_ticks + task3->num_ticks;
CHECK(num_ticks == expected_child_executions);
int num_entries = task1->num_entries + task2->num_entries + task3->num_entries;
CHECK(num_entries == expected_child_executions);
int num_exits = task1->num_exits + task2->num_exits + task3->num_exits;
CHECK(num_exits == expected_child_executions);
}
}
} //namespace TestProbabilitySelector
#endif // TEST_PROBABILITY_SELECTOR_H