Showing combined progress for entire task
This commit is contained in:
parent
8e7fcc4efe
commit
c425885929
|
@ -7,6 +7,46 @@
|
||||||
|
|
||||||
using std::string;
|
using std::string;
|
||||||
|
|
||||||
|
ProgressForwarder::ProgressForwarder(std::function<void(double progress)> callback) :
|
||||||
|
callback(callback)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void ProgressForwarder::reportProgress(double value) {
|
||||||
|
callback(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
ProgressMerger::ProgressMerger(ProgressSink& sink) :
|
||||||
|
sink(sink)
|
||||||
|
{}
|
||||||
|
|
||||||
|
ProgressSink& ProgressMerger::addSink(double weight) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
|
||||||
|
totalWeight += weight;
|
||||||
|
int sinkIndex = weightedValues.size();
|
||||||
|
weightedValues.push_back(0);
|
||||||
|
forwarders.push_back(ProgressForwarder([weight, sinkIndex, this](double progress) {
|
||||||
|
weightedValues[sinkIndex] = progress * weight;
|
||||||
|
report();
|
||||||
|
}));
|
||||||
|
return forwarders.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProgressMerger::report() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
|
||||||
|
if (totalWeight != 0) {
|
||||||
|
double weightedSum = 0;
|
||||||
|
for (double weightedValue : weightedValues) {
|
||||||
|
weightedSum += weightedValue;
|
||||||
|
}
|
||||||
|
double progress = weightedSum / totalWeight;
|
||||||
|
sink.reportProgress(progress);
|
||||||
|
} else {
|
||||||
|
sink.reportProgress(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ProgressBar::ProgressBar() {
|
ProgressBar::ProgressBar() {
|
||||||
updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this);
|
updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,12 +3,44 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <future>
|
#include <future>
|
||||||
|
#include <functional>
|
||||||
|
#include <list>
|
||||||
|
#include <vector>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
class ProgressBar {
|
class ProgressSink {
|
||||||
|
public:
|
||||||
|
virtual ~ProgressSink() {}
|
||||||
|
virtual void reportProgress(double value) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProgressForwarder : public ProgressSink {
|
||||||
|
public:
|
||||||
|
ProgressForwarder(std::function<void(double progress)> callback);
|
||||||
|
void reportProgress(double value) override;
|
||||||
|
private:
|
||||||
|
std::function<void(double progress)> callback;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProgressMerger {
|
||||||
|
public:
|
||||||
|
ProgressMerger(ProgressSink& sink);
|
||||||
|
ProgressSink& addSink(double weight);
|
||||||
|
private:
|
||||||
|
void report();
|
||||||
|
|
||||||
|
ProgressSink& sink;
|
||||||
|
std::mutex mutex;
|
||||||
|
double totalWeight = 0;
|
||||||
|
std::list<ProgressForwarder> forwarders;
|
||||||
|
std::vector<double> weightedValues;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ProgressBar : public ProgressSink {
|
||||||
public:
|
public:
|
||||||
ProgressBar();
|
ProgressBar();
|
||||||
~ProgressBar();
|
~ProgressBar();
|
||||||
void reportProgress(double value);
|
void reportProgress(double value) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void updateLoop();
|
void updateLoop();
|
||||||
|
@ -21,5 +53,3 @@ private:
|
||||||
std::string currentText;
|
std::string currentText;
|
||||||
int animationIndex = 0;
|
int animationIndex = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -88,9 +88,7 @@ int main(int argc, char *argv[]) {
|
||||||
map<centiseconds, Phone> phones;
|
map<centiseconds, Phone> phones;
|
||||||
{
|
{
|
||||||
ProgressBar progressBar;
|
ProgressBar progressBar;
|
||||||
phones = detectPhones(
|
phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
|
||||||
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
|
|
||||||
[&progressBar](double progress) { progressBar.reportProgress(progress); });
|
|
||||||
}
|
}
|
||||||
std::cerr << "Done" << std::endl;
|
std::cerr << "Done" << std::endl;
|
||||||
|
|
||||||
|
|
|
@ -75,13 +75,12 @@ int16_t floatSampleToInt16(float sample) {
|
||||||
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
|
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
|
||||||
}
|
}
|
||||||
|
|
||||||
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, function<void(double)> reportProgress) {
|
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
|
||||||
// Process entire sound file
|
// Process entire sound file
|
||||||
vector<int16_t> buffer;
|
vector<int16_t> buffer;
|
||||||
const int capacity = 1600; // 0.1 second capacity
|
const int capacity = 1600; // 0.1 second capacity
|
||||||
buffer.reserve(capacity);
|
buffer.reserve(capacity);
|
||||||
int sampleCount = 0;
|
int sampleCount = 0;
|
||||||
reportProgress(0);
|
|
||||||
do {
|
do {
|
||||||
// Read to buffer
|
// Read to buffer
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
|
@ -105,7 +104,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
|
||||||
processBuffer(buffer);
|
processBuffer(buffer);
|
||||||
|
|
||||||
sampleCount += buffer.size();
|
sampleCount += buffer.size();
|
||||||
reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
|
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
|
||||||
} while (buffer.size());
|
} while (buffer.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,7 +136,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
|
||||||
*errorString += message;
|
*errorString += message;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) {
|
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||||
// Convert audio stream to the exact format PocketSphinx requires
|
// Convert audio stream to the exact format PocketSphinx requires
|
||||||
audioStream = to16kHzMono(std::move(audioStream));
|
audioStream = to16kHzMono(std::move(audioStream));
|
||||||
|
|
||||||
|
@ -150,7 +149,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
|
||||||
int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false);
|
int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false);
|
||||||
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
|
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
|
||||||
};
|
};
|
||||||
processAudioStream(*audioStream.get(), processBuffer, reportProgress);
|
processAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||||
|
|
||||||
// End recognition
|
// End recognition
|
||||||
error = ps_end_utt(&recognizer);
|
error = ps_end_utt(&recognizer);
|
||||||
|
@ -170,7 +169,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) {
|
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||||
// Create alignment list
|
// Create alignment list
|
||||||
lambda_unique_ptr<ps_alignment_t> alignment(
|
lambda_unique_ptr<ps_alignment_t> alignment(
|
||||||
ps_alignment_init(recognizer.d2p),
|
ps_alignment_init(recognizer.d2p),
|
||||||
|
@ -211,7 +210,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
processAudioStream(*audioStream.get(), processBuffer, reportProgress);
|
processAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||||
|
|
||||||
// End search
|
// End search
|
||||||
ps_search_finish(search.get());
|
ps_search_finish(search.get());
|
||||||
|
@ -240,7 +239,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, function<void(double)> reportProgress) {
|
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
|
||||||
// Discard Pocketsphinx output
|
// Discard Pocketsphinx output
|
||||||
err_set_logfp(nullptr);
|
err_set_logfp(nullptr);
|
||||||
|
|
||||||
|
@ -256,11 +255,15 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
|
||||||
// Create speech recognizer
|
// Create speech recognizer
|
||||||
auto recognizer = createSpeechRecognizer(*config.get());
|
auto recognizer = createSpeechRecognizer(*config.get());
|
||||||
|
|
||||||
|
ProgressMerger progressMerger(progressSink);
|
||||||
|
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
|
||||||
|
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
|
||||||
|
|
||||||
// Recognize words
|
// Recognize words
|
||||||
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), reportProgress);
|
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
|
||||||
|
|
||||||
// Align the word's phones with speech
|
// Align the word's phones with speech
|
||||||
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), reportProgress);
|
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
catch (...) {
|
catch (...) {
|
||||||
|
|
|
@ -6,5 +6,6 @@
|
||||||
#include "audioInput/AudioStream.h"
|
#include "audioInput/AudioStream.h"
|
||||||
#include "Phone.h"
|
#include "Phone.h"
|
||||||
#include "centiseconds.h"
|
#include "centiseconds.h"
|
||||||
|
#include "progressBar.h"
|
||||||
|
|
||||||
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, std::function<void(double)> reportProgress);
|
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);
|
||||||
|
|
Loading…
Reference in New Issue