Merge pull request #51 from DanielSWolf/bugfix/#46-progress

Prevent incomplete progress reporting
This commit is contained in:
Daniel Wolf 2019-01-03 17:02:27 +01:00 committed by GitHub
commit 10632576bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 81 additions and 32 deletions

View File

@ -33,8 +33,8 @@ JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness."); if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
ProgressMerger progressMerger(progressSink); ProgressMerger progressMerger(progressSink);
ProgressSink& pass1ProgressSink = progressMerger.addSink(1.0); ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
ProgressSink& pass2ProgressSink = progressMerger.addSink(0.3); ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
// Detect activity // Detect activity
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange()); JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
@ -100,17 +100,23 @@ JoiningBoundedTimeline<void> detectVoiceActivity(
); );
audioSegments.push_back(segmentRange); audioSegments.push_back(segmentRange);
} }
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) { runParallel(
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange); "VAD",
JoiningBoundedTimeline<void> activitySegment = [&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink); const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
JoiningBoundedTimeline<void> activitySegment =
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
std::lock_guard<std::mutex> lock(activityMutex); std::lock_guard<std::mutex> lock(activityMutex);
for (auto activityRange : activitySegment) { for (auto activityRange : activitySegment) {
activityRange.getTimeRange().shift(segmentRange.getStart()); activityRange.getTimeRange().shift(segmentRange.getStart());
activity.set(activityRange); activity.set(activityRange);
} }
}, audioSegments, segmentCount, progressSink); },
audioSegments,
segmentCount,
progressSink
);
// Fill small gaps in activity // Fill small gaps in activity
const centiseconds maxGap(5); const centiseconds maxGap(5);

View File

@ -154,6 +154,8 @@ optional<Timeline<Phone>> getPhoneAlignment(
const vector<int16_t>& audioBuffer, const vector<int16_t>& audioBuffer,
ps_decoder_t& decoder) ps_decoder_t& decoder)
{ {
if (wordIds.empty()) return boost::none;
// Create alignment list // Create alignment list
lambda_unique_ptr<ps_alignment_t> alignment( lambda_unique_ptr<ps_alignment_t> alignment(
ps_alignment_init(decoder.d2p), ps_alignment_init(decoder.d2p),
@ -252,8 +254,10 @@ static Timeline<Phone> utteranceToPhones(
ProgressSink& utteranceProgressSink ProgressSink& utteranceProgressSink
) { ) {
ProgressMerger utteranceProgressMerger(utteranceProgressSink); ProgressMerger utteranceProgressMerger(utteranceProgressSink);
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0); ProgressSink& wordRecognitionProgressSink =
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5); utteranceProgressMerger.addSource("word recognition (PocketSphinx recognizer)", 1.0);
ProgressSink& alignmentProgressSink =
utteranceProgressMerger.addSource("alignment (PocketSphinx recognizer)", 0.5);
// Pad time range to give PocketSphinx some breathing room // Pad time range to give PocketSphinx some breathing room
TimeRange paddedTimeRange = utteranceTimeRange; TimeRange paddedTimeRange = utteranceTimeRange;
@ -298,7 +302,6 @@ static Timeline<Phone> utteranceToPhones(
const string fixedWord = fixPronunciation(timedWord.getValue()); const string fixedWord = fixPronunciation(timedWord.getValue());
wordIds.push_back(getWordId(fixedWord, *decoder.dict)); wordIds.push_back(getWordId(fixedWord, *decoder.dict));
} }
if (wordIds.empty()) return {};
// Align the words' phones with speech // Align the words' phones with speech
#if BOOST_VERSION < 105600 // Support legacy syntax #if BOOST_VERSION < 105600 // Support legacy syntax

View File

@ -91,8 +91,10 @@ BoundedTimeline<Phone> recognizePhones(
ProgressSink& progressSink ProgressSink& progressSink
) { ) {
ProgressMerger totalProgressMerger(progressSink); ProgressMerger totalProgressMerger(progressSink);
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0); ProgressSink& voiceActivationProgressSink =
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15); totalProgressMerger.addSource("VAD (PocketSphinx tools)", 1.0);
ProgressSink& dialogProgressSink =
totalProgressMerger.addSource("recognition (PocketSphinx tools)", 15.0);
// Make sure audio stream has no DC offset // Make sure audio stream has no DC offset
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone() | removeDcOffset(); const unique_ptr<AudioClip> audioClip = inputAudioClip.clone() | removeDcOffset();
@ -151,6 +153,7 @@ BoundedTimeline<Phone> recognizePhones(
} }
logging::debugFormat("Speech recognition using {} threads -- start", threadCount); logging::debugFormat("Speech recognition using {} threads -- start", threadCount);
runParallel( runParallel(
"speech recognition (PocketSphinx tools)",
processUtterance, processUtterance,
utterances, utterances,
threadCount, threadCount,
@ -159,7 +162,7 @@ BoundedTimeline<Phone> recognizePhones(
); );
logging::debug("Speech recognition -- end"); logging::debug("Speech recognition -- end");
} catch (...) { } catch (...) {
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx.")); std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx tools."));
} }
return phones; return phones;

View File

@ -82,6 +82,7 @@ void runParallel(
template<typename TCollection> template<typename TCollection>
void runParallel( void runParallel(
const std::string& description,
std::function<void(typename TCollection::reference, ProgressSink&)> processElement, std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
TCollection& collection, TCollection& collection,
int maxThreadCount, int maxThreadCount,
@ -92,9 +93,15 @@ void runParallel(
// Create a collection of wrapper functions that take care of progress handling // Create a collection of wrapper functions that take care of progress handling
ProgressMerger progressMerger(progressSink); ProgressMerger progressMerger(progressSink);
std::vector<std::function<void()>> functions; std::vector<std::function<void()>> functions;
int elementIndex = 0;
for (auto& element : collection) { for (auto& element : collection) {
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element)); auto& elementProgressSink = progressMerger.addSource(
fmt::format("runParallel ({}) #{}", description, elementIndex),
getElementProgressWeight(element)
);
functions.push_back([&]() { processElement(element, elementProgressSink); }); functions.push_back([&]() { processElement(element, elementProgressSink); });
++elementIndex;
} }
// Run wrapper function // Run wrapper function

View File

@ -1,6 +1,7 @@
#include "progress.h" #include "progress.h"
#include <mutex> #include <mutex>
#include "logging/logging.h"
using std::string; using std::string;
@ -16,17 +17,36 @@ ProgressMerger::ProgressMerger(ProgressSink& sink) :
sink(sink) sink(sink)
{} {}
ProgressSink& ProgressMerger::addSink(double weight) { ProgressMerger::~ProgressMerger() {
for (const auto& source : sources) {
if (source.progress < 1.0) {
logging::debugFormat(
"Progress merger source '{}' never reached 1.0, but stopped at {}.",
source.description,
source.progress
);
}
}
}
ProgressSink& ProgressMerger::addSource(const std::string& description, double weight) {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
totalWeight += weight; totalWeight += weight;
int sinkIndex = weightedValues.size();
weightedValues.push_back(0); const int sourceIndex = sources.size();
forwarders.emplace_back([weight, sinkIndex, this](double progress) { sources.push_back({
weightedValues[sinkIndex] = progress * weight; description,
report(); weight,
std::make_unique<ProgressForwarder>(
[sourceIndex, this](double progress) {
sources[sourceIndex].progress = progress;
report();
}
),
0.0
}); });
return forwarders.back(); return *sources[sourceIndex].forwarder;
} }
void ProgressMerger::report() { void ProgressMerger::report() {
@ -34,8 +54,8 @@ void ProgressMerger::report() {
if (totalWeight != 0) { if (totalWeight != 0) {
double weightedSum = 0; double weightedSum = 0;
for (double weightedValue : weightedValues) { for (const auto& source : sources) {
weightedSum += weightedValue; weightedSum += source.weight * source.progress;
} }
const double progress = weightedSum / totalWeight; const double progress = weightedSum / totalWeight;
sink.reportProgress(progress); sink.reportProgress(progress);

View File

@ -1,9 +1,10 @@
#pragma once #pragma once
#include <list>
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
#include <memory>
#include <string>
class ProgressSink { class ProgressSink {
public: public:
@ -24,16 +25,25 @@ private:
std::function<void(double progress)> callback; std::function<void(double progress)> callback;
}; };
struct MergerSource {
std::string description;
double weight;
// Needs to be a pointer because we give away references to the forwarder
// itself which would become invalid if the MergerSource is moved.
std::unique_ptr<ProgressForwarder> forwarder;
double progress;
};
class ProgressMerger { class ProgressMerger {
public: public:
ProgressMerger(ProgressSink& sink); ProgressMerger(ProgressSink& sink);
ProgressSink& addSink(double weight); ~ProgressMerger();
ProgressSink& addSource(const std::string& description, double weight);
private: private:
void report(); void report();
ProgressSink& sink; ProgressSink& sink;
std::mutex mutex; std::mutex mutex;
double totalWeight = 0; double totalWeight = 0;
std::list<ProgressForwarder> forwarders; std::vector<MergerSource> sources;
std::vector<double> weightedValues;
}; };