Sped up recognition via multithreading

This commit is contained in:
Daniel Wolf 2016-06-26 21:05:09 +02:00
parent 3a0a38575f
commit 9bf8355742
3 changed files with 129 additions and 74 deletions

View File

@ -223,6 +223,7 @@ set(SOURCE_FILES
src/languageModels.cpp src/languageModels.h
src/tupleHash.h
src/ThreadPool.cpp src/ThreadPool.h
src/ObjectPool.h
)
add_executable(rhubarb ${SOURCE_FILES})
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)

47
src/ObjectPool.h Normal file
View File

@ -0,0 +1,47 @@
#pragma once
#include <memory>
#include <functional>
#include <stack>
#include <mutex>
template <class T>
class ObjectPool {
public:
using ptr_type = std::unique_ptr<T, std::function<void(T*)>>;
ObjectPool(std::function<T*()> createObject) :
createObject(createObject)
{}
virtual ~ObjectPool() {}
ptr_type acquire() {
std::lock_guard<std::mutex> lock(poolMutex);
if (pool.empty()) {
pool.push(std::unique_ptr<T>(createObject()));
}
ptr_type tmp(pool.top().release(), [this](T* p) {
std::lock_guard<std::mutex> lock(poolMutex);
this->pool.push(std::unique_ptr<T>(p));
});
pool.pop();
return std::move(tmp);
}
bool empty() const {
std::lock_guard<std::mutex> lock(poolMutex);
return pool.empty();
}
size_t size() const {
std::lock_guard<std::mutex> lock(poolMutex);
return pool.size();
}
private:
std::function<T*()> createObject;
std::stack<std::unique_ptr<T>> pool;
mutable std::mutex poolMutex;
};

View File

@ -17,6 +17,8 @@
#include "g2p.h"
#include "ContinuousTimeline.h"
#include "audio/processing.h"
#include "ThreadPool.h"
#include "ObjectPool.h"
extern "C" {
#include <pocketsphinx.h>
@ -48,28 +50,6 @@ const path& getSphinxModelDirectory() {
return sphinxModelDirectory;
}
lambda_unique_ptr<ps_decoder_t> createDecoder() {
lambda_unique_ptr<cmd_ln_t> config(
cmd_ln_init(
nullptr, ps_args(), true,
// Set acoustic model
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
// Set pronunciation dictionary
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
"-dither", "yes",
nullptr),
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
if (!config) throw runtime_error("Error creating configuration.");
lambda_unique_ptr<ps_decoder_t> recognizer(
ps_init(config.get()),
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
if (!recognizer) throw runtime_error("Error creating speech decoder.");
return recognizer;
}
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
switch (errorLevel) {
case ERR_DEBUG:
@ -254,6 +234,48 @@ void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decode
}
}
lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
lambda_unique_ptr<cmd_ln_t> config(
cmd_ln_init(
nullptr, ps_args(), true,
// Set acoustic model
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
// Set pronunciation dictionary
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
"-dither", "yes",
nullptr),
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
if (!config) throw runtime_error("Error creating configuration.");
lambda_unique_ptr<ps_decoder_t> decoder(
ps_init(config.get()),
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
if (!decoder) throw runtime_error("Error creating speech decoder.");
// Set language model
lambda_unique_ptr<ngram_model_t> languageModel;
if (dialog) {
// Create dialog-specific language model
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
words.insert(words.begin(), "<s>");
words.push_back("</s>");
languageModel = createLanguageModel(words, *decoder->lmath);
// Add any dialog-specific words to the dictionary
addMissingDictionaryWords(words, *decoder);
} else {
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
languageModel = lambda_unique_ptr<ngram_model_t>(
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
[](ngram_model_t* lm) { ngram_model_free(lm); });
}
ps_set_lm(decoder.get(), "lm", languageModel.get());
ps_set_search(decoder.get(), "lm");
return decoder;
}
BoundedTimeline<Phone> detectPhones(
unique_ptr<AudioStream> audioStream,
optional<u32string> dialog,
@ -284,32 +306,14 @@ BoundedTimeline<Phone> detectPhones(
}
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
// Create speech recognizer
auto decoder = createDecoder();
// Set language model
lambda_unique_ptr<ngram_model_t> languageModel;
if (dialog) {
// Create dialog-specific language model
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
words.insert(words.begin(), "<s>");
words.push_back("</s>");
languageModel = createLanguageModel(words, *decoder->lmath);
// Add any dialog-specific words to the dictionary
addMissingDictionaryWords(words, *decoder);
} else {
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
languageModel = lambda_unique_ptr<ngram_model_t>(
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
[](ngram_model_t* lm) { ngram_model_free(lm); });
}
ps_set_lm(decoder.get(), "lm", languageModel.get());
ps_set_search(decoder.get(), "lm");
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
std::mutex resultMutex;
logging::debug("Speech recognition -- start");
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
ThreadPool threadPool;
for (const auto& timedUtterance : utterances) {
threadPool.addJob([&] {
auto decoder = decoderPool.acquire();
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
@ -331,7 +335,7 @@ BoundedTimeline<Phone> detectPhones(
for (const auto& timedWord : words) {
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
}
if (wordIds.empty()) continue;
if (wordIds.empty()) return;
// Align the words' phones with speech
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
@ -342,10 +346,13 @@ BoundedTimeline<Phone> detectPhones(
}
// Fill result
std::lock_guard<std::mutex> lock(resultMutex);
for (const auto& timedPhone : segmentPhones) {
result.set(timedPhone);
}
});
}
threadPool.waitAll();
logging::debug("Speech recognition -- end");
return result;