Merge pull request #51 from DanielSWolf/bugfix/#46-progress

Prevent incomplete progress reporting
This commit is contained in:
Daniel Wolf 2019-01-03 17:02:27 +01:00 committed by GitHub
commit 10632576bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 81 additions and 32 deletions

View File

@ -33,8 +33,8 @@ JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
ProgressMerger progressMerger(progressSink);
ProgressSink& pass1ProgressSink = progressMerger.addSink(1.0);
ProgressSink& pass2ProgressSink = progressMerger.addSink(0.3);
ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
// Detect activity
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
@ -100,7 +100,9 @@ JoiningBoundedTimeline<void> detectVoiceActivity(
);
audioSegments.push_back(segmentRange);
}
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
runParallel(
"VAD",
[&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
JoiningBoundedTimeline<void> activitySegment =
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
@ -110,7 +112,11 @@ JoiningBoundedTimeline<void> detectVoiceActivity(
activityRange.getTimeRange().shift(segmentRange.getStart());
activity.set(activityRange);
}
}, audioSegments, segmentCount, progressSink);
},
audioSegments,
segmentCount,
progressSink
);
// Fill small gaps in activity
const centiseconds maxGap(5);

View File

@ -154,6 +154,8 @@ optional<Timeline<Phone>> getPhoneAlignment(
const vector<int16_t>& audioBuffer,
ps_decoder_t& decoder)
{
if (wordIds.empty()) return boost::none;
// Create alignment list
lambda_unique_ptr<ps_alignment_t> alignment(
ps_alignment_init(decoder.d2p),
@ -252,8 +254,10 @@ static Timeline<Phone> utteranceToPhones(
ProgressSink& utteranceProgressSink
) {
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
ProgressSink& wordRecognitionProgressSink =
utteranceProgressMerger.addSource("word recognition (PocketSphinx recognizer)", 1.0);
ProgressSink& alignmentProgressSink =
utteranceProgressMerger.addSource("alignment (PocketSphinx recognizer)", 0.5);
// Pad time range to give PocketSphinx some breathing room
TimeRange paddedTimeRange = utteranceTimeRange;
@ -298,7 +302,6 @@ static Timeline<Phone> utteranceToPhones(
const string fixedWord = fixPronunciation(timedWord.getValue());
wordIds.push_back(getWordId(fixedWord, *decoder.dict));
}
if (wordIds.empty()) return {};
// Align the words' phones with speech
#if BOOST_VERSION < 105600 // Support legacy syntax

View File

@ -91,8 +91,10 @@ BoundedTimeline<Phone> recognizePhones(
ProgressSink& progressSink
) {
ProgressMerger totalProgressMerger(progressSink);
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
ProgressSink& voiceActivationProgressSink =
totalProgressMerger.addSource("VAD (PocketSphinx tools)", 1.0);
ProgressSink& dialogProgressSink =
totalProgressMerger.addSource("recognition (PocketSphinx tools)", 15.0);
// Make sure audio stream has no DC offset
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone() | removeDcOffset();
@ -151,6 +153,7 @@ BoundedTimeline<Phone> recognizePhones(
}
logging::debugFormat("Speech recognition using {} threads -- start", threadCount);
runParallel(
"speech recognition (PocketSphinx tools)",
processUtterance,
utterances,
threadCount,
@ -159,7 +162,7 @@ BoundedTimeline<Phone> recognizePhones(
);
logging::debug("Speech recognition -- end");
} catch (...) {
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx."));
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx tools."));
}
return phones;

View File

@ -82,6 +82,7 @@ void runParallel(
template<typename TCollection>
void runParallel(
const std::string& description,
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
TCollection& collection,
int maxThreadCount,
@ -92,9 +93,15 @@ void runParallel(
// Create a collection of wrapper functions that take care of progress handling
ProgressMerger progressMerger(progressSink);
std::vector<std::function<void()>> functions;
int elementIndex = 0;
for (auto& element : collection) {
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element));
auto& elementProgressSink = progressMerger.addSource(
fmt::format("runParallel ({}) #{}", description, elementIndex),
getElementProgressWeight(element)
);
functions.push_back([&]() { processElement(element, elementProgressSink); });
++elementIndex;
}
// Run wrapper function

View File

@ -1,6 +1,7 @@
#include "progress.h"
#include <mutex>
#include "logging/logging.h"
using std::string;
@ -16,17 +17,36 @@ ProgressMerger::ProgressMerger(ProgressSink& sink) :
sink(sink)
{}
ProgressSink& ProgressMerger::addSink(double weight) {
ProgressMerger::~ProgressMerger() {
for (const auto& source : sources) {
if (source.progress < 1.0) {
logging::debugFormat(
"Progress merger source '{}' never reached 1.0, but stopped at {}.",
source.description,
source.progress
);
}
}
}
ProgressSink& ProgressMerger::addSource(const std::string& description, double weight) {
std::lock_guard<std::mutex> lock(mutex);
totalWeight += weight;
int sinkIndex = weightedValues.size();
weightedValues.push_back(0);
forwarders.emplace_back([weight, sinkIndex, this](double progress) {
weightedValues[sinkIndex] = progress * weight;
const int sourceIndex = sources.size();
sources.push_back({
description,
weight,
std::make_unique<ProgressForwarder>(
[sourceIndex, this](double progress) {
sources[sourceIndex].progress = progress;
report();
}
),
0.0
});
return forwarders.back();
return *sources[sourceIndex].forwarder;
}
void ProgressMerger::report() {
@ -34,8 +54,8 @@ void ProgressMerger::report() {
if (totalWeight != 0) {
double weightedSum = 0;
for (double weightedValue : weightedValues) {
weightedSum += weightedValue;
for (const auto& source : sources) {
weightedSum += source.weight * source.progress;
}
const double progress = weightedSum / totalWeight;
sink.reportProgress(progress);

View File

@ -1,9 +1,10 @@
#pragma once
#include <list>
#include <functional>
#include <mutex>
#include <vector>
#include <memory>
#include <string>
class ProgressSink {
public:
@ -24,16 +25,25 @@ private:
std::function<void(double progress)> callback;
};
struct MergerSource {
std::string description;
double weight;
// Needs to be a pointer because we give away references to the forwarder
// itself which would become invalid if the MergerSource is moved.
std::unique_ptr<ProgressForwarder> forwarder;
double progress;
};
class ProgressMerger {
public:
ProgressMerger(ProgressSink& sink);
ProgressSink& addSink(double weight);
~ProgressMerger();
ProgressSink& addSource(const std::string& description, double weight);
private:
void report();
ProgressSink& sink;
std::mutex mutex;
double totalWeight = 0;
std::list<ProgressForwarder> forwarders;
std::vector<double> weightedValues;
std::vector<MergerSource> sources;
};