Sped up recognition via multithreading

This commit is contained in:
Daniel Wolf 2016-06-26 21:05:09 +02:00
parent 3a0a38575f
commit 9bf8355742
3 changed files with 129 additions and 74 deletions

View File

@ -223,6 +223,7 @@ set(SOURCE_FILES
src/languageModels.cpp src/languageModels.h src/languageModels.cpp src/languageModels.h
src/tupleHash.h src/tupleHash.h
src/ThreadPool.cpp src/ThreadPool.h src/ThreadPool.cpp src/ThreadPool.h
src/ObjectPool.h
) )
add_executable(rhubarb ${SOURCE_FILES}) add_executable(rhubarb ${SOURCE_FILES})
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC) target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)

47
src/ObjectPool.h Normal file
View File

@ -0,0 +1,47 @@
#pragma once
#include <memory>
#include <functional>
#include <stack>
#include <mutex>
template <class T>
class ObjectPool {
public:
using ptr_type = std::unique_ptr<T, std::function<void(T*)>>;
ObjectPool(std::function<T*()> createObject) :
createObject(createObject)
{}
virtual ~ObjectPool() {}
ptr_type acquire() {
std::lock_guard<std::mutex> lock(poolMutex);
if (pool.empty()) {
pool.push(std::unique_ptr<T>(createObject()));
}
ptr_type tmp(pool.top().release(), [this](T* p) {
std::lock_guard<std::mutex> lock(poolMutex);
this->pool.push(std::unique_ptr<T>(p));
});
pool.pop();
return std::move(tmp);
}
bool empty() const {
std::lock_guard<std::mutex> lock(poolMutex);
return pool.empty();
}
size_t size() const {
std::lock_guard<std::mutex> lock(poolMutex);
return pool.size();
}
private:
std::function<T*()> createObject;
std::stack<std::unique_ptr<T>> pool;
mutable std::mutex poolMutex;
};

View File

@ -17,6 +17,8 @@
#include "g2p.h" #include "g2p.h"
#include "ContinuousTimeline.h" #include "ContinuousTimeline.h"
#include "audio/processing.h" #include "audio/processing.h"
#include "ThreadPool.h"
#include "ObjectPool.h"
extern "C" { extern "C" {
#include <pocketsphinx.h> #include <pocketsphinx.h>
@ -48,28 +50,6 @@ const path& getSphinxModelDirectory() {
return sphinxModelDirectory; return sphinxModelDirectory;
} }
lambda_unique_ptr<ps_decoder_t> createDecoder() {
lambda_unique_ptr<cmd_ln_t> config(
cmd_ln_init(
nullptr, ps_args(), true,
// Set acoustic model
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
// Set pronunciation dictionary
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
"-dither", "yes",
nullptr),
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
if (!config) throw runtime_error("Error creating configuration.");
lambda_unique_ptr<ps_decoder_t> recognizer(
ps_init(config.get()),
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
if (!recognizer) throw runtime_error("Error creating speech decoder.");
return recognizer;
}
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) { logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
switch (errorLevel) { switch (errorLevel) {
case ERR_DEBUG: case ERR_DEBUG:
@ -254,6 +234,48 @@ void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decode
} }
} }
lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
lambda_unique_ptr<cmd_ln_t> config(
cmd_ln_init(
nullptr, ps_args(), true,
// Set acoustic model
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
// Set pronunciation dictionary
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
"-dither", "yes",
nullptr),
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
if (!config) throw runtime_error("Error creating configuration.");
lambda_unique_ptr<ps_decoder_t> decoder(
ps_init(config.get()),
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
if (!decoder) throw runtime_error("Error creating speech decoder.");
// Set language model
lambda_unique_ptr<ngram_model_t> languageModel;
if (dialog) {
// Create dialog-specific language model
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
words.insert(words.begin(), "<s>");
words.push_back("</s>");
languageModel = createLanguageModel(words, *decoder->lmath);
// Add any dialog-specific words to the dictionary
addMissingDictionaryWords(words, *decoder);
} else {
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
languageModel = lambda_unique_ptr<ngram_model_t>(
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
[](ngram_model_t* lm) { ngram_model_free(lm); });
}
ps_set_lm(decoder.get(), "lm", languageModel.get());
ps_set_search(decoder.get(), "lm");
return decoder;
}
BoundedTimeline<Phone> detectPhones( BoundedTimeline<Phone> detectPhones(
unique_ptr<AudioStream> audioStream, unique_ptr<AudioStream> audioStream,
optional<u32string> dialog, optional<u32string> dialog,
@ -284,68 +306,53 @@ BoundedTimeline<Phone> detectPhones(
} }
auto utteranceProgressSinkIt = utteranceProgressSinks.begin(); auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
// Create speech recognizer
auto decoder = createDecoder();
// Set language model
lambda_unique_ptr<ngram_model_t> languageModel;
if (dialog) {
// Create dialog-specific language model
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
words.insert(words.begin(), "<s>");
words.push_back("</s>");
languageModel = createLanguageModel(words, *decoder->lmath);
// Add any dialog-specific words to the dictionary
addMissingDictionaryWords(words, *decoder);
} else {
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
languageModel = lambda_unique_ptr<ngram_model_t>(
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
[](ngram_model_t* lm) { ngram_model_free(lm); });
}
ps_set_lm(decoder.get(), "lm", languageModel.get());
ps_set_search(decoder.get(), "lm");
BoundedTimeline<Phone> result(audioStream->getTruncatedRange()); BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
std::mutex resultMutex;
logging::debug("Speech recognition -- start"); logging::debug("Speech recognition -- start");
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
ThreadPool threadPool;
for (const auto& timedUtterance : utterances) { for (const auto& timedUtterance : utterances) {
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++); threadPool.addJob([&] {
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0); auto decoder = decoderPool.acquire();
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5); ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
const TimeRange timeRange = timedUtterance.getTimeRange(); const TimeRange timeRange = timedUtterance.getTimeRange();
logging::logTimedEvent("utterance", timeRange, string("")); logging::logTimedEvent("utterance", timeRange, string(""));
auto streamSegment = createSegment(audioStream->clone(true), timeRange); auto streamSegment = createSegment(audioStream->clone(true), timeRange);
// Get words // Get words
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink); BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
for (Timed<string> timedWord : words) { for (Timed<string> timedWord : words) {
timedWord.getTimeRange().shift(timedUtterance.getStart()); timedWord.getTimeRange().shift(timedUtterance.getStart());
logging::logTimedEvent("word", timedWord); logging::logTimedEvent("word", timedWord);
} }
// Look up words in dictionary // Look up words in dictionary
vector<s3wid_t> wordIds; vector<s3wid_t> wordIds;
for (const auto& timedWord : words) { for (const auto& timedWord : words) {
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict)); wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
} }
if (wordIds.empty()) continue; if (wordIds.empty()) return;
// Align the words' phones with speech // Align the words' phones with speech
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink) BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown)); .value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
segmentPhones.shift(timedUtterance.getStart()); segmentPhones.shift(timedUtterance.getStart());
for (const auto& timedPhone : segmentPhones) { for (const auto& timedPhone : segmentPhones) {
logging::logTimedEvent("phone", timedPhone); logging::logTimedEvent("phone", timedPhone);
} }
// Fill result // Fill result
for (const auto& timedPhone : segmentPhones) { std::lock_guard<std::mutex> lock(resultMutex);
result.set(timedPhone); for (const auto& timedPhone : segmentPhones) {
} result.set(timedPhone);
}
});
} }
threadPool.waitAll();
logging::debug("Speech recognition -- end"); logging::debug("Speech recognition -- end");
return result; return result;