diff --git a/src/audio/voiceActivityDetection.cpp b/src/audio/voiceActivityDetection.cpp index 59324ef..b121071 100644 --- a/src/audio/voiceActivityDetection.cpp +++ b/src/audio/voiceActivityDetection.cpp @@ -7,16 +7,15 @@ #include #include "processing.h" #include +#include +#include "AudioStreamSegment.h" using std::vector; using boost::adaptors::transformed; using fmt::format; using std::runtime_error; -BoundedTimeline detectVoiceActivity(std::unique_ptr audioStream, ProgressSink& progressSink) { - // Prepare audio for VAD - audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000)); - +BoundedTimeline webRtcDetectVoiceActivity(AudioStream& audioStream, ProgressSink& progressSink) { VadInst* vadHandle = WebRtcVad_Create(); if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle."); @@ -30,17 +29,48 @@ BoundedTimeline detectVoiceActivity(std::unique_ptr audioStre if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness."); // Detect activity - BoundedTimeline activity(audioStream->getTruncatedRange()); + BoundedTimeline activity(audioStream.getTruncatedRange()); centiseconds time = centiseconds::zero(); auto processBuffer = [&](const vector& buffer) { - bool isActive = WebRtcVad_Process(vadHandle, audioStream->getSampleRate(), buffer.data(), buffer.size()) == 1; + bool isActive = WebRtcVad_Process(vadHandle, audioStream.getSampleRate(), buffer.data(), buffer.size()) == 1; if (isActive) { activity.set(time, time + centiseconds(1)); } time += centiseconds(1); }; - const size_t bufferCapacity = audioStream->getSampleRate() / 100; - process16bitAudioStream(*audioStream.get(), processBuffer, bufferCapacity, progressSink); + const size_t bufferCapacity = audioStream.getSampleRate() / 100; + process16bitAudioStream(audioStream, processBuffer, bufferCapacity, progressSink); + + return activity; +} + +BoundedTimeline detectVoiceActivity(std::unique_ptr audioStream, ProgressSink& progressSink) { + // Prepare audio for VAD + audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000)); + + BoundedTimeline activity(audioStream->getTruncatedRange()); + std::mutex activityMutex; + + // Split audio into segments and perform parallel VAD + ThreadPool threadPool; + int segmentCount = threadPool.getThreadCount(); + centiseconds audioLength = audioStream->getTruncatedRange().getLength(); + ProgressMerger progressMerger(progressSink); + for (int i = 0; i < segmentCount; ++i) { + TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount); + ProgressSink& segmentProgressSink = progressMerger.addSink(1.0); + threadPool.addJob([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] { + std::unique_ptr audioSegment = createSegment(audioStream->clone(false), segmentRange); + BoundedTimeline activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink); + + std::lock_guard lock(activityMutex); + for (auto activityRange : activitySegment) { + activityRange.getTimeRange().shift(segmentRange.getStart()); + activity.set(activityRange); + } + }); + } + threadPool.waitAll(); // Fill small gaps in activity const centiseconds maxGap(5);