Showing combined progress for entire task
This commit is contained in:
parent
8e7fcc4efe
commit
c425885929
|
@ -7,6 +7,46 @@
|
|||
|
||||
using std::string;
|
||||
|
||||
ProgressForwarder::ProgressForwarder(std::function<void(double progress)> callback) :
|
||||
callback(callback)
|
||||
{}
|
||||
|
||||
void ProgressForwarder::reportProgress(double value) {
|
||||
callback(value);
|
||||
}
|
||||
|
||||
ProgressMerger::ProgressMerger(ProgressSink& sink) :
|
||||
sink(sink)
|
||||
{}
|
||||
|
||||
ProgressSink& ProgressMerger::addSink(double weight) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
totalWeight += weight;
|
||||
int sinkIndex = weightedValues.size();
|
||||
weightedValues.push_back(0);
|
||||
forwarders.push_back(ProgressForwarder([weight, sinkIndex, this](double progress) {
|
||||
weightedValues[sinkIndex] = progress * weight;
|
||||
report();
|
||||
}));
|
||||
return forwarders.back();
|
||||
}
|
||||
|
||||
void ProgressMerger::report() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
if (totalWeight != 0) {
|
||||
double weightedSum = 0;
|
||||
for (double weightedValue : weightedValues) {
|
||||
weightedSum += weightedValue;
|
||||
}
|
||||
double progress = weightedSum / totalWeight;
|
||||
sink.reportProgress(progress);
|
||||
} else {
|
||||
sink.reportProgress(0);
|
||||
}
|
||||
}
|
||||
|
||||
ProgressBar::ProgressBar() {
|
||||
updateLoopFuture = std::async(std::launch::async, &ProgressBar::updateLoop, this);
|
||||
}
|
||||
|
|
|
@ -3,12 +3,44 @@
|
|||
#include <string>
|
||||
#include <atomic>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
class ProgressBar {
|
||||
class ProgressSink {
|
||||
public:
|
||||
virtual ~ProgressSink() {}
|
||||
virtual void reportProgress(double value) = 0;
|
||||
};
|
||||
|
||||
class ProgressForwarder : public ProgressSink {
|
||||
public:
|
||||
ProgressForwarder(std::function<void(double progress)> callback);
|
||||
void reportProgress(double value) override;
|
||||
private:
|
||||
std::function<void(double progress)> callback;
|
||||
};
|
||||
|
||||
class ProgressMerger {
|
||||
public:
|
||||
ProgressMerger(ProgressSink& sink);
|
||||
ProgressSink& addSink(double weight);
|
||||
private:
|
||||
void report();
|
||||
|
||||
ProgressSink& sink;
|
||||
std::mutex mutex;
|
||||
double totalWeight = 0;
|
||||
std::list<ProgressForwarder> forwarders;
|
||||
std::vector<double> weightedValues;
|
||||
};
|
||||
|
||||
class ProgressBar : public ProgressSink {
|
||||
public:
|
||||
ProgressBar();
|
||||
~ProgressBar();
|
||||
void reportProgress(double value);
|
||||
void reportProgress(double value) override;
|
||||
|
||||
private:
|
||||
void updateLoop();
|
||||
|
@ -21,5 +53,3 @@ private:
|
|||
std::string currentText;
|
||||
int animationIndex = 0;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -88,9 +88,7 @@ int main(int argc, char *argv[]) {
|
|||
map<centiseconds, Phone> phones;
|
||||
{
|
||||
ProgressBar progressBar;
|
||||
phones = detectPhones(
|
||||
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
|
||||
[&progressBar](double progress) { progressBar.reportProgress(progress); });
|
||||
phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
|
||||
}
|
||||
std::cerr << "Done" << std::endl;
|
||||
|
||||
|
|
|
@ -75,13 +75,12 @@ int16_t floatSampleToInt16(float sample) {
|
|||
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
|
||||
}
|
||||
|
||||
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, function<void(double)> reportProgress) {
|
||||
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
|
||||
// Process entire sound file
|
||||
vector<int16_t> buffer;
|
||||
const int capacity = 1600; // 0.1 second capacity
|
||||
buffer.reserve(capacity);
|
||||
int sampleCount = 0;
|
||||
reportProgress(0);
|
||||
do {
|
||||
// Read to buffer
|
||||
buffer.clear();
|
||||
|
@ -105,7 +104,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
|
|||
processBuffer(buffer);
|
||||
|
||||
sampleCount += buffer.size();
|
||||
reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
|
||||
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
|
||||
} while (buffer.size());
|
||||
}
|
||||
|
||||
|
@ -137,7 +136,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
|
|||
*errorString += message;
|
||||
}
|
||||
|
||||
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) {
|
||||
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||
// Convert audio stream to the exact format PocketSphinx requires
|
||||
audioStream = to16kHzMono(std::move(audioStream));
|
||||
|
||||
|
@ -150,7 +149,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
|
|||
int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false);
|
||||
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
|
||||
};
|
||||
processAudioStream(*audioStream.get(), processBuffer, reportProgress);
|
||||
processAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||
|
||||
// End recognition
|
||||
error = ps_end_utt(&recognizer);
|
||||
|
@ -170,7 +169,7 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
|
|||
return result;
|
||||
}
|
||||
|
||||
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, function<void(double)> reportProgress) {
|
||||
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||
// Create alignment list
|
||||
lambda_unique_ptr<ps_alignment_t> alignment(
|
||||
ps_alignment_init(recognizer.d2p),
|
||||
|
@ -211,7 +210,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
|
|||
}
|
||||
}
|
||||
};
|
||||
processAudioStream(*audioStream.get(), processBuffer, reportProgress);
|
||||
processAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||
|
||||
// End search
|
||||
ps_search_finish(search.get());
|
||||
|
@ -240,7 +239,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
|
|||
return result;
|
||||
}
|
||||
|
||||
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, function<void(double)> reportProgress) {
|
||||
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
|
||||
// Discard Pocketsphinx output
|
||||
err_set_logfp(nullptr);
|
||||
|
||||
|
@ -256,11 +255,15 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
|
|||
// Create speech recognizer
|
||||
auto recognizer = createSpeechRecognizer(*config.get());
|
||||
|
||||
ProgressMerger progressMerger(progressSink);
|
||||
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
|
||||
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
|
||||
|
||||
// Recognize words
|
||||
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), reportProgress);
|
||||
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
|
||||
|
||||
// Align the word's phones with speech
|
||||
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), reportProgress);
|
||||
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
|
||||
return result;
|
||||
}
|
||||
catch (...) {
|
||||
|
|
|
@ -6,5 +6,6 @@
|
|||
#include "audioInput/AudioStream.h"
|
||||
#include "Phone.h"
|
||||
#include "centiseconds.h"
|
||||
#include "progressBar.h"
|
||||
|
||||
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, std::function<void(double)> reportProgress);
|
||||
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);
|
||||
|
|
Loading…
Reference in New Issue