Sped up recognition via multithreading
This commit is contained in:
parent
3a0a38575f
commit
9bf8355742
|
@ -223,6 +223,7 @@ set(SOURCE_FILES
|
|||
src/languageModels.cpp src/languageModels.h
|
||||
src/tupleHash.h
|
||||
src/ThreadPool.cpp src/ThreadPool.h
|
||||
src/ObjectPool.h
|
||||
)
|
||||
add_executable(rhubarb ${SOURCE_FILES})
|
||||
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
#pragma once
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <stack>
|
||||
#include <mutex>
|
||||
|
||||
template <class T>
|
||||
class ObjectPool {
|
||||
public:
|
||||
using ptr_type = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
||||
ObjectPool(std::function<T*()> createObject) :
|
||||
createObject(createObject)
|
||||
{}
|
||||
|
||||
virtual ~ObjectPool() {}
|
||||
|
||||
ptr_type acquire() {
|
||||
std::lock_guard<std::mutex> lock(poolMutex);
|
||||
|
||||
if (pool.empty()) {
|
||||
pool.push(std::unique_ptr<T>(createObject()));
|
||||
}
|
||||
|
||||
ptr_type tmp(pool.top().release(), [this](T* p) {
|
||||
std::lock_guard<std::mutex> lock(poolMutex);
|
||||
this->pool.push(std::unique_ptr<T>(p));
|
||||
});
|
||||
pool.pop();
|
||||
return std::move(tmp);
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
std::lock_guard<std::mutex> lock(poolMutex);
|
||||
return pool.empty();
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
std::lock_guard<std::mutex> lock(poolMutex);
|
||||
return pool.size();
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<T*()> createObject;
|
||||
std::stack<std::unique_ptr<T>> pool;
|
||||
mutable std::mutex poolMutex;
|
||||
};
|
|
@ -17,6 +17,8 @@
|
|||
#include "g2p.h"
|
||||
#include "ContinuousTimeline.h"
|
||||
#include "audio/processing.h"
|
||||
#include "ThreadPool.h"
|
||||
#include "ObjectPool.h"
|
||||
|
||||
extern "C" {
|
||||
#include <pocketsphinx.h>
|
||||
|
@ -48,28 +50,6 @@ const path& getSphinxModelDirectory() {
|
|||
return sphinxModelDirectory;
|
||||
}
|
||||
|
||||
lambda_unique_ptr<ps_decoder_t> createDecoder() {
|
||||
lambda_unique_ptr<cmd_ln_t> config(
|
||||
cmd_ln_init(
|
||||
nullptr, ps_args(), true,
|
||||
// Set acoustic model
|
||||
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
|
||||
// Set pronunciation dictionary
|
||||
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
|
||||
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
|
||||
"-dither", "yes",
|
||||
nullptr),
|
||||
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
|
||||
if (!config) throw runtime_error("Error creating configuration.");
|
||||
|
||||
lambda_unique_ptr<ps_decoder_t> recognizer(
|
||||
ps_init(config.get()),
|
||||
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
|
||||
if (!recognizer) throw runtime_error("Error creating speech decoder.");
|
||||
|
||||
return recognizer;
|
||||
}
|
||||
|
||||
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
|
||||
switch (errorLevel) {
|
||||
case ERR_DEBUG:
|
||||
|
@ -254,6 +234,48 @@ void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decode
|
|||
}
|
||||
}
|
||||
|
||||
lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
|
||||
lambda_unique_ptr<cmd_ln_t> config(
|
||||
cmd_ln_init(
|
||||
nullptr, ps_args(), true,
|
||||
// Set acoustic model
|
||||
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
|
||||
// Set pronunciation dictionary
|
||||
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
|
||||
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
|
||||
"-dither", "yes",
|
||||
nullptr),
|
||||
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
|
||||
if (!config) throw runtime_error("Error creating configuration.");
|
||||
|
||||
lambda_unique_ptr<ps_decoder_t> decoder(
|
||||
ps_init(config.get()),
|
||||
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
|
||||
if (!decoder) throw runtime_error("Error creating speech decoder.");
|
||||
|
||||
// Set language model
|
||||
lambda_unique_ptr<ngram_model_t> languageModel;
|
||||
if (dialog) {
|
||||
// Create dialog-specific language model
|
||||
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
|
||||
words.insert(words.begin(), "<s>");
|
||||
words.push_back("</s>");
|
||||
languageModel = createLanguageModel(words, *decoder->lmath);
|
||||
|
||||
// Add any dialog-specific words to the dictionary
|
||||
addMissingDictionaryWords(words, *decoder);
|
||||
} else {
|
||||
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
|
||||
languageModel = lambda_unique_ptr<ngram_model_t>(
|
||||
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
|
||||
[](ngram_model_t* lm) { ngram_model_free(lm); });
|
||||
}
|
||||
ps_set_lm(decoder.get(), "lm", languageModel.get());
|
||||
ps_set_search(decoder.get(), "lm");
|
||||
|
||||
return decoder;
|
||||
}
|
||||
|
||||
BoundedTimeline<Phone> detectPhones(
|
||||
unique_ptr<AudioStream> audioStream,
|
||||
optional<u32string> dialog,
|
||||
|
@ -284,68 +306,53 @@ BoundedTimeline<Phone> detectPhones(
|
|||
}
|
||||
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
|
||||
|
||||
// Create speech recognizer
|
||||
auto decoder = createDecoder();
|
||||
|
||||
// Set language model
|
||||
lambda_unique_ptr<ngram_model_t> languageModel;
|
||||
if (dialog) {
|
||||
// Create dialog-specific language model
|
||||
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
|
||||
words.insert(words.begin(), "<s>");
|
||||
words.push_back("</s>");
|
||||
languageModel = createLanguageModel(words, *decoder->lmath);
|
||||
|
||||
// Add any dialog-specific words to the dictionary
|
||||
addMissingDictionaryWords(words, *decoder);
|
||||
} else {
|
||||
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
|
||||
languageModel = lambda_unique_ptr<ngram_model_t>(
|
||||
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
|
||||
[](ngram_model_t* lm) { ngram_model_free(lm); });
|
||||
}
|
||||
ps_set_lm(decoder.get(), "lm", languageModel.get());
|
||||
ps_set_search(decoder.get(), "lm");
|
||||
|
||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
||||
std::mutex resultMutex;
|
||||
logging::debug("Speech recognition -- start");
|
||||
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
|
||||
ThreadPool threadPool;
|
||||
for (const auto& timedUtterance : utterances) {
|
||||
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
|
||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||
threadPool.addJob([&] {
|
||||
auto decoder = decoderPool.acquire();
|
||||
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
|
||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||
|
||||
const TimeRange timeRange = timedUtterance.getTimeRange();
|
||||
logging::logTimedEvent("utterance", timeRange, string(""));
|
||||
const TimeRange timeRange = timedUtterance.getTimeRange();
|
||||
logging::logTimedEvent("utterance", timeRange, string(""));
|
||||
|
||||
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
|
||||
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
|
||||
|
||||
// Get words
|
||||
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
|
||||
for (Timed<string> timedWord : words) {
|
||||
timedWord.getTimeRange().shift(timedUtterance.getStart());
|
||||
logging::logTimedEvent("word", timedWord);
|
||||
}
|
||||
// Get words
|
||||
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
|
||||
for (Timed<string> timedWord : words) {
|
||||
timedWord.getTimeRange().shift(timedUtterance.getStart());
|
||||
logging::logTimedEvent("word", timedWord);
|
||||
}
|
||||
|
||||
// Look up words in dictionary
|
||||
vector<s3wid_t> wordIds;
|
||||
for (const auto& timedWord : words) {
|
||||
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
|
||||
}
|
||||
if (wordIds.empty()) continue;
|
||||
// Look up words in dictionary
|
||||
vector<s3wid_t> wordIds;
|
||||
for (const auto& timedWord : words) {
|
||||
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
|
||||
}
|
||||
if (wordIds.empty()) return;
|
||||
|
||||
// Align the words' phones with speech
|
||||
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
|
||||
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
||||
segmentPhones.shift(timedUtterance.getStart());
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
logging::logTimedEvent("phone", timedPhone);
|
||||
}
|
||||
// Align the words' phones with speech
|
||||
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
|
||||
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
||||
segmentPhones.shift(timedUtterance.getStart());
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
logging::logTimedEvent("phone", timedPhone);
|
||||
}
|
||||
|
||||
// Fill result
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
result.set(timedPhone);
|
||||
}
|
||||
// Fill result
|
||||
std::lock_guard<std::mutex> lock(resultMutex);
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
result.set(timedPhone);
|
||||
}
|
||||
});
|
||||
}
|
||||
threadPool.waitAll();
|
||||
logging::debug("Speech recognition -- end");
|
||||
|
||||
return result;
|
||||
|
|
Loading…
Reference in New Issue