Multithreaded recognition: refactoring and fixes

* Decoders are correctly released after use
* Determining optimal thread count for multithreading
This commit is contained in:
Daniel Wolf 2016-06-29 21:44:17 +02:00
parent f13449f810
commit 2d314f4bc7
6 changed files with 152 additions and 73 deletions

View File

@ -11,6 +11,10 @@ public:
using typename Timeline<T>::iterator;
using Timeline<T>::end;
BoundedTimeline() :
range(TimeRange::zero())
{}
explicit BoundedTimeline(TimeRange range) :
range(range)
{}

View File

@ -32,7 +32,7 @@ ThreadPool::~ThreadPool() {
}
}
void ThreadPool::addJob(job_t job) {
void ThreadPool::schedule(job_t job) {
std::lock_guard<std::mutex> guard(queueMutex);
jobQueue.emplace_back(job);
++remainingJobCount;

View File

@ -6,6 +6,7 @@
#include <list>
#include <functional>
#include <vector>
#include "progressBar.h"
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
@ -32,7 +33,15 @@ public:
// Adds a new job to the pool.
// If there are no queued jobs, a thread is woken up to take the job.
// If all threads are busy, the job is added to the end of the queue.
void addJob(job_t job);
void schedule(job_t job);
// Asynchronously runs a function for every element of a collection.
template<typename TCollection>
void schedule(
TCollection& collection,
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
ProgressSink& progressSink,
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; });
// Blocks until all jobs have finshed executing
void waitAll();
@ -57,3 +66,21 @@ private:
job_t getNextJob();
};
template <typename TCollection>
void ThreadPool::schedule(
TCollection& collection,
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
ProgressSink& progressSink,
std::function<double(const typename TCollection::reference)> getElementProgressWeight)
{
// Use shared pointer to keep progress merger alive throughout execution
auto progressMerger = std::make_shared<ProgressMerger>(progressSink);
// Schedule all elements
for (auto& element : collection) {
ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element));
schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] {
processElement(element, elementProgressSink);
});
}
}

View File

@ -52,8 +52,9 @@ public:
using reverse_iterator = typename set_type::reverse_iterator;
using size_type = size_t;
using value_type = Timed<T>;
using reference = const value_type&;
class reference {
class ReferenceWrapper {
public:
operator boost::optional<const T&>() const {
auto optional = timeline.get(time);
@ -71,7 +72,7 @@ public:
return optional->getValue();
}
reference& operator=(boost::optional<const T&> value) {
ReferenceWrapper& operator=(boost::optional<const T&> value) {
if (value) {
timeline.set(time, time + time_type(1), *value);
} else {
@ -83,7 +84,7 @@ public:
private:
friend class Timeline;
reference(Timeline& timeline, time_type time) :
ReferenceWrapper(Timeline& timeline, time_type time) :
timeline(timeline),
time(time)
{}
@ -232,13 +233,13 @@ public:
return set(Timed<void>(start, end));
}
reference operator[](time_type time) {
return reference(*this, time);
ReferenceWrapper operator[](time_type time) {
return ReferenceWrapper(*this, time);
}
// ReSharper disable once CppConstValueFunctionReturnType
const reference operator[](time_type time) const {
return reference(*this, time);
const ReferenceWrapper operator[](time_type time) const {
return ReferenceWrapper(*this, time);
}
virtual void shift(time_type offset) {

View File

@ -59,7 +59,7 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
for (int i = 0; i < segmentCount; ++i) {
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
ProgressSink& segmentProgressSink = progressMerger.addSink(1.0);
threadPool.addJob([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
threadPool.schedule([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
std::unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);

View File

@ -42,6 +42,7 @@ using std::regex_replace;
using std::chrono::duration;
using boost::optional;
using std::u32string;
using std::chrono::duration_cast;
constexpr int sphinxSampleRate = 16000;
@ -134,7 +135,7 @@ s3wid_t getWordId(const string& word, dict_t& dictionary) {
return wordId;
}
optional<BoundedTimeline<Phone>> getPhoneAlignment(
optional<Timeline<Phone>> getPhoneAlignment(
const vector<s3wid_t>& wordIds,
unique_ptr<AudioStream> audioStream,
ps_decoder_t& decoder,
@ -193,7 +194,7 @@ optional<BoundedTimeline<Phone>> getPhoneAlignment(
// Extract phones with timestamps
char** phoneNames = decoder.dict->mdef->ciname;
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
Timeline<Phone> result;
for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) {
// Get phone
ps_alignment_entry_t* phoneEntry = ps_alignment_iter_get(it);
@ -276,82 +277,128 @@ lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
return decoder;
}
Timeline<Phone> utteranceToPhones(
AudioStream& audioStream,
TimeRange utterance,
ps_decoder_t& decoder,
ProgressSink& utteranceProgressSink)
{
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
auto streamSegment = createSegment(audioStream.clone(true), utterance);
// Get words
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), decoder, wordRecognitionProgressSink);
for (Timed<string> timedWord : words) {
timedWord.getTimeRange().shift(utterance.getStart());
logging::logTimedEvent("word", timedWord);
}
// Look up words in dictionary
vector<s3wid_t> wordIds;
for (const auto& timedWord : words) {
wordIds.push_back(getWordId(timedWord.getValue(), *decoder.dict));
}
if (wordIds.empty()) return Timeline<Phone>();
// Align the words' phones with speech
Timeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), decoder, alignmentProgressSink)
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
segmentPhones.shift(utterance.getStart());
for (const auto& timedPhone : segmentPhones) {
logging::logTimedEvent("phone", timedPhone);
}
return segmentPhones;
}
BoundedTimeline<Phone> detectPhones(
unique_ptr<AudioStream> audioStream,
optional<u32string> dialog,
ProgressSink& progressSink)
{
ProgressMerger totalProgressMerger(progressSink);
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
// Make sure audio stream has no DC offset
audioStream = removeDCOffset(std::move(audioStream));
// Split audio into utterances
BoundedTimeline<void> utterances;
try {
utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
}
catch (...) {
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
}
// Discard Pocketsphinx output
err_set_logfp(nullptr);
// Redirect Pocketsphinx output to log
err_set_callback(sphinxLogCallback, nullptr);
// Make sure audio stream has no DC offset
audioStream = removeDCOffset(std::move(audioStream));
// Prepare pool of decoders
std::stack<lambda_unique_ptr<ps_decoder_t>> decoderPool;
std::mutex decoderPoolMutex;
auto getDecoder = [&] {
std::lock_guard<std::mutex> lock(decoderPoolMutex);
if (decoderPool.empty()) {
decoderPool.push(createDecoder(dialog));
}
auto decoder = std::move(decoderPool.top());
decoderPool.pop();
return std::move(decoder);
};
auto returnDecoder = [&](lambda_unique_ptr<ps_decoder_t> decoder) {
std::lock_guard<std::mutex> lock(decoderPoolMutex);
decoderPool.push(std::move(decoder));
};
ProgressMerger totalProgressMerger(progressSink);
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
std::mutex resultMutex, audioStreamMutex;
auto processUtterance = [&](Timed<void> timedUtterance, ProgressSink& utteranceProgressSink) {
logging::logTimedEvent("utterance", timedUtterance.getTimeRange(), string(""));
// Detect phones for utterance
auto decoder = getDecoder();
std::unique_ptr<AudioStream> audioStreamCopy;
{
std::lock_guard<std::mutex> lock(audioStreamMutex);
audioStreamCopy = audioStream->clone(true);
}
Timeline<Phone> phones =
utteranceToPhones(*audioStreamCopy, timedUtterance.getTimeRange(), *decoder, utteranceProgressSink);
returnDecoder(std::move(decoder));
// Copy phones to result timeline
std::lock_guard<std::mutex> lock(resultMutex);
for (const auto& timedPhone : phones) {
result.set(timedPhone);
}
};
auto getUtteranceProgressWeight = [](const Timed<void> timedUtterance) {
return timedUtterance.getTimeRange().getLength().count();
};
// Perform speech recognition
try {
// Split audio into utterances
BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
// For progress reporting: weigh utterances by length
ProgressMerger dialogProgressMerger(dialogProgressSink);
vector<ProgressSink*> utteranceProgressSinks;
for (const auto& timedUtterance : utterances) {
utteranceProgressSinks.push_back(&dialogProgressMerger.addSink(timedUtterance.getTimeRange().getLength().count()));
}
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
std::mutex resultMutex;
// Determine how many parallel threads to use
int threadCount = std::min({
// Don't use more threads than there are CPU cores
ThreadPool::getRecommendedThreadCount(),
// Don't use more threads than there are utterances to be processed
static_cast<int>(utterances.size()),
// Don't waste time creating additional threads (and decoders!) if the recording is short
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
});
ThreadPool threadPool(threadCount);
logging::debug("Speech recognition -- start");
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
ThreadPool threadPool;
for (const auto& timedUtterance : utterances) {
threadPool.addJob([&] {
auto decoder = decoderPool.acquire();
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
const TimeRange timeRange = timedUtterance.getTimeRange();
logging::logTimedEvent("utterance", timeRange, string(""));
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
// Get words
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
for (Timed<string> timedWord : words) {
timedWord.getTimeRange().shift(timedUtterance.getStart());
logging::logTimedEvent("word", timedWord);
}
// Look up words in dictionary
vector<s3wid_t> wordIds;
for (const auto& timedWord : words) {
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
}
if (wordIds.empty()) return;
// Align the words' phones with speech
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
segmentPhones.shift(timedUtterance.getStart());
for (const auto& timedPhone : segmentPhones) {
logging::logTimedEvent("phone", timedPhone);
}
// Fill result
std::lock_guard<std::mutex> lock(resultMutex);
for (const auto& timedPhone : segmentPhones) {
result.set(timedPhone);
}
});
}
threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight);
threadPool.waitAll();
logging::debug("Speech recognition -- end");