Sped up recognition via multithreading
This commit is contained in:
parent
3a0a38575f
commit
9bf8355742
|
@ -223,6 +223,7 @@ set(SOURCE_FILES
|
||||||
src/languageModels.cpp src/languageModels.h
|
src/languageModels.cpp src/languageModels.h
|
||||||
src/tupleHash.h
|
src/tupleHash.h
|
||||||
src/ThreadPool.cpp src/ThreadPool.h
|
src/ThreadPool.cpp src/ThreadPool.h
|
||||||
|
src/ObjectPool.h
|
||||||
)
|
)
|
||||||
add_executable(rhubarb ${SOURCE_FILES})
|
add_executable(rhubarb ${SOURCE_FILES})
|
||||||
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)
|
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
#pragma once
|
||||||
|
#include <memory>
|
||||||
|
#include <functional>
|
||||||
|
#include <stack>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class ObjectPool {
|
||||||
|
public:
|
||||||
|
using ptr_type = std::unique_ptr<T, std::function<void(T*)>>;
|
||||||
|
|
||||||
|
ObjectPool(std::function<T*()> createObject) :
|
||||||
|
createObject(createObject)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual ~ObjectPool() {}
|
||||||
|
|
||||||
|
ptr_type acquire() {
|
||||||
|
std::lock_guard<std::mutex> lock(poolMutex);
|
||||||
|
|
||||||
|
if (pool.empty()) {
|
||||||
|
pool.push(std::unique_ptr<T>(createObject()));
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr_type tmp(pool.top().release(), [this](T* p) {
|
||||||
|
std::lock_guard<std::mutex> lock(poolMutex);
|
||||||
|
this->pool.push(std::unique_ptr<T>(p));
|
||||||
|
});
|
||||||
|
pool.pop();
|
||||||
|
return std::move(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
std::lock_guard<std::mutex> lock(poolMutex);
|
||||||
|
return pool.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
std::lock_guard<std::mutex> lock(poolMutex);
|
||||||
|
return pool.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<T*()> createObject;
|
||||||
|
std::stack<std::unique_ptr<T>> pool;
|
||||||
|
mutable std::mutex poolMutex;
|
||||||
|
};
|
|
@ -17,6 +17,8 @@
|
||||||
#include "g2p.h"
|
#include "g2p.h"
|
||||||
#include "ContinuousTimeline.h"
|
#include "ContinuousTimeline.h"
|
||||||
#include "audio/processing.h"
|
#include "audio/processing.h"
|
||||||
|
#include "ThreadPool.h"
|
||||||
|
#include "ObjectPool.h"
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#include <pocketsphinx.h>
|
#include <pocketsphinx.h>
|
||||||
|
@ -48,28 +50,6 @@ const path& getSphinxModelDirectory() {
|
||||||
return sphinxModelDirectory;
|
return sphinxModelDirectory;
|
||||||
}
|
}
|
||||||
|
|
||||||
lambda_unique_ptr<ps_decoder_t> createDecoder() {
|
|
||||||
lambda_unique_ptr<cmd_ln_t> config(
|
|
||||||
cmd_ln_init(
|
|
||||||
nullptr, ps_args(), true,
|
|
||||||
// Set acoustic model
|
|
||||||
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
|
|
||||||
// Set pronunciation dictionary
|
|
||||||
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
|
|
||||||
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
|
|
||||||
"-dither", "yes",
|
|
||||||
nullptr),
|
|
||||||
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
|
|
||||||
if (!config) throw runtime_error("Error creating configuration.");
|
|
||||||
|
|
||||||
lambda_unique_ptr<ps_decoder_t> recognizer(
|
|
||||||
ps_init(config.get()),
|
|
||||||
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
|
|
||||||
if (!recognizer) throw runtime_error("Error creating speech decoder.");
|
|
||||||
|
|
||||||
return recognizer;
|
|
||||||
}
|
|
||||||
|
|
||||||
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
|
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
|
||||||
switch (errorLevel) {
|
switch (errorLevel) {
|
||||||
case ERR_DEBUG:
|
case ERR_DEBUG:
|
||||||
|
@ -254,6 +234,48 @@ void addMissingDictionaryWords(const vector<string>& words, ps_decoder_t& decode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
|
||||||
|
lambda_unique_ptr<cmd_ln_t> config(
|
||||||
|
cmd_ln_init(
|
||||||
|
nullptr, ps_args(), true,
|
||||||
|
// Set acoustic model
|
||||||
|
"-hmm", (getSphinxModelDirectory() / "acoustic-model").string().c_str(),
|
||||||
|
// Set pronunciation dictionary
|
||||||
|
"-dict", (getSphinxModelDirectory() / "cmudict-en-us.dict").string().c_str(),
|
||||||
|
// Add noise against zero silence (see http://cmusphinx.sourceforge.net/wiki/faq#qwhy_my_accuracy_is_poor)
|
||||||
|
"-dither", "yes",
|
||||||
|
nullptr),
|
||||||
|
[](cmd_ln_t* config) { cmd_ln_free_r(config); });
|
||||||
|
if (!config) throw runtime_error("Error creating configuration.");
|
||||||
|
|
||||||
|
lambda_unique_ptr<ps_decoder_t> decoder(
|
||||||
|
ps_init(config.get()),
|
||||||
|
[](ps_decoder_t* recognizer) { ps_free(recognizer); });
|
||||||
|
if (!decoder) throw runtime_error("Error creating speech decoder.");
|
||||||
|
|
||||||
|
// Set language model
|
||||||
|
lambda_unique_ptr<ngram_model_t> languageModel;
|
||||||
|
if (dialog) {
|
||||||
|
// Create dialog-specific language model
|
||||||
|
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
|
||||||
|
words.insert(words.begin(), "<s>");
|
||||||
|
words.push_back("</s>");
|
||||||
|
languageModel = createLanguageModel(words, *decoder->lmath);
|
||||||
|
|
||||||
|
// Add any dialog-specific words to the dictionary
|
||||||
|
addMissingDictionaryWords(words, *decoder);
|
||||||
|
} else {
|
||||||
|
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
|
||||||
|
languageModel = lambda_unique_ptr<ngram_model_t>(
|
||||||
|
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
|
||||||
|
[](ngram_model_t* lm) { ngram_model_free(lm); });
|
||||||
|
}
|
||||||
|
ps_set_lm(decoder.get(), "lm", languageModel.get());
|
||||||
|
ps_set_search(decoder.get(), "lm");
|
||||||
|
|
||||||
|
return decoder;
|
||||||
|
}
|
||||||
|
|
||||||
BoundedTimeline<Phone> detectPhones(
|
BoundedTimeline<Phone> detectPhones(
|
||||||
unique_ptr<AudioStream> audioStream,
|
unique_ptr<AudioStream> audioStream,
|
||||||
optional<u32string> dialog,
|
optional<u32string> dialog,
|
||||||
|
@ -284,68 +306,53 @@ BoundedTimeline<Phone> detectPhones(
|
||||||
}
|
}
|
||||||
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
|
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
|
||||||
|
|
||||||
// Create speech recognizer
|
|
||||||
auto decoder = createDecoder();
|
|
||||||
|
|
||||||
// Set language model
|
|
||||||
lambda_unique_ptr<ngram_model_t> languageModel;
|
|
||||||
if (dialog) {
|
|
||||||
// Create dialog-specific language model
|
|
||||||
vector<string> words = tokenizeText(*dialog, [&](const string& word) { return dictionaryContains(*decoder->dict, word); });
|
|
||||||
words.insert(words.begin(), "<s>");
|
|
||||||
words.push_back("</s>");
|
|
||||||
languageModel = createLanguageModel(words, *decoder->lmath);
|
|
||||||
|
|
||||||
// Add any dialog-specific words to the dictionary
|
|
||||||
addMissingDictionaryWords(words, *decoder);
|
|
||||||
} else {
|
|
||||||
path modelPath = getSphinxModelDirectory() / "en-us.lm.bin";
|
|
||||||
languageModel = lambda_unique_ptr<ngram_model_t>(
|
|
||||||
ngram_model_read(decoder->config, modelPath.string().c_str(), NGRAM_AUTO, decoder->lmath),
|
|
||||||
[](ngram_model_t* lm) { ngram_model_free(lm); });
|
|
||||||
}
|
|
||||||
ps_set_lm(decoder.get(), "lm", languageModel.get());
|
|
||||||
ps_set_search(decoder.get(), "lm");
|
|
||||||
|
|
||||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
||||||
|
std::mutex resultMutex;
|
||||||
logging::debug("Speech recognition -- start");
|
logging::debug("Speech recognition -- start");
|
||||||
|
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
|
||||||
|
ThreadPool threadPool;
|
||||||
for (const auto& timedUtterance : utterances) {
|
for (const auto& timedUtterance : utterances) {
|
||||||
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
|
threadPool.addJob([&] {
|
||||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
auto decoder = decoderPool.acquire();
|
||||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
|
||||||
|
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||||
|
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||||
|
|
||||||
const TimeRange timeRange = timedUtterance.getTimeRange();
|
const TimeRange timeRange = timedUtterance.getTimeRange();
|
||||||
logging::logTimedEvent("utterance", timeRange, string(""));
|
logging::logTimedEvent("utterance", timeRange, string(""));
|
||||||
|
|
||||||
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
|
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
|
||||||
|
|
||||||
// Get words
|
// Get words
|
||||||
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
|
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
|
||||||
for (Timed<string> timedWord : words) {
|
for (Timed<string> timedWord : words) {
|
||||||
timedWord.getTimeRange().shift(timedUtterance.getStart());
|
timedWord.getTimeRange().shift(timedUtterance.getStart());
|
||||||
logging::logTimedEvent("word", timedWord);
|
logging::logTimedEvent("word", timedWord);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up words in dictionary
|
// Look up words in dictionary
|
||||||
vector<s3wid_t> wordIds;
|
vector<s3wid_t> wordIds;
|
||||||
for (const auto& timedWord : words) {
|
for (const auto& timedWord : words) {
|
||||||
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
|
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
|
||||||
}
|
}
|
||||||
if (wordIds.empty()) continue;
|
if (wordIds.empty()) return;
|
||||||
|
|
||||||
// Align the words' phones with speech
|
// Align the words' phones with speech
|
||||||
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
|
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
|
||||||
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
||||||
segmentPhones.shift(timedUtterance.getStart());
|
segmentPhones.shift(timedUtterance.getStart());
|
||||||
for (const auto& timedPhone : segmentPhones) {
|
for (const auto& timedPhone : segmentPhones) {
|
||||||
logging::logTimedEvent("phone", timedPhone);
|
logging::logTimedEvent("phone", timedPhone);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill result
|
// Fill result
|
||||||
for (const auto& timedPhone : segmentPhones) {
|
std::lock_guard<std::mutex> lock(resultMutex);
|
||||||
result.set(timedPhone);
|
for (const auto& timedPhone : segmentPhones) {
|
||||||
}
|
result.set(timedPhone);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
threadPool.waitAll();
|
||||||
logging::debug("Speech recognition -- end");
|
logging::debug("Speech recognition -- end");
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
Loading…
Reference in New Issue