Merge pull request #51 from DanielSWolf/bugfix/#46-progress
Prevent incomplete progress reporting
This commit is contained in:
commit
10632576bd
|
@ -33,8 +33,8 @@ JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
|
|||
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
||||
|
||||
ProgressMerger progressMerger(progressSink);
|
||||
ProgressSink& pass1ProgressSink = progressMerger.addSink(1.0);
|
||||
ProgressSink& pass2ProgressSink = progressMerger.addSink(0.3);
|
||||
ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
|
||||
ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
|
||||
|
||||
// Detect activity
|
||||
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
|
||||
|
@ -100,17 +100,23 @@ JoiningBoundedTimeline<void> detectVoiceActivity(
|
|||
);
|
||||
audioSegments.push_back(segmentRange);
|
||||
}
|
||||
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
|
||||
JoiningBoundedTimeline<void> activitySegment =
|
||||
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||
runParallel(
|
||||
"VAD",
|
||||
[&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
|
||||
JoiningBoundedTimeline<void> activitySegment =
|
||||
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||
|
||||
std::lock_guard<std::mutex> lock(activityMutex);
|
||||
for (auto activityRange : activitySegment) {
|
||||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||
activity.set(activityRange);
|
||||
}
|
||||
}, audioSegments, segmentCount, progressSink);
|
||||
std::lock_guard<std::mutex> lock(activityMutex);
|
||||
for (auto activityRange : activitySegment) {
|
||||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||
activity.set(activityRange);
|
||||
}
|
||||
},
|
||||
audioSegments,
|
||||
segmentCount,
|
||||
progressSink
|
||||
);
|
||||
|
||||
// Fill small gaps in activity
|
||||
const centiseconds maxGap(5);
|
||||
|
|
|
@ -154,6 +154,8 @@ optional<Timeline<Phone>> getPhoneAlignment(
|
|||
const vector<int16_t>& audioBuffer,
|
||||
ps_decoder_t& decoder)
|
||||
{
|
||||
if (wordIds.empty()) return boost::none;
|
||||
|
||||
// Create alignment list
|
||||
lambda_unique_ptr<ps_alignment_t> alignment(
|
||||
ps_alignment_init(decoder.d2p),
|
||||
|
@ -252,8 +254,10 @@ static Timeline<Phone> utteranceToPhones(
|
|||
ProgressSink& utteranceProgressSink
|
||||
) {
|
||||
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
|
||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||
ProgressSink& wordRecognitionProgressSink =
|
||||
utteranceProgressMerger.addSource("word recognition (PocketSphinx recognizer)", 1.0);
|
||||
ProgressSink& alignmentProgressSink =
|
||||
utteranceProgressMerger.addSource("alignment (PocketSphinx recognizer)", 0.5);
|
||||
|
||||
// Pad time range to give PocketSphinx some breathing room
|
||||
TimeRange paddedTimeRange = utteranceTimeRange;
|
||||
|
@ -298,7 +302,6 @@ static Timeline<Phone> utteranceToPhones(
|
|||
const string fixedWord = fixPronunciation(timedWord.getValue());
|
||||
wordIds.push_back(getWordId(fixedWord, *decoder.dict));
|
||||
}
|
||||
if (wordIds.empty()) return {};
|
||||
|
||||
// Align the words' phones with speech
|
||||
#if BOOST_VERSION < 105600 // Support legacy syntax
|
||||
|
|
|
@ -91,8 +91,10 @@ BoundedTimeline<Phone> recognizePhones(
|
|||
ProgressSink& progressSink
|
||||
) {
|
||||
ProgressMerger totalProgressMerger(progressSink);
|
||||
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
|
||||
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
|
||||
ProgressSink& voiceActivationProgressSink =
|
||||
totalProgressMerger.addSource("VAD (PocketSphinx tools)", 1.0);
|
||||
ProgressSink& dialogProgressSink =
|
||||
totalProgressMerger.addSource("recognition (PocketSphinx tools)", 15.0);
|
||||
|
||||
// Make sure audio stream has no DC offset
|
||||
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone() | removeDcOffset();
|
||||
|
@ -151,6 +153,7 @@ BoundedTimeline<Phone> recognizePhones(
|
|||
}
|
||||
logging::debugFormat("Speech recognition using {} threads -- start", threadCount);
|
||||
runParallel(
|
||||
"speech recognition (PocketSphinx tools)",
|
||||
processUtterance,
|
||||
utterances,
|
||||
threadCount,
|
||||
|
@ -159,7 +162,7 @@ BoundedTimeline<Phone> recognizePhones(
|
|||
);
|
||||
logging::debug("Speech recognition -- end");
|
||||
} catch (...) {
|
||||
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx."));
|
||||
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx tools."));
|
||||
}
|
||||
|
||||
return phones;
|
||||
|
|
|
@ -82,6 +82,7 @@ void runParallel(
|
|||
|
||||
template<typename TCollection>
|
||||
void runParallel(
|
||||
const std::string& description,
|
||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||
TCollection& collection,
|
||||
int maxThreadCount,
|
||||
|
@ -92,9 +93,15 @@ void runParallel(
|
|||
// Create a collection of wrapper functions that take care of progress handling
|
||||
ProgressMerger progressMerger(progressSink);
|
||||
std::vector<std::function<void()>> functions;
|
||||
int elementIndex = 0;
|
||||
for (auto& element : collection) {
|
||||
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element));
|
||||
auto& elementProgressSink = progressMerger.addSource(
|
||||
fmt::format("runParallel ({}) #{}", description, elementIndex),
|
||||
getElementProgressWeight(element)
|
||||
);
|
||||
functions.push_back([&]() { processElement(element, elementProgressSink); });
|
||||
|
||||
++elementIndex;
|
||||
}
|
||||
|
||||
// Run wrapper function
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "progress.h"
|
||||
|
||||
#include <mutex>
|
||||
#include "logging/logging.h"
|
||||
|
||||
using std::string;
|
||||
|
||||
|
@ -16,17 +17,36 @@ ProgressMerger::ProgressMerger(ProgressSink& sink) :
|
|||
sink(sink)
|
||||
{}
|
||||
|
||||
ProgressSink& ProgressMerger::addSink(double weight) {
|
||||
ProgressMerger::~ProgressMerger() {
|
||||
for (const auto& source : sources) {
|
||||
if (source.progress < 1.0) {
|
||||
logging::debugFormat(
|
||||
"Progress merger source '{}' never reached 1.0, but stopped at {}.",
|
||||
source.description,
|
||||
source.progress
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProgressSink& ProgressMerger::addSource(const std::string& description, double weight) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
totalWeight += weight;
|
||||
int sinkIndex = weightedValues.size();
|
||||
weightedValues.push_back(0);
|
||||
forwarders.emplace_back([weight, sinkIndex, this](double progress) {
|
||||
weightedValues[sinkIndex] = progress * weight;
|
||||
report();
|
||||
|
||||
const int sourceIndex = sources.size();
|
||||
sources.push_back({
|
||||
description,
|
||||
weight,
|
||||
std::make_unique<ProgressForwarder>(
|
||||
[sourceIndex, this](double progress) {
|
||||
sources[sourceIndex].progress = progress;
|
||||
report();
|
||||
}
|
||||
),
|
||||
0.0
|
||||
});
|
||||
return forwarders.back();
|
||||
return *sources[sourceIndex].forwarder;
|
||||
}
|
||||
|
||||
void ProgressMerger::report() {
|
||||
|
@ -34,8 +54,8 @@ void ProgressMerger::report() {
|
|||
|
||||
if (totalWeight != 0) {
|
||||
double weightedSum = 0;
|
||||
for (double weightedValue : weightedValues) {
|
||||
weightedSum += weightedValue;
|
||||
for (const auto& source : sources) {
|
||||
weightedSum += source.weight * source.progress;
|
||||
}
|
||||
const double progress = weightedSum / totalWeight;
|
||||
sink.reportProgress(progress);
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
class ProgressSink {
|
||||
public:
|
||||
|
@ -24,16 +25,25 @@ private:
|
|||
std::function<void(double progress)> callback;
|
||||
};
|
||||
|
||||
struct MergerSource {
|
||||
std::string description;
|
||||
double weight;
|
||||
// Needs to be a pointer because we give away references to the forwarder
|
||||
// itself which would become invalid if the MergerSource is moved.
|
||||
std::unique_ptr<ProgressForwarder> forwarder;
|
||||
double progress;
|
||||
};
|
||||
|
||||
class ProgressMerger {
|
||||
public:
|
||||
ProgressMerger(ProgressSink& sink);
|
||||
ProgressSink& addSink(double weight);
|
||||
~ProgressMerger();
|
||||
ProgressSink& addSource(const std::string& description, double weight);
|
||||
private:
|
||||
void report();
|
||||
|
||||
ProgressSink& sink;
|
||||
std::mutex mutex;
|
||||
double totalWeight = 0;
|
||||
std::list<ProgressForwarder> forwarders;
|
||||
std::vector<double> weightedValues;
|
||||
std::vector<MergerSource> sources;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue