Sped up VAD via multithreading
This commit is contained in:
parent
84097756c8
commit
3a0a38575f
|
@ -7,16 +7,15 @@
|
||||||
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
|
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
|
||||||
#include "processing.h"
|
#include "processing.h"
|
||||||
#include <gsl_util.h>
|
#include <gsl_util.h>
|
||||||
|
#include <ThreadPool.h>
|
||||||
|
#include "AudioStreamSegment.h"
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using boost::adaptors::transformed;
|
using boost::adaptors::transformed;
|
||||||
using fmt::format;
|
using fmt::format;
|
||||||
using std::runtime_error;
|
using std::runtime_error;
|
||||||
|
|
||||||
BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) {
|
BoundedTimeline<void> webRtcDetectVoiceActivity(AudioStream& audioStream, ProgressSink& progressSink) {
|
||||||
// Prepare audio for VAD
|
|
||||||
audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000));
|
|
||||||
|
|
||||||
VadInst* vadHandle = WebRtcVad_Create();
|
VadInst* vadHandle = WebRtcVad_Create();
|
||||||
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
|
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
|
||||||
|
|
||||||
|
@ -30,17 +29,48 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
||||||
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
||||||
|
|
||||||
// Detect activity
|
// Detect activity
|
||||||
BoundedTimeline<void> activity(audioStream->getTruncatedRange());
|
BoundedTimeline<void> activity(audioStream.getTruncatedRange());
|
||||||
centiseconds time = centiseconds::zero();
|
centiseconds time = centiseconds::zero();
|
||||||
auto processBuffer = [&](const vector<int16_t>& buffer) {
|
auto processBuffer = [&](const vector<int16_t>& buffer) {
|
||||||
bool isActive = WebRtcVad_Process(vadHandle, audioStream->getSampleRate(), buffer.data(), buffer.size()) == 1;
|
bool isActive = WebRtcVad_Process(vadHandle, audioStream.getSampleRate(), buffer.data(), buffer.size()) == 1;
|
||||||
if (isActive) {
|
if (isActive) {
|
||||||
activity.set(time, time + centiseconds(1));
|
activity.set(time, time + centiseconds(1));
|
||||||
}
|
}
|
||||||
time += centiseconds(1);
|
time += centiseconds(1);
|
||||||
};
|
};
|
||||||
const size_t bufferCapacity = audioStream->getSampleRate() / 100;
|
const size_t bufferCapacity = audioStream.getSampleRate() / 100;
|
||||||
process16bitAudioStream(*audioStream.get(), processBuffer, bufferCapacity, progressSink);
|
process16bitAudioStream(audioStream, processBuffer, bufferCapacity, progressSink);
|
||||||
|
|
||||||
|
return activity;
|
||||||
|
}
|
||||||
|
|
||||||
|
BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) {
|
||||||
|
// Prepare audio for VAD
|
||||||
|
audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000));
|
||||||
|
|
||||||
|
BoundedTimeline<void> activity(audioStream->getTruncatedRange());
|
||||||
|
std::mutex activityMutex;
|
||||||
|
|
||||||
|
// Split audio into segments and perform parallel VAD
|
||||||
|
ThreadPool threadPool;
|
||||||
|
int segmentCount = threadPool.getThreadCount();
|
||||||
|
centiseconds audioLength = audioStream->getTruncatedRange().getLength();
|
||||||
|
ProgressMerger progressMerger(progressSink);
|
||||||
|
for (int i = 0; i < segmentCount; ++i) {
|
||||||
|
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
||||||
|
ProgressSink& segmentProgressSink = progressMerger.addSink(1.0);
|
||||||
|
threadPool.addJob([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
|
||||||
|
std::unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
||||||
|
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(activityMutex);
|
||||||
|
for (auto activityRange : activitySegment) {
|
||||||
|
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||||
|
activity.set(activityRange);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
threadPool.waitAll();
|
||||||
|
|
||||||
// Fill small gaps in activity
|
// Fill small gaps in activity
|
||||||
const centiseconds maxGap(5);
|
const centiseconds maxGap(5);
|
||||||
|
|
Loading…
Reference in New Issue