diff --git a/rhubarb/src/audio/voiceActivityDetection.cpp b/rhubarb/src/audio/voiceActivityDetection.cpp index b89b0a7..990ae88 100644 --- a/rhubarb/src/audio/voiceActivityDetection.cpp +++ b/rhubarb/src/audio/voiceActivityDetection.cpp @@ -33,8 +33,8 @@ JoiningBoundedTimeline webRtcDetectVoiceActivity( if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness."); ProgressMerger progressMerger(progressSink); - ProgressSink& pass1ProgressSink = progressMerger.addSink(1.0); - ProgressSink& pass2ProgressSink = progressMerger.addSink(0.3); + ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0); + ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3); // Detect activity JoiningBoundedTimeline activity(audioClip.getTruncatedRange()); @@ -100,17 +100,23 @@ JoiningBoundedTimeline detectVoiceActivity( ); audioSegments.push_back(segmentRange); } - runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) { - const unique_ptr audioSegment = audioClip->clone() | segment(segmentRange); - JoiningBoundedTimeline activitySegment = - webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink); + runParallel( + "VAD", + [&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) { + const unique_ptr audioSegment = audioClip->clone() | segment(segmentRange); + JoiningBoundedTimeline activitySegment = + webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink); - std::lock_guard lock(activityMutex); - for (auto activityRange : activitySegment) { - activityRange.getTimeRange().shift(segmentRange.getStart()); - activity.set(activityRange); - } - }, audioSegments, segmentCount, progressSink); + std::lock_guard lock(activityMutex); + for (auto activityRange : activitySegment) { + activityRange.getTimeRange().shift(segmentRange.getStart()); + activity.set(activityRange); + } + }, + audioSegments, + segmentCount, + progressSink + ); // Fill small gaps in activity const centiseconds maxGap(5); diff --git a/rhubarb/src/recognition/PocketSphinxRecognizer.cpp b/rhubarb/src/recognition/PocketSphinxRecognizer.cpp index c44e465..ca41050 100644 --- a/rhubarb/src/recognition/PocketSphinxRecognizer.cpp +++ b/rhubarb/src/recognition/PocketSphinxRecognizer.cpp @@ -154,6 +154,8 @@ optional> getPhoneAlignment( const vector& audioBuffer, ps_decoder_t& decoder) { + if (wordIds.empty()) return boost::none; + // Create alignment list lambda_unique_ptr alignment( ps_alignment_init(decoder.d2p), @@ -252,8 +254,10 @@ static Timeline utteranceToPhones( ProgressSink& utteranceProgressSink ) { ProgressMerger utteranceProgressMerger(utteranceProgressSink); - ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0); - ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5); + ProgressSink& wordRecognitionProgressSink = + utteranceProgressMerger.addSource("word recognition (PocketSphinx recognizer)", 1.0); + ProgressSink& alignmentProgressSink = + utteranceProgressMerger.addSource("alignment (PocketSphinx recognizer)", 0.5); // Pad time range to give PocketSphinx some breathing room TimeRange paddedTimeRange = utteranceTimeRange; @@ -298,7 +302,6 @@ static Timeline utteranceToPhones( const string fixedWord = fixPronunciation(timedWord.getValue()); wordIds.push_back(getWordId(fixedWord, *decoder.dict)); } - if (wordIds.empty()) return {}; // Align the words' phones with speech #if BOOST_VERSION < 105600 // Support legacy syntax diff --git a/rhubarb/src/recognition/pocketSphinxTools.cpp b/rhubarb/src/recognition/pocketSphinxTools.cpp index 27c6ee1..4e9c117 100644 --- a/rhubarb/src/recognition/pocketSphinxTools.cpp +++ b/rhubarb/src/recognition/pocketSphinxTools.cpp @@ -91,8 +91,10 @@ BoundedTimeline recognizePhones( ProgressSink& progressSink ) { ProgressMerger totalProgressMerger(progressSink); - ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0); - ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15); + ProgressSink& voiceActivationProgressSink = + totalProgressMerger.addSource("VAD (PocketSphinx tools)", 1.0); + ProgressSink& dialogProgressSink = + totalProgressMerger.addSource("recognition (PocketSphinx tools)", 15.0); // Make sure audio stream has no DC offset const unique_ptr audioClip = inputAudioClip.clone() | removeDcOffset(); @@ -151,6 +153,7 @@ BoundedTimeline recognizePhones( } logging::debugFormat("Speech recognition using {} threads -- start", threadCount); runParallel( + "speech recognition (PocketSphinx tools)", processUtterance, utterances, threadCount, @@ -159,7 +162,7 @@ BoundedTimeline recognizePhones( ); logging::debug("Speech recognition -- end"); } catch (...) { - std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx.")); + std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx tools.")); } return phones; diff --git a/rhubarb/src/tools/parallel.h b/rhubarb/src/tools/parallel.h index 1caef68..7486b01 100644 --- a/rhubarb/src/tools/parallel.h +++ b/rhubarb/src/tools/parallel.h @@ -82,6 +82,7 @@ void runParallel( template void runParallel( + const std::string& description, std::function processElement, TCollection& collection, int maxThreadCount, @@ -92,9 +93,15 @@ void runParallel( // Create a collection of wrapper functions that take care of progress handling ProgressMerger progressMerger(progressSink); std::vector> functions; + int elementIndex = 0; for (auto& element : collection) { - auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element)); + auto& elementProgressSink = progressMerger.addSource( + fmt::format("runParallel ({}) #{}", description, elementIndex), + getElementProgressWeight(element) + ); functions.push_back([&]() { processElement(element, elementProgressSink); }); + + ++elementIndex; } // Run wrapper function diff --git a/rhubarb/src/tools/progress.cpp b/rhubarb/src/tools/progress.cpp index 81e54a7..c313092 100644 --- a/rhubarb/src/tools/progress.cpp +++ b/rhubarb/src/tools/progress.cpp @@ -1,6 +1,7 @@ #include "progress.h" #include +#include "logging/logging.h" using std::string; @@ -16,17 +17,36 @@ ProgressMerger::ProgressMerger(ProgressSink& sink) : sink(sink) {} -ProgressSink& ProgressMerger::addSink(double weight) { +ProgressMerger::~ProgressMerger() { + for (const auto& source : sources) { + if (source.progress < 1.0) { + logging::debugFormat( + "Progress merger source '{}' never reached 1.0, but stopped at {}.", + source.description, + source.progress + ); + } + } +} + +ProgressSink& ProgressMerger::addSource(const std::string& description, double weight) { std::lock_guard lock(mutex); totalWeight += weight; - int sinkIndex = weightedValues.size(); - weightedValues.push_back(0); - forwarders.emplace_back([weight, sinkIndex, this](double progress) { - weightedValues[sinkIndex] = progress * weight; - report(); + + const int sourceIndex = sources.size(); + sources.push_back({ + description, + weight, + std::make_unique( + [sourceIndex, this](double progress) { + sources[sourceIndex].progress = progress; + report(); + } + ), + 0.0 }); - return forwarders.back(); + return *sources[sourceIndex].forwarder; } void ProgressMerger::report() { @@ -34,8 +54,8 @@ void ProgressMerger::report() { if (totalWeight != 0) { double weightedSum = 0; - for (double weightedValue : weightedValues) { - weightedSum += weightedValue; + for (const auto& source : sources) { + weightedSum += source.weight * source.progress; } const double progress = weightedSum / totalWeight; sink.reportProgress(progress); diff --git a/rhubarb/src/tools/progress.h b/rhubarb/src/tools/progress.h index 23f4965..6a66cb0 100644 --- a/rhubarb/src/tools/progress.h +++ b/rhubarb/src/tools/progress.h @@ -1,9 +1,10 @@ #pragma once -#include #include #include #include +#include +#include class ProgressSink { public: @@ -24,16 +25,25 @@ private: std::function callback; }; +struct MergerSource { + std::string description; + double weight; + // Needs to be a pointer because we give away references to the forwarder + // itself which would become invalid if the MergerSource is moved. + std::unique_ptr forwarder; + double progress; +}; + class ProgressMerger { public: ProgressMerger(ProgressSink& sink); - ProgressSink& addSink(double weight); + ~ProgressMerger(); + ProgressSink& addSource(const std::string& description, double weight); private: void report(); ProgressSink& sink; std::mutex mutex; double totalWeight = 0; - std::list forwarders; - std::vector weightedValues; + std::vector sources; };