Showing combined progress for entire task

This commit is contained in:
Daniel Wolf 2016-01-28 19:13:40 +01:00
parent 8e7fcc4efe
commit c425885929
5 changed files with 90 additions and 18 deletions

View File

@ -7,6 +7,46 @@
using std::string;
ProgressForwarder::ProgressForwarder(std::function<void(double progress)> callback) :
callback(callback)
{}
void ProgressForwarder::reportProgress(double value) {
callback(value);
}
ProgressMerger::ProgressMerger(ProgressSink& sink) :
sink(sink)
{}
ProgressSink& ProgressMerger::addSink(double weight) {
std::lock_guard<std::mutex> lock(mutex);
totalWeight += weight;
int sinkIndex = weightedValues.size();
weightedValues.push_back(0);
forwarders.push_back(ProgressForwarder([weight, sinkIndex, this](double progress) {
weightedValues[sinkIndex] = progress * weight;
report();
}));
return forwarders.back();
}
void ProgressMerger::report() {
std::lock_guard<std::mutex> lock(mutex);
if (totalWeight != 0) {
double weightedSum = 0;
for (double weightedValue : weightedValues) {
weightedSum += weightedValue;
}
double progress = weightedSum / totalWeight;
sink.reportProgress(progress);
} else {
sink.reportProgress(0);
}
}
ProgressBar::ProgressBar() {
updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this);
}

View File

@ -3,12 +3,44 @@
#include <string>
#include <atomic>
#include <future>
#include <functional>
#include <list>
#include <vector>
#include <mutex>
class ProgressBar {
class ProgressSink {
public:
virtual ~ProgressSink() {}
virtual void reportProgress(double value) = 0;
};
class ProgressForwarder : public ProgressSink {
public:
ProgressForwarder(std::function<void(double progress)> callback);
void reportProgress(double value) override;
private:
std::function<void(double progress)> callback;
};
class ProgressMerger {
public:
ProgressMerger(ProgressSink& sink);
ProgressSink& addSink(double weight);
private:
void report();
ProgressSink& sink;
std::mutex mutex;
double totalWeight = 0;
std::list<ProgressForwarder> forwarders;
std::vector<double> weightedValues;
};
class ProgressBar : public ProgressSink {
public:
ProgressBar();
~ProgressBar();
void reportProgress(double value);
void reportProgress(double value) override;
private:
void updateLoop();
@ -21,5 +53,3 @@ private:
std::string currentText;
int animationIndex = 0;
};

View File

@ -88,9 +88,7 @@ int main(int argc, char *argv[]) {
map<centiseconds, Phone> phones;
{
ProgressBar progressBar;
phones = detectPhones(
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
[&progressBar](double progress) { progressBar.reportProgress(progress); });
phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
}
std::cerr << "Done" << std::endl;

View File

@ -75,13 +75,12 @@ int16_t floatSampleToInt16(float sample) {
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
}
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, function<void(double)> reportProgress) {
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
// Process entire sound file
vector<int16_t> buffer;
const int capacity = 1600; // 0.1 second capacity
buffer.reserve(capacity);
int sampleCount = 0;
reportProgress(0);
do {
// Read to buffer
buffer.clear();
@ -105,7 +104,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
processBuffer(buffer);
sampleCount += buffer.size();
reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
} while (buffer.size());
}
@ -137,7 +136,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
*errorString += message;
}
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) {
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream));
@ -150,7 +149,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false);
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
};
processAudioStream(*audioStream.get(), processBuffer, reportProgress);
processAudioStream(*audioStream.get(), processBuffer, progressSink);
// End recognition
error = ps_end_utt(&recognizer);
@ -170,7 +169,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
return result;
}
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) {
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Create alignment list
lambda_unique_ptr<ps_alignment_t> alignment(
ps_alignment_init(recognizer.d2p),
@ -211,7 +210,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
}
}
};
processAudioStream(*audioStream.get(), processBuffer, reportProgress);
processAudioStream(*audioStream.get(), processBuffer, progressSink);
// End search
ps_search_finish(search.get());
@ -240,7 +239,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
return result;
}
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, function<void(double)> reportProgress) {
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
// Discard Pocketsphinx output
err_set_logfp(nullptr);
@ -256,11 +255,15 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
// Create speech recognizer
auto recognizer = createSpeechRecognizer(*config.get());
ProgressMerger progressMerger(progressSink);
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
// Recognize words
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), reportProgress);
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
// Align the word's phones with speech
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), reportProgress);
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
return result;
}
catch (...) {

View File

@ -6,5 +6,6 @@
#include "audioInput/AudioStream.h"
#include "Phone.h"
#include "centiseconds.h"
#include "progressBar.h"
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, std::function<void(double)> reportProgress);
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);