Showing combined progress for entire task

This commit is contained in:
Daniel Wolf 2016-01-28 19:13:40 +01:00
parent 8e7fcc4efe
commit c425885929
5 changed files with 90 additions and 18 deletions

View File

@ -7,6 +7,46 @@
using std::string; using std::string;
ProgressForwarder::ProgressForwarder(std::function<void(double progress)> callback) :
callback(callback)
{}
void ProgressForwarder::reportProgress(double value) {
callback(value);
}
ProgressMerger::ProgressMerger(ProgressSink& sink) :
sink(sink)
{}
ProgressSink& ProgressMerger::addSink(double weight) {
std::lock_guard<std::mutex> lock(mutex);
totalWeight += weight;
int sinkIndex = weightedValues.size();
weightedValues.push_back(0);
forwarders.push_back(ProgressForwarder([weight, sinkIndex, this](double progress) {
weightedValues[sinkIndex] = progress * weight;
report();
}));
return forwarders.back();
}
void ProgressMerger::report() {
std::lock_guard<std::mutex> lock(mutex);
if (totalWeight != 0) {
double weightedSum = 0;
for (double weightedValue : weightedValues) {
weightedSum += weightedValue;
}
double progress = weightedSum / totalWeight;
sink.reportProgress(progress);
} else {
sink.reportProgress(0);
}
}
ProgressBar::ProgressBar() { ProgressBar::ProgressBar() {
updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this); updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this);
} }

View File

@ -3,12 +3,44 @@
#include <string> #include <string>
#include <atomic> #include <atomic>
#include <future> #include <future>
#include <functional>
#include <list>
#include <vector>
#include <mutex>
class ProgressBar { class ProgressSink {
public:
virtual ~ProgressSink() {}
virtual void reportProgress(double value) = 0;
};
class ProgressForwarder : public ProgressSink {
public:
ProgressForwarder(std::function<void(double progress)> callback);
void reportProgress(double value) override;
private:
std::function<void(double progress)> callback;
};
class ProgressMerger {
public:
ProgressMerger(ProgressSink& sink);
ProgressSink& addSink(double weight);
private:
void report();
ProgressSink& sink;
std::mutex mutex;
double totalWeight = 0;
std::list<ProgressForwarder> forwarders;
std::vector<double> weightedValues;
};
class ProgressBar : public ProgressSink {
public: public:
ProgressBar(); ProgressBar();
~ProgressBar(); ~ProgressBar();
void reportProgress(double value); void reportProgress(double value) override;
private: private:
void updateLoop(); void updateLoop();
@ -21,5 +53,3 @@ private:
std::string currentText; std::string currentText;
int animationIndex = 0; int animationIndex = 0;
}; };

View File

@ -88,9 +88,7 @@ int main(int argc, char *argv[]) {
map<centiseconds, Phone> phones; map<centiseconds, Phone> phones;
{ {
ProgressBar progressBar; ProgressBar progressBar;
phones = detectPhones( phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
[&progressBar](double progress) { progressBar.reportProgress(progress); });
} }
std::cerr << "Done" << std::endl; std::cerr << "Done" << std::endl;

View File

@ -75,13 +75,12 @@ int16_t floatSampleToInt16(float sample) {
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN); return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
} }
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, function<void(double)> reportProgress) { void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
// Process entire sound file // Process entire sound file
vector<int16_t> buffer; vector<int16_t> buffer;
const int capacity = 1600; // 0.1 second capacity const int capacity = 1600; // 0.1 second capacity
buffer.reserve(capacity); buffer.reserve(capacity);
int sampleCount = 0; int sampleCount = 0;
reportProgress(0);
do { do {
// Read to buffer // Read to buffer
buffer.clear(); buffer.clear();
@ -105,7 +104,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
processBuffer(buffer); processBuffer(buffer);
sampleCount += buffer.size(); sampleCount += buffer.size();
reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount()); progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
} while (buffer.size()); } while (buffer.size());
} }
@ -137,7 +136,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
*errorString += message; *errorString += message;
} }
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) { vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Convert audio stream to the exact format PocketSphinx requires // Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream)); audioStream = to16kHzMono(std::move(audioStream));
@ -150,7 +149,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false); int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false);
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition."); if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
}; };
processAudioStream(*audioStream.get(), processBuffer, reportProgress); processAudioStream(*audioStream.get(), processBuffer, progressSink);
// End recognition // End recognition
error = ps_end_utt(&recognizer); error = ps_end_utt(&recognizer);
@ -170,7 +169,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
return result; return result;
} }
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) { map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Create alignment list // Create alignment list
lambda_unique_ptr<ps_alignment_t> alignment( lambda_unique_ptr<ps_alignment_t> alignment(
ps_alignment_init(recognizer.d2p), ps_alignment_init(recognizer.d2p),
@ -211,7 +210,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
} }
} }
}; };
processAudioStream(*audioStream.get(), processBuffer, reportProgress); processAudioStream(*audioStream.get(), processBuffer, progressSink);
// End search // End search
ps_search_finish(search.get()); ps_search_finish(search.get());
@ -240,7 +239,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
return result; return result;
} }
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, function<void(double)> reportProgress) { map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
// Discard Pocketsphinx output // Discard Pocketsphinx output
err_set_logfp(nullptr); err_set_logfp(nullptr);
@ -256,11 +255,15 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
// Create speech recognizer // Create speech recognizer
auto recognizer = createSpeechRecognizer(*config.get()); auto recognizer = createSpeechRecognizer(*config.get());
ProgressMerger progressMerger(progressSink);
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
// Recognize words // Recognize words
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), reportProgress); vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
// Align the word's phones with speech // Align the word's phones with speech
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), reportProgress); map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
return result; return result;
} }
catch (...) { catch (...) {

View File

@ -6,5 +6,6 @@
#include "audioInput/AudioStream.h" #include "audioInput/AudioStream.h"
#include "Phone.h" #include "Phone.h"
#include "centiseconds.h" #include "centiseconds.h"
#include "progressBar.h"
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, std::function<void(double)> reportProgress); std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);