Do not use multithreading for VAD

WebRTC adapts to the audio signal. If we slice the audio clip into multiple
shorter clips, then perform VAD on them in parallel, the result may not be as
good.
This commit is contained in:
Daniel Wolf 2019-01-04 20:28:57 +01:00
parent a723942f22
commit 2bbad258c0
4 changed files with 25 additions and 84 deletions

View File

@ -212,6 +212,7 @@
<s:Boolean x:Key="/Default/UserDictionary/Words/=qwhy/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=rbegin/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=resample/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=resamples/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=retime/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=retimed/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=synth/@EntryIndexedValue">True</s:Boolean>

View File

@ -8,7 +8,7 @@
#include "processing.h"
#include <gsl_util.h>
#include "tools/parallel.h"
#include "AudioSegment.h"
#include <webrtc/common_audio/vad/vad_core.h>
using std::vector;
using boost::adaptors::transformed;
@ -16,124 +16,65 @@ using fmt::format;
using std::runtime_error;
using std::unique_ptr;
JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
const AudioClip& audioClip,
JoiningBoundedTimeline<void> detectVoiceActivity(
const AudioClip& inputAudioClip,
ProgressSink& progressSink
) {
// Prepare audio for VAD
constexpr int webRtcSamplingRate = 8000;
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone()
| resample(webRtcSamplingRate)
| removeDcOffset();
VadInst* vadHandle = WebRtcVad_Create();
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
auto freeHandle = gsl::finally([&]() { WebRtcVad_Free(vadHandle); });
int error = WebRtcVad_Init(vadHandle);
if (error) throw runtime_error("Error initializing WebRTC VAD handle.");
if (error) throw runtime_error("Error initializing WebRTC VAD.");
const int aggressiveness = 2; // 0..3. The higher, the more is cut off.
error = WebRtcVad_set_mode(vadHandle, aggressiveness);
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
ProgressMerger progressMerger(progressSink);
ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
// Detect activity
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
JoiningBoundedTimeline<void> activity(audioClip->getTruncatedRange());
centiseconds time = 0_cs;
const size_t bufferCapacity = audioClip.getSampleRate() / 100;
const size_t frameSize = webRtcSamplingRate / 100;
const auto processBuffer = [&](const vector<int16_t>& buffer) {
// WebRTC is picky regarding buffer size
if (buffer.size() < bufferCapacity) return;
if (buffer.size() < frameSize) return;
const int result = WebRtcVad_Process(
vadHandle,
audioClip.getSampleRate(),
webRtcSamplingRate,
buffer.data(),
buffer.size()
) == 1;
);
if (result == -1) throw runtime_error("Error processing audio buffer using WebRTC VAD.");
const bool isActive = result != 0;
// Ignore the result of WebRtcVad_Process, instead directly interpret the internal VAD flag.
// The result of WebRtcVad_Process stays 1 for a number of frames after the last detected
// activity.
const bool isActive = reinterpret_cast<VadInstT*>(vadHandle)->vad == 1;
if (isActive) {
activity.set(time, time + 1_cs);
}
time += 1_cs;
};
process16bitAudioClip(audioClip, processBuffer, bufferCapacity, pass1ProgressSink);
// WebRTC adapts to the audio. This means results may not be correct at the very beginning.
// It sometimes returns false activity at the very beginning, mistaking the background noise for
// speech.
// So we delete the first recognized utterance and re-process the corresponding audio segment.
if (!activity.empty()) {
TimeRange firstActivity = activity.begin()->getTimeRange();
activity.clear(firstActivity);
const unique_ptr<AudioClip> streamStart = audioClip.clone()
| segment(TimeRange(0_cs, firstActivity.getEnd()));
time = 0_cs;
process16bitAudioClip(*streamStart, processBuffer, bufferCapacity, pass2ProgressSink);
}
return activity;
}
JoiningBoundedTimeline<void> detectVoiceActivity(
const AudioClip& inputAudioClip,
int maxThreadCount,
ProgressSink& progressSink
) {
// Prepare audio for VAD
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone()
| resample(8000) // Convert to the internal sampling rate of WebRTC
| removeDcOffset();
JoiningBoundedTimeline<void> activity(audioClip->getTruncatedRange());
std::mutex activityMutex;
// Split audio into segments and perform parallel VAD
const int segmentCount = maxThreadCount;
const centiseconds audioDuration = audioClip->getTruncatedRange().getDuration();
vector<TimeRange> audioSegments;
for (int i = 0; i < segmentCount; ++i) {
TimeRange segmentRange = TimeRange(
i * audioDuration / segmentCount,
(i + 1) * audioDuration / segmentCount
);
audioSegments.push_back(segmentRange);
}
runParallel(
"VAD",
[&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
JoiningBoundedTimeline<void> activitySegment =
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
std::lock_guard<std::mutex> lock(activityMutex);
for (auto activityRange : activitySegment) {
activityRange.getTimeRange().shift(segmentRange.getStart());
activity.set(activityRange);
}
},
audioSegments,
segmentCount,
progressSink
);
process16bitAudioClip(*audioClip, processBuffer, frameSize, progressSink);
// Fill small gaps in activity
const centiseconds maxGap(5);
const centiseconds maxGap(10);
for (const auto& pair : getPairs(activity)) {
if (pair.second.getStart() - pair.first.getEnd() <= maxGap) {
activity.set(pair.first.getEnd(), pair.second.getStart());
}
}
// Shorten activities. WebRTC adds a bit of buffer at the end.
const centiseconds tail(5);
for (const auto& utterance : JoiningBoundedTimeline<void>(activity)) {
if (utterance.getDuration() > tail && utterance.getEnd() < audioDuration) {
activity.clear(utterance.getEnd() - tail, utterance.getEnd());
}
}
logging::debugFormat(
"Found {} sections of voice activity: {}",
activity.size(),

View File

@ -5,6 +5,5 @@
JoiningBoundedTimeline<void> detectVoiceActivity(
const AudioClip& audioClip,
int maxThreadCount,
ProgressSink& progressSink
);

View File

@ -102,7 +102,7 @@ BoundedTimeline<Phone> recognizePhones(
// Split audio into utterances
JoiningBoundedTimeline<void> utterances;
try {
utterances = detectVoiceActivity(*audioClip, maxThreadCount, voiceActivationProgressSink);
utterances = detectVoiceActivity(*audioClip, voiceActivationProgressSink);
} catch (...) {
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
}