Merge pull request #51 from DanielSWolf/bugfix/#46-progress
Prevent incomplete progress reporting
This commit is contained in:
commit
10632576bd
|
@ -33,8 +33,8 @@ JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
|
||||||
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
||||||
|
|
||||||
ProgressMerger progressMerger(progressSink);
|
ProgressMerger progressMerger(progressSink);
|
||||||
ProgressSink& pass1ProgressSink = progressMerger.addSink(1.0);
|
ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
|
||||||
ProgressSink& pass2ProgressSink = progressMerger.addSink(0.3);
|
ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
|
||||||
|
|
||||||
// Detect activity
|
// Detect activity
|
||||||
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
|
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
|
||||||
|
@ -100,7 +100,9 @@ JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||||
);
|
);
|
||||||
audioSegments.push_back(segmentRange);
|
audioSegments.push_back(segmentRange);
|
||||||
}
|
}
|
||||||
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
runParallel(
|
||||||
|
"VAD",
|
||||||
|
[&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||||
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
|
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
|
||||||
JoiningBoundedTimeline<void> activitySegment =
|
JoiningBoundedTimeline<void> activitySegment =
|
||||||
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||||
|
@ -110,7 +112,11 @@ JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||||
activity.set(activityRange);
|
activity.set(activityRange);
|
||||||
}
|
}
|
||||||
}, audioSegments, segmentCount, progressSink);
|
},
|
||||||
|
audioSegments,
|
||||||
|
segmentCount,
|
||||||
|
progressSink
|
||||||
|
);
|
||||||
|
|
||||||
// Fill small gaps in activity
|
// Fill small gaps in activity
|
||||||
const centiseconds maxGap(5);
|
const centiseconds maxGap(5);
|
||||||
|
|
|
@ -154,6 +154,8 @@ optional<Timeline<Phone>> getPhoneAlignment(
|
||||||
const vector<int16_t>& audioBuffer,
|
const vector<int16_t>& audioBuffer,
|
||||||
ps_decoder_t& decoder)
|
ps_decoder_t& decoder)
|
||||||
{
|
{
|
||||||
|
if (wordIds.empty()) return boost::none;
|
||||||
|
|
||||||
// Create alignment list
|
// Create alignment list
|
||||||
lambda_unique_ptr<ps_alignment_t> alignment(
|
lambda_unique_ptr<ps_alignment_t> alignment(
|
||||||
ps_alignment_init(decoder.d2p),
|
ps_alignment_init(decoder.d2p),
|
||||||
|
@ -252,8 +254,10 @@ static Timeline<Phone> utteranceToPhones(
|
||||||
ProgressSink& utteranceProgressSink
|
ProgressSink& utteranceProgressSink
|
||||||
) {
|
) {
|
||||||
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
|
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
|
||||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
ProgressSink& wordRecognitionProgressSink =
|
||||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
utteranceProgressMerger.addSource("word recognition (PocketSphinx recognizer)", 1.0);
|
||||||
|
ProgressSink& alignmentProgressSink =
|
||||||
|
utteranceProgressMerger.addSource("alignment (PocketSphinx recognizer)", 0.5);
|
||||||
|
|
||||||
// Pad time range to give PocketSphinx some breathing room
|
// Pad time range to give PocketSphinx some breathing room
|
||||||
TimeRange paddedTimeRange = utteranceTimeRange;
|
TimeRange paddedTimeRange = utteranceTimeRange;
|
||||||
|
@ -298,7 +302,6 @@ static Timeline<Phone> utteranceToPhones(
|
||||||
const string fixedWord = fixPronunciation(timedWord.getValue());
|
const string fixedWord = fixPronunciation(timedWord.getValue());
|
||||||
wordIds.push_back(getWordId(fixedWord, *decoder.dict));
|
wordIds.push_back(getWordId(fixedWord, *decoder.dict));
|
||||||
}
|
}
|
||||||
if (wordIds.empty()) return {};
|
|
||||||
|
|
||||||
// Align the words' phones with speech
|
// Align the words' phones with speech
|
||||||
#if BOOST_VERSION < 105600 // Support legacy syntax
|
#if BOOST_VERSION < 105600 // Support legacy syntax
|
||||||
|
|
|
@ -91,8 +91,10 @@ BoundedTimeline<Phone> recognizePhones(
|
||||||
ProgressSink& progressSink
|
ProgressSink& progressSink
|
||||||
) {
|
) {
|
||||||
ProgressMerger totalProgressMerger(progressSink);
|
ProgressMerger totalProgressMerger(progressSink);
|
||||||
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
|
ProgressSink& voiceActivationProgressSink =
|
||||||
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
|
totalProgressMerger.addSource("VAD (PocketSphinx tools)", 1.0);
|
||||||
|
ProgressSink& dialogProgressSink =
|
||||||
|
totalProgressMerger.addSource("recognition (PocketSphinx tools)", 15.0);
|
||||||
|
|
||||||
// Make sure audio stream has no DC offset
|
// Make sure audio stream has no DC offset
|
||||||
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone() | removeDcOffset();
|
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone() | removeDcOffset();
|
||||||
|
@ -151,6 +153,7 @@ BoundedTimeline<Phone> recognizePhones(
|
||||||
}
|
}
|
||||||
logging::debugFormat("Speech recognition using {} threads -- start", threadCount);
|
logging::debugFormat("Speech recognition using {} threads -- start", threadCount);
|
||||||
runParallel(
|
runParallel(
|
||||||
|
"speech recognition (PocketSphinx tools)",
|
||||||
processUtterance,
|
processUtterance,
|
||||||
utterances,
|
utterances,
|
||||||
threadCount,
|
threadCount,
|
||||||
|
@ -159,7 +162,7 @@ BoundedTimeline<Phone> recognizePhones(
|
||||||
);
|
);
|
||||||
logging::debug("Speech recognition -- end");
|
logging::debug("Speech recognition -- end");
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx."));
|
std::throw_with_nested(runtime_error("Error performing speech recognition via PocketSphinx tools."));
|
||||||
}
|
}
|
||||||
|
|
||||||
return phones;
|
return phones;
|
||||||
|
|
|
@ -82,6 +82,7 @@ void runParallel(
|
||||||
|
|
||||||
template<typename TCollection>
|
template<typename TCollection>
|
||||||
void runParallel(
|
void runParallel(
|
||||||
|
const std::string& description,
|
||||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||||
TCollection& collection,
|
TCollection& collection,
|
||||||
int maxThreadCount,
|
int maxThreadCount,
|
||||||
|
@ -92,9 +93,15 @@ void runParallel(
|
||||||
// Create a collection of wrapper functions that take care of progress handling
|
// Create a collection of wrapper functions that take care of progress handling
|
||||||
ProgressMerger progressMerger(progressSink);
|
ProgressMerger progressMerger(progressSink);
|
||||||
std::vector<std::function<void()>> functions;
|
std::vector<std::function<void()>> functions;
|
||||||
|
int elementIndex = 0;
|
||||||
for (auto& element : collection) {
|
for (auto& element : collection) {
|
||||||
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element));
|
auto& elementProgressSink = progressMerger.addSource(
|
||||||
|
fmt::format("runParallel ({}) #{}", description, elementIndex),
|
||||||
|
getElementProgressWeight(element)
|
||||||
|
);
|
||||||
functions.push_back([&]() { processElement(element, elementProgressSink); });
|
functions.push_back([&]() { processElement(element, elementProgressSink); });
|
||||||
|
|
||||||
|
++elementIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run wrapper function
|
// Run wrapper function
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "progress.h"
|
#include "progress.h"
|
||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include "logging/logging.h"
|
||||||
|
|
||||||
using std::string;
|
using std::string;
|
||||||
|
|
||||||
|
@ -16,17 +17,36 @@ ProgressMerger::ProgressMerger(ProgressSink& sink) :
|
||||||
sink(sink)
|
sink(sink)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
ProgressSink& ProgressMerger::addSink(double weight) {
|
ProgressMerger::~ProgressMerger() {
|
||||||
|
for (const auto& source : sources) {
|
||||||
|
if (source.progress < 1.0) {
|
||||||
|
logging::debugFormat(
|
||||||
|
"Progress merger source '{}' never reached 1.0, but stopped at {}.",
|
||||||
|
source.description,
|
||||||
|
source.progress
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ProgressSink& ProgressMerger::addSource(const std::string& description, double weight) {
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
|
||||||
totalWeight += weight;
|
totalWeight += weight;
|
||||||
int sinkIndex = weightedValues.size();
|
|
||||||
weightedValues.push_back(0);
|
const int sourceIndex = sources.size();
|
||||||
forwarders.emplace_back([weight, sinkIndex, this](double progress) {
|
sources.push_back({
|
||||||
weightedValues[sinkIndex] = progress * weight;
|
description,
|
||||||
|
weight,
|
||||||
|
std::make_unique<ProgressForwarder>(
|
||||||
|
[sourceIndex, this](double progress) {
|
||||||
|
sources[sourceIndex].progress = progress;
|
||||||
report();
|
report();
|
||||||
|
}
|
||||||
|
),
|
||||||
|
0.0
|
||||||
});
|
});
|
||||||
return forwarders.back();
|
return *sources[sourceIndex].forwarder;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProgressMerger::report() {
|
void ProgressMerger::report() {
|
||||||
|
@ -34,8 +54,8 @@ void ProgressMerger::report() {
|
||||||
|
|
||||||
if (totalWeight != 0) {
|
if (totalWeight != 0) {
|
||||||
double weightedSum = 0;
|
double weightedSum = 0;
|
||||||
for (double weightedValue : weightedValues) {
|
for (const auto& source : sources) {
|
||||||
weightedSum += weightedValue;
|
weightedSum += source.weight * source.progress;
|
||||||
}
|
}
|
||||||
const double progress = weightedSum / totalWeight;
|
const double progress = weightedSum / totalWeight;
|
||||||
sink.reportProgress(progress);
|
sink.reportProgress(progress);
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <list>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
class ProgressSink {
|
class ProgressSink {
|
||||||
public:
|
public:
|
||||||
|
@ -24,16 +25,25 @@ private:
|
||||||
std::function<void(double progress)> callback;
|
std::function<void(double progress)> callback;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MergerSource {
|
||||||
|
std::string description;
|
||||||
|
double weight;
|
||||||
|
// Needs to be a pointer because we give away references to the forwarder
|
||||||
|
// itself which would become invalid if the MergerSource is moved.
|
||||||
|
std::unique_ptr<ProgressForwarder> forwarder;
|
||||||
|
double progress;
|
||||||
|
};
|
||||||
|
|
||||||
class ProgressMerger {
|
class ProgressMerger {
|
||||||
public:
|
public:
|
||||||
ProgressMerger(ProgressSink& sink);
|
ProgressMerger(ProgressSink& sink);
|
||||||
ProgressSink& addSink(double weight);
|
~ProgressMerger();
|
||||||
|
ProgressSink& addSource(const std::string& description, double weight);
|
||||||
private:
|
private:
|
||||||
void report();
|
void report();
|
||||||
|
|
||||||
ProgressSink& sink;
|
ProgressSink& sink;
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
double totalWeight = 0;
|
double totalWeight = 0;
|
||||||
std::list<ProgressForwarder> forwarders;
|
std::vector<MergerSource> sources;
|
||||||
std::vector<double> weightedValues;
|
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue