From b33c1cae319359f12f0c7fc690722faaad0d1b5c Mon Sep 17 00:00:00 2001 From: Serhii Snitsaruk Date: Sun, 24 Sep 2023 14:12:50 +0200 Subject: [PATCH] Add tests for BTProbabilitySelector --- tests/test_probability_selector.h | 164 ++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 tests/test_probability_selector.h diff --git a/tests/test_probability_selector.h b/tests/test_probability_selector.h new file mode 100644 index 0000000..fefe3eb --- /dev/null +++ b/tests/test_probability_selector.h @@ -0,0 +1,164 @@ +/** + * test_probability_selector.h + * ============================================================================= + * Copyright 2021-2023 Serhii Snitsaruk + * + * Use of this source code is governed by an MIT-style + * license that can be found in the LICENSE file or at + * https://opensource.org/licenses/MIT. + * ============================================================================= + */ + +#ifndef TEST_PROBABILITY_SELECTOR_H +#define TEST_PROBABILITY_SELECTOR_H + +#include "limbo_test.h" + +#include "modules/limboai/bt/tasks/bt_task.h" +#include "modules/limboai/bt/tasks/composites/bt_probability_selector.h" + +namespace TestProbabilitySelector { + +TEST_CASE("[Modules][LimboAI] BTProbabilitySelector") { + Ref sel = memnew(BTProbabilitySelector); + + SUBCASE("When empty") { + ERR_PRINT_OFF; + CHECK(sel->execute(0.01666) == BTTask::FAILURE); + ERR_PRINT_ON; + } + + Ref task1 = memnew(BTTestAction); + Ref task2 = memnew(BTTestAction); + Ref task3 = memnew(BTTestAction); + sel->add_child(task1); + sel->add_child(task2); + sel->add_child(task3); + + Math::randomize(); + + SUBCASE("With zero weight") { + sel->set_weight(0, 0.0); + sel->set_weight(1, 0.0); + sel->set_weight(2, 0.0); + + CHECK(sel->execute(0.01666) == BTTask::FAILURE); + + for (int i = 0; i < 100; i++) { + sel->execute(0.01666); + } + + CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); + } + SUBCASE("When a child task returns SUCCESS") { + sel->set_weight(0, 1.0); + sel->set_weight(1, 0.0); + sel->set_weight(2, 0.0); + task1->ret_status = BTTask::SUCCESS; + + CHECK(sel->execute(0.01666) == BTTask::SUCCESS); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 1, 1, 1); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); + CHECK(sel->execute(0.01666) == BTTask::SUCCESS); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 2, 2, 2); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FRESH, 0, 0, 0); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); + } + SUBCASE("With a RUNNING status and a low-weight remaining child") { + sel->set_weight(0, 0.0); + sel->set_weight(1, 1.0); + sel->set_weight(2, 0.0); + task1->ret_status = BTTask::FAILURE; + task2->ret_status = BTTask::RUNNING; + task3->ret_status = BTTask::FAILURE; + + CHECK(sel->execute(0.01666) == BTTask::RUNNING); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); // * ignored + CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::RUNNING, 1, 1, 0); // * running + CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); // * ignored + + CHECK(sel->execute(0.01666) == BTTask::RUNNING); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::FRESH, 0, 0, 0); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::RUNNING, 1, 2, 0); // * continued + CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); + + task2->ret_status = BTTask::FAILURE; + task1->ret_status = BTTask::SUCCESS; + sel->set_weight(0, 0.000000000001); // * extremely low weight, however, when it is the only child to evaluate, it should have 100% probability of being chosen. + CHECK(sel->execute(0.01666) == BTTask::SUCCESS); + CHECK_STATUS_ENTRIES_TICKS_EXITS(task1, BTTask::SUCCESS, 1, 1, 1); // * started & succeeded (2) + CHECK_STATUS_ENTRIES_TICKS_EXITS(task2, BTTask::FAILURE, 1, 3, 1); // * continued & failed (1) + CHECK_STATUS_ENTRIES_TICKS_EXITS(task3, BTTask::FRESH, 0, 0, 0); // * ignored + } + SUBCASE("When all return SUCCESS status") { + task1->ret_status = BTTask::SUCCESS; + task2->ret_status = BTTask::SUCCESS; + task3->ret_status = BTTask::SUCCESS; + + CHECK(sel->execute(0.01666) == BTTask::SUCCESS); + CHECK(sel->execute(0.01666) == BTTask::SUCCESS); + CHECK(sel->execute(0.01666) == BTTask::SUCCESS); + + int num_ticks = task1->num_ticks + task2->num_ticks + task3->num_ticks; + CHECK(num_ticks == 3); + + int num_entries = task1->num_entries + task2->num_entries + task3->num_entries; + CHECK(num_entries == 3); + + int num_exits = task1->num_exits + task2->num_exits + task3->num_exits; + CHECK(num_exits == 3); + + CHECK(task1->is_status_either(BTTask::SUCCESS, BTTask::FRESH)); + CHECK(task2->is_status_either(BTTask::SUCCESS, BTTask::FRESH)); + CHECK(task3->is_status_either(BTTask::SUCCESS, BTTask::FRESH)); + } + SUBCASE("With balanced weights") { + task1->ret_status = BTTask::SUCCESS; + task2->ret_status = BTTask::SUCCESS; + task3->ret_status = BTTask::SUCCESS; + + int sample_size = 1000; + sel->set_weight(0, 1.0); + sel->set_weight(1, 1.0); + sel->set_weight(2, 1.0); + + for (int i = 0; i < sample_size; i++) { + sel->execute(0.01666); + } + + CHECK(task1->num_ticks > 300); + CHECK(task1->num_ticks < 366); + CHECK(task2->num_ticks > 300); + CHECK(task2->num_ticks < 366); + CHECK(task3->num_ticks > 300); + CHECK(task3->num_ticks < 366); + } + SUBCASE("With imbalanced weights") { + task1->ret_status = BTTask::SUCCESS; + task2->ret_status = BTTask::SUCCESS; + task3->ret_status = BTTask::SUCCESS; + + int sample_size = 10000; + sel->set_weight(0, 1.0); // * ~1250 + sel->set_weight(1, 2.0); // * ~2500 + sel->set_weight(2, 5.0); // * ~6250 + + for (int i = 0; i < sample_size; i++) { + sel->execute(0.01666); + } + + CHECK(task1->num_ticks > 1150); + CHECK(task1->num_ticks < 1350); + CHECK(task2->num_ticks > 2250); + CHECK(task2->num_ticks < 2750); + CHECK(task3->num_ticks > 5750); + CHECK(task3->num_ticks < 6750); + } +} + +} //namespace TestProbabilitySelector + +#endif // TEST_PROBABILITY_SELECTOR_H