From c42588592967c4d5c760621efaa3cde728fd0404 Mon Sep 17 00:00:00 2001 From: Daniel Wolf Date: Thu, 28 Jan 2016 19:13:40 +0100 Subject: [PATCH] Showing combined progress for entire task --- src/ProgressBar.cpp | 40 ++++++++++++++++++++++++++++++++++++++++ src/ProgressBar.h | 38 ++++++++++++++++++++++++++++++++++---- src/main.cpp | 4 +--- src/phoneExtraction.cpp | 23 +++++++++++++---------- src/phoneExtraction.h | 3 ++- 5 files changed, 90 insertions(+), 18 deletions(-) diff --git a/src/ProgressBar.cpp b/src/ProgressBar.cpp index fb92ec4..3b0d726 100644 --- a/src/ProgressBar.cpp +++ b/src/ProgressBar.cpp @@ -7,6 +7,46 @@ using std::string; +ProgressForwarder::ProgressForwarder(std::function callback) : + callback(callback) +{} + +void ProgressForwarder::reportProgress(double value) { + callback(value); +} + +ProgressMerger::ProgressMerger(ProgressSink& sink) : + sink(sink) +{} + +ProgressSink& ProgressMerger::addSink(double weight) { + std::lock_guard lock(mutex); + + totalWeight += weight; + int sinkIndex = weightedValues.size(); + weightedValues.push_back(0); + forwarders.push_back(ProgressForwarder([weight, sinkIndex, this](double progress) { + weightedValues[sinkIndex] = progress * weight; + report(); + })); + return forwarders.back(); +} + +void ProgressMerger::report() { + std::lock_guard lock(mutex); + + if (totalWeight != 0) { + double weightedSum = 0; + for (double weightedValue : weightedValues) { + weightedSum += weightedValue; + } + double progress = weightedSum / totalWeight; + sink.reportProgress(progress); + } else { + sink.reportProgress(0); + } +} + ProgressBar::ProgressBar() { updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this); } diff --git a/src/ProgressBar.h b/src/ProgressBar.h index 152ae5b..59f9964 100644 --- a/src/ProgressBar.h +++ b/src/ProgressBar.h @@ -3,12 +3,44 @@ #include #include #include +#include +#include +#include +#include -class ProgressBar { +class ProgressSink { +public: + virtual ~ProgressSink() {} + virtual void reportProgress(double value) = 0; +}; + +class ProgressForwarder : public ProgressSink { +public: + ProgressForwarder(std::function callback); + void reportProgress(double value) override; +private: + std::function callback; +}; + +class ProgressMerger { +public: + ProgressMerger(ProgressSink& sink); + ProgressSink& addSink(double weight); +private: + void report(); + + ProgressSink& sink; + std::mutex mutex; + double totalWeight = 0; + std::list forwarders; + std::vector weightedValues; +}; + +class ProgressBar : public ProgressSink { public: ProgressBar(); ~ProgressBar(); - void reportProgress(double value); + void reportProgress(double value) override; private: void updateLoop(); @@ -21,5 +53,3 @@ private: std::string currentText; int animationIndex = 0; }; - - diff --git a/src/main.cpp b/src/main.cpp index d793272..2eab1c2 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -88,9 +88,7 @@ int main(int argc, char *argv[]) { map phones; { ProgressBar progressBar; - phones = detectPhones( - [&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, - [&progressBar](double progress) { progressBar.reportProgress(progress); }); + phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar); } std::cerr << "Done" << std::endl; diff --git a/src/phoneExtraction.cpp b/src/phoneExtraction.cpp index 007c324..5389dd0 100644 --- a/src/phoneExtraction.cpp +++ b/src/phoneExtraction.cpp @@ -75,13 +75,12 @@ int16_t floatSampleToInt16(float sample) { return static_cast(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN); } -void processAudioStream(AudioStream& audioStream16kHzMono, function&)> processBuffer, function reportProgress) { +void processAudioStream(AudioStream& audioStream16kHzMono, function&)> processBuffer, ProgressSink& progressSink) { // Process entire sound file vector buffer; const int capacity = 1600; // 0.1 second capacity buffer.reserve(capacity); int sampleCount = 0; - reportProgress(0); do { // Read to buffer buffer.clear(); @@ -105,7 +104,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function(sampleCount) / audioStream16kHzMono.getFrameCount()); + progressSink.reportProgress(static_cast(sampleCount) / audioStream16kHzMono.getFrameCount()); } while (buffer.size()); } @@ -137,7 +136,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form *errorString += message; } -vector recognizeWords(unique_ptr audioStream, ps_decoder_t& recognizer, function reportProgress) { +vector recognizeWords(unique_ptr audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { // Convert audio stream to the exact format PocketSphinx requires audioStream = to16kHzMono(std::move(audioStream)); @@ -150,7 +149,7 @@ vector recognizeWords(unique_ptr audioStream, ps_decoder_t int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false); if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition."); }; - processAudioStream(*audioStream.get(), processBuffer, reportProgress); + processAudioStream(*audioStream.get(), processBuffer, progressSink); // End recognition error = ps_end_utt(&recognizer); @@ -170,7 +169,7 @@ vector recognizeWords(unique_ptr audioStream, ps_decoder_t return result; } -map getPhoneAlignment(const vector& wordIds, unique_ptr audioStream, ps_decoder_t& recognizer, function reportProgress) { +map getPhoneAlignment(const vector& wordIds, unique_ptr audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { // Create alignment list lambda_unique_ptr alignment( ps_alignment_init(recognizer.d2p), @@ -211,7 +210,7 @@ map getPhoneAlignment(const vector& wordIds, uniqu } } }; - processAudioStream(*audioStream.get(), processBuffer, reportProgress); + processAudioStream(*audioStream.get(), processBuffer, progressSink); // End search ps_search_finish(search.get()); @@ -240,7 +239,7 @@ map getPhoneAlignment(const vector& wordIds, uniqu return result; } -map detectPhones(std::function(void)> createAudioStream, function reportProgress) { +map detectPhones(std::function(void)> createAudioStream, ProgressSink& progressSink) { // Discard Pocketsphinx output err_set_logfp(nullptr); @@ -256,11 +255,15 @@ map detectPhones(std::function // Create speech recognizer auto recognizer = createSpeechRecognizer(*config.get()); + ProgressMerger progressMerger(progressSink); + ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0); + ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5); + // Recognize words - vector wordIds = recognizeWords(createAudioStream(), *recognizer.get(), reportProgress); + vector wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink); // Align the word's phones with speech - map result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), reportProgress); + map result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink); return result; } catch (...) { diff --git a/src/phoneExtraction.h b/src/phoneExtraction.h index 28969ab..c2fee5f 100644 --- a/src/phoneExtraction.h +++ b/src/phoneExtraction.h @@ -6,5 +6,6 @@ #include "audioInput/AudioStream.h" #include "Phone.h" #include "centiseconds.h" +#include "progressBar.h" -std::map detectPhones(std::function(void)> createAudioStream, std::function reportProgress); +std::map detectPhones(std::function(void)> createAudioStream, ProgressSink& progressSink);