/** * test_probability_selector.h * ============================================================================= * Copyright (c) 2023-present Serhii Snitsaruk and the LimboAI contributors. * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ #ifndef TEST_PROBABILITY_SELECTOR_H #define TEST_PROBABILITY_SELECTOR_H #include "limbo_test.h" #include "modules/limboai/bt/tasks/bt_task.h" #include "modules/limboai/bt/tasks/composites/bt_probability_selector.h" namespace TestProbabilitySelector { TEST_CASE("[Modules][LimboAI] BTProbabilitySelector") { Ref sel = memnew(BTProbabilitySelector); SUBCASE("When empty") { ERR_PRINT_OFF; CHECK(sel->execute(0.01666) == BTTask::FAILURE); ERR_PRINT_ON; } Ref task1 = memnew(BTTestAction); Ref task2 = memnew(BTTestAction); Ref task3 = memnew(BTTestAction); sel->add_child(task1); sel->add_child(task2); sel->add_child(task3); Math::randomize(); SUBCASE("With zero weight") { sel->set_weight(0, 0.0); sel->set_weight(1, 0.0); sel->set_weight(2, 0.0); CHECK(sel->execute(0.01666) == BTTask::FAILURE); for (int i = 0; i < 100; i++) { sel->execute(0.01666); } CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0); CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); } SUBCASE("When a child task returns SUCCESS") { sel->set_weight(0, 1.0); sel->set_weight(1, 0.0); sel->set_weight(2, 0.0); task1->ret_status = BTTask::SUCCESS; CHECK(sel->execute(0.01666) == BTTask::SUCCESS); CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 1, 1, 1); CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0); CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); CHECK(sel->execute(0.01666) == BTTask::SUCCESS); CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 2, 2, 2); CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0); CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); } SUBCASE("With a RUNNING status and a low-weight remaining child") { sel->set_weight(0, 0.0); sel->set_weight(1, 1.0); sel->set_weight(2, 0.0); task1->ret_status = BTTask::FAILURE; task2->ret_status = BTTask::RUNNING; task3->ret_status = BTTask::FAILURE; CHECK(sel->execute(0.01666) == BTTask::RUNNING); CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); // * ignored CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::RUNNING, 1, 1, 0); // * running CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); // * ignored CHECK(sel->execute(0.01666) == BTTask::RUNNING); CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::RUNNING, 1, 2, 0); // * continued CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); task2->ret_status = BTTask::FAILURE; task1->ret_status = BTTask::SUCCESS; sel->set_weight(0, 0.000000000001); // * extremely low weight, however, when it is the only child to evaluate, it should have 100% probability of being chosen. CHECK(sel->execute(0.01666) == BTTask::SUCCESS); CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 1, 1, 1); // * started & succeeded (2) CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FAILURE, 1, 3, 1); // * continued & failed (1) CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); // * ignored } SUBCASE("When all return SUCCESS status") { task1->ret_status = BTTask::SUCCESS; task2->ret_status = BTTask::SUCCESS; task3->ret_status = BTTask::SUCCESS; CHECK(sel->execute(0.01666) == BTTask::SUCCESS); CHECK(sel->execute(0.01666) == BTTask::SUCCESS); CHECK(sel->execute(0.01666) == BTTask::SUCCESS); int num_ticks = task1->num_ticks + task2->num_ticks + task3->num_ticks; CHECK(num_ticks == 3); int num_entries = task1->num_entries + task2->num_entries + task3->num_entries; CHECK(num_entries == 3); int num_exits = task1->num_exits + task2->num_exits + task3->num_exits; CHECK(num_exits == 3); CHECK(task1->is_status_either(BTTask::SUCCESS, BTTask::FRESH)); CHECK(task2->is_status_either(BTTask::SUCCESS, BTTask::FRESH)); CHECK(task3->is_status_either(BTTask::SUCCESS, BTTask::FRESH)); } SUBCASE("With balanced weights") { task1->ret_status = BTTask::SUCCESS; task2->ret_status = BTTask::SUCCESS; task3->ret_status = BTTask::SUCCESS; int sample_size = 1000; sel->set_weight(0, 1.0); sel->set_weight(1, 1.0); sel->set_weight(2, 1.0); for (int i = 0; i < sample_size; i++) { sel->execute(0.01666); } CHECK(task1->num_ticks > 300); CHECK(task1->num_ticks < 366); CHECK(task2->num_ticks > 300); CHECK(task2->num_ticks < 366); CHECK(task3->num_ticks > 300); CHECK(task3->num_ticks < 366); } SUBCASE("With imbalanced weights") { task1->ret_status = BTTask::SUCCESS; task2->ret_status = BTTask::SUCCESS; task3->ret_status = BTTask::SUCCESS; int sample_size = 10000; sel->set_weight(0, 1.0); // * ~1250 sel->set_weight(1, 2.0); // * ~2500 sel->set_weight(2, 5.0); // * ~6250 for (int i = 0; i < sample_size; i++) { sel->execute(0.01666); } CHECK(task1->num_ticks > 1150); CHECK(task1->num_ticks < 1350); CHECK(task2->num_ticks > 2250); CHECK(task2->num_ticks < 2750); CHECK(task3->num_ticks > 5750); CHECK(task3->num_ticks < 6750); } SUBCASE("Test abort_on_failure") { task1->ret_status = BTTask::FAILURE; task2->ret_status = BTTask::FAILURE; task3->ret_status = BTTask::FAILURE; int expected_child_executions = 0; SUBCASE("When abort_on_failure == false") { sel->set_abort_on_failure(false); expected_child_executions = 3; } SUBCASE("When abort_on_failure == true") { sel->set_abort_on_failure(true); expected_child_executions = 1; } sel->execute(0.01666); int num_ticks = task1->num_ticks + task2->num_ticks + task3->num_ticks; CHECK(num_ticks == expected_child_executions); int num_entries = task1->num_entries + task2->num_entries + task3->num_entries; CHECK(num_entries == expected_child_executions); int num_exits = task1->num_exits + task2->num_exits + task3->num_exits; CHECK(num_exits == expected_child_executions); } } } //namespace TestProbabilitySelector #endif // TEST_PROBABILITY_SELECTOR_H