diff --git a/bt/tasks/composites/bt_probability_selector.cpp b/bt/tasks/composites/bt_probability_selector.cpp index 74e12fd..36a9416 100644 --- a/bt/tasks/composites/bt_probability_selector.cpp +++ b/bt/tasks/composites/bt_probability_selector.cpp @@ -40,6 +40,15 @@ void BTProbabilitySelector::set_probability(int p_index, double p_probability) { _set_weight(p_index, new_total - others_total); } +void BTProbabilitySelector::set_abort_on_failure(bool p_abort_on_failure) { + abort_on_failure = p_abort_on_failure; + emit_changed(); +} + +bool BTProbabilitySelector::get_abort_on_failure() const { + return abort_on_failure; +} + void BTProbabilitySelector::_enter() { _select_task(); } @@ -53,6 +62,9 @@ BT::Status BTProbabilitySelector::_tick(double p_delta) { while (selected_task.is_valid()) { Status status = selected_task->execute(p_delta); if (status == FAILURE) { + if (abort_on_failure) { + return FAILURE; + } failed_tasks.insert(selected_task); _select_task(); } else { // RUNNING or SUCCESS @@ -98,4 +110,8 @@ void BTProbabilitySelector::_bind_methods() { ClassDB::bind_method(D_METHOD("set_weight", "p_index", "p_weight"), &BTProbabilitySelector::set_weight); ClassDB::bind_method(D_METHOD("get_probability", "p_index"), &BTProbabilitySelector::get_probability); ClassDB::bind_method(D_METHOD("set_probability", "p_index", "p_probability"), &BTProbabilitySelector::set_probability); + ClassDB::bind_method(D_METHOD("get_abort_on_failure"), &BTProbabilitySelector::get_abort_on_failure); + ClassDB::bind_method(D_METHOD("set_abort_on_failure", "p_value"), &BTProbabilitySelector::set_abort_on_failure); + + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "abort_on_failure"), "set_abort_on_failure", "get_abort_on_failure"); } diff --git a/bt/tasks/composites/bt_probability_selector.h b/bt/tasks/composites/bt_probability_selector.h index c45157d..744a926 100644 --- a/bt/tasks/composites/bt_probability_selector.h +++ b/bt/tasks/composites/bt_probability_selector.h @@ -23,6 +23,7 @@ class BTProbabilitySelector : public BTComposite { private: HashSet> failed_tasks; Ref selected_task; + bool abort_on_failure = false; void _select_task(); @@ -50,6 +51,9 @@ public: double get_probability(int p_index) const; void set_probability(int p_index, double p_probability); + + void set_abort_on_failure(bool p_abort_on_failure); + bool get_abort_on_failure() const; }; #endif // BT_PROBABILITY_SELECTOR_H diff --git a/tests/test_probability_selector.h b/tests/test_probability_selector.h index fefe3eb..13337b5 100644 --- a/tests/test_probability_selector.h +++ b/tests/test_probability_selector.h @@ -157,6 +157,30 @@ TEST_CASE("[Modules][LimboAI] BTProbabilitySelector") { CHECK(task3->num_ticks > 5750); CHECK(task3->num_ticks < 6750); } + SUBCASE("Test abort_on_failure") { + task1->ret_status = BTTask::FAILURE; + task2->ret_status = BTTask::FAILURE; + task3->ret_status = BTTask::FAILURE; + + int expected_child_executions = 0; + + SUBCASE("When abort_on_failure == false") { + sel->set_abort_on_failure(false); + expected_child_executions = 3; + } + SUBCASE("When abort_on_failure == true") { + sel->set_abort_on_failure(true); + expected_child_executions = 1; + } + + sel->execute(0.01666); + int num_ticks = task1->num_ticks + task2->num_ticks + task3->num_ticks; + CHECK(num_ticks == expected_child_executions); + int num_entries = task1->num_entries + task2->num_entries + task3->num_entries; + CHECK(num_entries == expected_child_executions); + int num_exits = task1->num_exits + task2->num_exits + task3->num_exits; + CHECK(num_exits == expected_child_executions); + } } } //namespace TestProbabilitySelector