Multithreaded recognition: refactoring and fixes
* Decoders are correctly released after use * Determining optimal thread count for multithreading
This commit is contained in:
parent
f13449f810
commit
2d314f4bc7
|
@ -11,6 +11,10 @@ public:
|
||||||
using typename Timeline<T>::iterator;
|
using typename Timeline<T>::iterator;
|
||||||
using Timeline<T>::end;
|
using Timeline<T>::end;
|
||||||
|
|
||||||
|
BoundedTimeline() :
|
||||||
|
range(TimeRange::zero())
|
||||||
|
{}
|
||||||
|
|
||||||
explicit BoundedTimeline(TimeRange range) :
|
explicit BoundedTimeline(TimeRange range) :
|
||||||
range(range)
|
range(range)
|
||||||
{}
|
{}
|
||||||
|
|
|
@ -32,7 +32,7 @@ ThreadPool::~ThreadPool() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ThreadPool::addJob(job_t job) {
|
void ThreadPool::schedule(job_t job) {
|
||||||
std::lock_guard<std::mutex> guard(queueMutex);
|
std::lock_guard<std::mutex> guard(queueMutex);
|
||||||
jobQueue.emplace_back(job);
|
jobQueue.emplace_back(job);
|
||||||
++remainingJobCount;
|
++remainingJobCount;
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "progressBar.h"
|
||||||
|
|
||||||
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
|
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
|
||||||
|
|
||||||
|
@ -32,7 +33,15 @@ public:
|
||||||
// Adds a new job to the pool.
|
// Adds a new job to the pool.
|
||||||
// If there are no queued jobs, a thread is woken up to take the job.
|
// If there are no queued jobs, a thread is woken up to take the job.
|
||||||
// If all threads are busy, the job is added to the end of the queue.
|
// If all threads are busy, the job is added to the end of the queue.
|
||||||
void addJob(job_t job);
|
void schedule(job_t job);
|
||||||
|
|
||||||
|
// Asynchronously runs a function for every element of a collection.
|
||||||
|
template<typename TCollection>
|
||||||
|
void schedule(
|
||||||
|
TCollection& collection,
|
||||||
|
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||||
|
ProgressSink& progressSink,
|
||||||
|
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; });
|
||||||
|
|
||||||
// Blocks until all jobs have finshed executing
|
// Blocks until all jobs have finshed executing
|
||||||
void waitAll();
|
void waitAll();
|
||||||
|
@ -57,3 +66,21 @@ private:
|
||||||
job_t getNextJob();
|
job_t getNextJob();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TCollection>
|
||||||
|
void ThreadPool::schedule(
|
||||||
|
TCollection& collection,
|
||||||
|
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||||
|
ProgressSink& progressSink,
|
||||||
|
std::function<double(const typename TCollection::reference)> getElementProgressWeight)
|
||||||
|
{
|
||||||
|
// Use shared pointer to keep progress merger alive throughout execution
|
||||||
|
auto progressMerger = std::make_shared<ProgressMerger>(progressSink);
|
||||||
|
|
||||||
|
// Schedule all elements
|
||||||
|
for (auto& element : collection) {
|
||||||
|
ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element));
|
||||||
|
schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] {
|
||||||
|
processElement(element, elementProgressSink);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -52,8 +52,9 @@ public:
|
||||||
using reverse_iterator = typename set_type::reverse_iterator;
|
using reverse_iterator = typename set_type::reverse_iterator;
|
||||||
using size_type = size_t;
|
using size_type = size_t;
|
||||||
using value_type = Timed<T>;
|
using value_type = Timed<T>;
|
||||||
|
using reference = const value_type&;
|
||||||
|
|
||||||
class reference {
|
class ReferenceWrapper {
|
||||||
public:
|
public:
|
||||||
operator boost::optional<const T&>() const {
|
operator boost::optional<const T&>() const {
|
||||||
auto optional = timeline.get(time);
|
auto optional = timeline.get(time);
|
||||||
|
@ -71,7 +72,7 @@ public:
|
||||||
return optional->getValue();
|
return optional->getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
reference& operator=(boost::optional<const T&> value) {
|
ReferenceWrapper& operator=(boost::optional<const T&> value) {
|
||||||
if (value) {
|
if (value) {
|
||||||
timeline.set(time, time + time_type(1), *value);
|
timeline.set(time, time + time_type(1), *value);
|
||||||
} else {
|
} else {
|
||||||
|
@ -83,7 +84,7 @@ public:
|
||||||
private:
|
private:
|
||||||
friend class Timeline;
|
friend class Timeline;
|
||||||
|
|
||||||
reference(Timeline& timeline, time_type time) :
|
ReferenceWrapper(Timeline& timeline, time_type time) :
|
||||||
timeline(timeline),
|
timeline(timeline),
|
||||||
time(time)
|
time(time)
|
||||||
{}
|
{}
|
||||||
|
@ -232,13 +233,13 @@ public:
|
||||||
return set(Timed<void>(start, end));
|
return set(Timed<void>(start, end));
|
||||||
}
|
}
|
||||||
|
|
||||||
reference operator[](time_type time) {
|
ReferenceWrapper operator[](time_type time) {
|
||||||
return reference(*this, time);
|
return ReferenceWrapper(*this, time);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReSharper disable once CppConstValueFunctionReturnType
|
// ReSharper disable once CppConstValueFunctionReturnType
|
||||||
const reference operator[](time_type time) const {
|
const ReferenceWrapper operator[](time_type time) const {
|
||||||
return reference(*this, time);
|
return ReferenceWrapper(*this, time);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void shift(time_type offset) {
|
virtual void shift(time_type offset) {
|
||||||
|
|
|
@ -59,7 +59,7 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
||||||
for (int i = 0; i < segmentCount; ++i) {
|
for (int i = 0; i < segmentCount; ++i) {
|
||||||
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
||||||
ProgressSink& segmentProgressSink = progressMerger.addSink(1.0);
|
ProgressSink& segmentProgressSink = progressMerger.addSink(1.0);
|
||||||
threadPool.addJob([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
|
threadPool.schedule([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
|
||||||
std::unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
std::unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
||||||
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ using std::regex_replace;
|
||||||
using std::chrono::duration;
|
using std::chrono::duration;
|
||||||
using boost::optional;
|
using boost::optional;
|
||||||
using std::u32string;
|
using std::u32string;
|
||||||
|
using std::chrono::duration_cast;
|
||||||
|
|
||||||
constexpr int sphinxSampleRate = 16000;
|
constexpr int sphinxSampleRate = 16000;
|
||||||
|
|
||||||
|
@ -134,7 +135,7 @@ s3wid_t getWordId(const string& word, dict_t& dictionary) {
|
||||||
return wordId;
|
return wordId;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<BoundedTimeline<Phone>> getPhoneAlignment(
|
optional<Timeline<Phone>> getPhoneAlignment(
|
||||||
const vector<s3wid_t>& wordIds,
|
const vector<s3wid_t>& wordIds,
|
||||||
unique_ptr<AudioStream> audioStream,
|
unique_ptr<AudioStream> audioStream,
|
||||||
ps_decoder_t& decoder,
|
ps_decoder_t& decoder,
|
||||||
|
@ -193,7 +194,7 @@ optional<BoundedTimeline<Phone>> getPhoneAlignment(
|
||||||
|
|
||||||
// Extract phones with timestamps
|
// Extract phones with timestamps
|
||||||
char** phoneNames = decoder.dict->mdef->ciname;
|
char** phoneNames = decoder.dict->mdef->ciname;
|
||||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
Timeline<Phone> result;
|
||||||
for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) {
|
for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) {
|
||||||
// Get phone
|
// Get phone
|
||||||
ps_alignment_entry_t* phoneEntry = ps_alignment_iter_get(it);
|
ps_alignment_entry_t* phoneEntry = ps_alignment_iter_get(it);
|
||||||
|
@ -276,82 +277,128 @@ lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
|
||||||
return decoder;
|
return decoder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Timeline<Phone> utteranceToPhones(
|
||||||
|
AudioStream& audioStream,
|
||||||
|
TimeRange utterance,
|
||||||
|
ps_decoder_t& decoder,
|
||||||
|
ProgressSink& utteranceProgressSink)
|
||||||
|
{
|
||||||
|
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
|
||||||
|
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||||
|
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||||
|
|
||||||
|
auto streamSegment = createSegment(audioStream.clone(true), utterance);
|
||||||
|
|
||||||
|
// Get words
|
||||||
|
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), decoder, wordRecognitionProgressSink);
|
||||||
|
for (Timed<string> timedWord : words) {
|
||||||
|
timedWord.getTimeRange().shift(utterance.getStart());
|
||||||
|
logging::logTimedEvent("word", timedWord);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up words in dictionary
|
||||||
|
vector<s3wid_t> wordIds;
|
||||||
|
for (const auto& timedWord : words) {
|
||||||
|
wordIds.push_back(getWordId(timedWord.getValue(), *decoder.dict));
|
||||||
|
}
|
||||||
|
if (wordIds.empty()) return Timeline<Phone>();
|
||||||
|
|
||||||
|
// Align the words' phones with speech
|
||||||
|
Timeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), decoder, alignmentProgressSink)
|
||||||
|
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
||||||
|
segmentPhones.shift(utterance.getStart());
|
||||||
|
for (const auto& timedPhone : segmentPhones) {
|
||||||
|
logging::logTimedEvent("phone", timedPhone);
|
||||||
|
}
|
||||||
|
|
||||||
|
return segmentPhones;
|
||||||
|
}
|
||||||
|
|
||||||
BoundedTimeline<Phone> detectPhones(
|
BoundedTimeline<Phone> detectPhones(
|
||||||
unique_ptr<AudioStream> audioStream,
|
unique_ptr<AudioStream> audioStream,
|
||||||
optional<u32string> dialog,
|
optional<u32string> dialog,
|
||||||
ProgressSink& progressSink)
|
ProgressSink& progressSink)
|
||||||
{
|
{
|
||||||
|
ProgressMerger totalProgressMerger(progressSink);
|
||||||
|
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
|
||||||
|
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
|
||||||
|
|
||||||
|
// Make sure audio stream has no DC offset
|
||||||
|
audioStream = removeDCOffset(std::move(audioStream));
|
||||||
|
|
||||||
|
// Split audio into utterances
|
||||||
|
BoundedTimeline<void> utterances;
|
||||||
|
try {
|
||||||
|
utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
|
||||||
|
}
|
||||||
|
catch (...) {
|
||||||
|
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
|
||||||
|
}
|
||||||
|
|
||||||
// Discard Pocketsphinx output
|
// Discard Pocketsphinx output
|
||||||
err_set_logfp(nullptr);
|
err_set_logfp(nullptr);
|
||||||
|
|
||||||
// Redirect Pocketsphinx output to log
|
// Redirect Pocketsphinx output to log
|
||||||
err_set_callback(sphinxLogCallback, nullptr);
|
err_set_callback(sphinxLogCallback, nullptr);
|
||||||
|
|
||||||
// Make sure audio stream has no DC offset
|
// Prepare pool of decoders
|
||||||
audioStream = removeDCOffset(std::move(audioStream));
|
std::stack<lambda_unique_ptr<ps_decoder_t>> decoderPool;
|
||||||
|
std::mutex decoderPoolMutex;
|
||||||
|
auto getDecoder = [&] {
|
||||||
|
std::lock_guard<std::mutex> lock(decoderPoolMutex);
|
||||||
|
if (decoderPool.empty()) {
|
||||||
|
decoderPool.push(createDecoder(dialog));
|
||||||
|
}
|
||||||
|
auto decoder = std::move(decoderPool.top());
|
||||||
|
decoderPool.pop();
|
||||||
|
return std::move(decoder);
|
||||||
|
};
|
||||||
|
auto returnDecoder = [&](lambda_unique_ptr<ps_decoder_t> decoder) {
|
||||||
|
std::lock_guard<std::mutex> lock(decoderPoolMutex);
|
||||||
|
decoderPool.push(std::move(decoder));
|
||||||
|
};
|
||||||
|
|
||||||
ProgressMerger totalProgressMerger(progressSink);
|
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
||||||
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
|
std::mutex resultMutex, audioStreamMutex;
|
||||||
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
|
auto processUtterance = [&](Timed<void> timedUtterance, ProgressSink& utteranceProgressSink) {
|
||||||
|
logging::logTimedEvent("utterance", timedUtterance.getTimeRange(), string(""));
|
||||||
|
|
||||||
|
// Detect phones for utterance
|
||||||
|
auto decoder = getDecoder();
|
||||||
|
std::unique_ptr<AudioStream> audioStreamCopy;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(audioStreamMutex);
|
||||||
|
audioStreamCopy = audioStream->clone(true);
|
||||||
|
}
|
||||||
|
Timeline<Phone> phones =
|
||||||
|
utteranceToPhones(*audioStreamCopy, timedUtterance.getTimeRange(), *decoder, utteranceProgressSink);
|
||||||
|
returnDecoder(std::move(decoder));
|
||||||
|
|
||||||
|
// Copy phones to result timeline
|
||||||
|
std::lock_guard<std::mutex> lock(resultMutex);
|
||||||
|
for (const auto& timedPhone : phones) {
|
||||||
|
result.set(timedPhone);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getUtteranceProgressWeight = [](const Timed<void> timedUtterance) {
|
||||||
|
return timedUtterance.getTimeRange().getLength().count();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Perform speech recognition
|
||||||
try {
|
try {
|
||||||
// Split audio into utterances
|
// Determine how many parallel threads to use
|
||||||
BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
|
int threadCount = std::min({
|
||||||
|
// Don't use more threads than there are CPU cores
|
||||||
// For progress reporting: weigh utterances by length
|
ThreadPool::getRecommendedThreadCount(),
|
||||||
ProgressMerger dialogProgressMerger(dialogProgressSink);
|
// Don't use more threads than there are utterances to be processed
|
||||||
vector<ProgressSink*> utteranceProgressSinks;
|
static_cast<int>(utterances.size()),
|
||||||
for (const auto& timedUtterance : utterances) {
|
// Don't waste time creating additional threads (and decoders!) if the recording is short
|
||||||
utteranceProgressSinks.push_back(&dialogProgressMerger.addSink(timedUtterance.getTimeRange().getLength().count()));
|
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
|
||||||
}
|
});
|
||||||
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
|
ThreadPool threadPool(threadCount);
|
||||||
|
|
||||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
|
||||||
std::mutex resultMutex;
|
|
||||||
logging::debug("Speech recognition -- start");
|
logging::debug("Speech recognition -- start");
|
||||||
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
|
threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight);
|
||||||
ThreadPool threadPool;
|
|
||||||
for (const auto& timedUtterance : utterances) {
|
|
||||||
threadPool.addJob([&] {
|
|
||||||
auto decoder = decoderPool.acquire();
|
|
||||||
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
|
|
||||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
|
||||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
|
||||||
|
|
||||||
const TimeRange timeRange = timedUtterance.getTimeRange();
|
|
||||||
logging::logTimedEvent("utterance", timeRange, string(""));
|
|
||||||
|
|
||||||
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
|
|
||||||
|
|
||||||
// Get words
|
|
||||||
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
|
|
||||||
for (Timed<string> timedWord : words) {
|
|
||||||
timedWord.getTimeRange().shift(timedUtterance.getStart());
|
|
||||||
logging::logTimedEvent("word", timedWord);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look up words in dictionary
|
|
||||||
vector<s3wid_t> wordIds;
|
|
||||||
for (const auto& timedWord : words) {
|
|
||||||
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
|
|
||||||
}
|
|
||||||
if (wordIds.empty()) return;
|
|
||||||
|
|
||||||
// Align the words' phones with speech
|
|
||||||
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
|
|
||||||
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
|
||||||
segmentPhones.shift(timedUtterance.getStart());
|
|
||||||
for (const auto& timedPhone : segmentPhones) {
|
|
||||||
logging::logTimedEvent("phone", timedPhone);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill result
|
|
||||||
std::lock_guard<std::mutex> lock(resultMutex);
|
|
||||||
for (const auto& timedPhone : segmentPhones) {
|
|
||||||
result.set(timedPhone);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
threadPool.waitAll();
|
threadPool.waitAll();
|
||||||
logging::debug("Speech recognition -- end");
|
logging::debug("Speech recognition -- end");
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue