Sped up VAD via multithreading

This commit is contained in:
Daniel Wolf 2016-06-26 14:02:48 +02:00
parent 84097756c8
commit 3a0a38575f
1 changed files with 38 additions and 8 deletions

View File

@ -7,16 +7,15 @@
#include <webrtc/common_audio/vad/include/webrtc_vad.h> #include <webrtc/common_audio/vad/include/webrtc_vad.h>
#include "processing.h" #include "processing.h"
#include <gsl_util.h> #include <gsl_util.h>
#include <ThreadPool.h>
#include "AudioStreamSegment.h"
using std::vector; using std::vector;
using boost::adaptors::transformed; using boost::adaptors::transformed;
using fmt::format; using fmt::format;
using std::runtime_error; using std::runtime_error;
BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) { BoundedTimeline<void> webRtcDetectVoiceActivity(AudioStream& audioStream, ProgressSink& progressSink) {
// Prepare audio for VAD
audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000));
VadInst* vadHandle = WebRtcVad_Create(); VadInst* vadHandle = WebRtcVad_Create();
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle."); if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
@ -30,17 +29,48 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness."); if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
// Detect activity // Detect activity
BoundedTimeline<void> activity(audioStream->getTruncatedRange()); BoundedTimeline<void> activity(audioStream.getTruncatedRange());
centiseconds time = centiseconds::zero(); centiseconds time = centiseconds::zero();
auto processBuffer = [&](const vector<int16_t>& buffer) { auto processBuffer = [&](const vector<int16_t>& buffer) {
bool isActive = WebRtcVad_Process(vadHandle, audioStream->getSampleRate(), buffer.data(), buffer.size()) == 1; bool isActive = WebRtcVad_Process(vadHandle, audioStream.getSampleRate(), buffer.data(), buffer.size()) == 1;
if (isActive) { if (isActive) {
activity.set(time, time + centiseconds(1)); activity.set(time, time + centiseconds(1));
} }
time += centiseconds(1); time += centiseconds(1);
}; };
const size_t bufferCapacity = audioStream->getSampleRate() / 100; const size_t bufferCapacity = audioStream.getSampleRate() / 100;
process16bitAudioStream(*audioStream.get(), processBuffer, bufferCapacity, progressSink); process16bitAudioStream(audioStream, processBuffer, bufferCapacity, progressSink);
return activity;
}
BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) {
// Prepare audio for VAD
audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000));
BoundedTimeline<void> activity(audioStream->getTruncatedRange());
std::mutex activityMutex;
// Split audio into segments and perform parallel VAD
ThreadPool threadPool;
int segmentCount = threadPool.getThreadCount();
centiseconds audioLength = audioStream->getTruncatedRange().getLength();
ProgressMerger progressMerger(progressSink);
for (int i = 0; i < segmentCount; ++i) {
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
ProgressSink& segmentProgressSink = progressMerger.addSink(1.0);
threadPool.addJob([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
std::unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
std::lock_guard<std::mutex> lock(activityMutex);
for (auto activityRange : activitySegment) {
activityRange.getTimeRange().shift(segmentRange.getStart());
activity.set(activityRange);
}
});
}
threadPool.waitAll();
// Fill small gaps in activity // Fill small gaps in activity
const centiseconds maxGap(5); const centiseconds maxGap(5);