Do not use multithreading for VAD
WebRTC adapts to the audio signal. If we slice the audio clip into multiple shorter clips, then perform VAD on them in parallel, the result may not be as good.
This commit is contained in:
parent
a723942f22
commit
2bbad258c0
|
@ -212,6 +212,7 @@
|
||||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=qwhy/@EntryIndexedValue">True</s:Boolean>
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=qwhy/@EntryIndexedValue">True</s:Boolean>
|
||||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=rbegin/@EntryIndexedValue">True</s:Boolean>
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=rbegin/@EntryIndexedValue">True</s:Boolean>
|
||||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=resample/@EntryIndexedValue">True</s:Boolean>
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=resample/@EntryIndexedValue">True</s:Boolean>
|
||||||
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=resamples/@EntryIndexedValue">True</s:Boolean>
|
||||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=retime/@EntryIndexedValue">True</s:Boolean>
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=retime/@EntryIndexedValue">True</s:Boolean>
|
||||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=retimed/@EntryIndexedValue">True</s:Boolean>
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=retimed/@EntryIndexedValue">True</s:Boolean>
|
||||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=synth/@EntryIndexedValue">True</s:Boolean>
|
<s:Boolean x:Key="/Default/UserDictionary/Words/=synth/@EntryIndexedValue">True</s:Boolean>
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#include "processing.h"
|
#include "processing.h"
|
||||||
#include <gsl_util.h>
|
#include <gsl_util.h>
|
||||||
#include "tools/parallel.h"
|
#include "tools/parallel.h"
|
||||||
#include "AudioSegment.h"
|
#include <webrtc/common_audio/vad/vad_core.h>
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using boost::adaptors::transformed;
|
using boost::adaptors::transformed;
|
||||||
|
@ -16,124 +16,65 @@ using fmt::format;
|
||||||
using std::runtime_error;
|
using std::runtime_error;
|
||||||
using std::unique_ptr;
|
using std::unique_ptr;
|
||||||
|
|
||||||
JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
|
JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||||
const AudioClip& audioClip,
|
const AudioClip& inputAudioClip,
|
||||||
ProgressSink& progressSink
|
ProgressSink& progressSink
|
||||||
) {
|
) {
|
||||||
|
// Prepare audio for VAD
|
||||||
|
constexpr int webRtcSamplingRate = 8000;
|
||||||
|
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone()
|
||||||
|
| resample(webRtcSamplingRate)
|
||||||
|
| removeDcOffset();
|
||||||
|
|
||||||
VadInst* vadHandle = WebRtcVad_Create();
|
VadInst* vadHandle = WebRtcVad_Create();
|
||||||
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
|
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
|
||||||
|
|
||||||
auto freeHandle = gsl::finally([&]() { WebRtcVad_Free(vadHandle); });
|
auto freeHandle = gsl::finally([&]() { WebRtcVad_Free(vadHandle); });
|
||||||
|
|
||||||
int error = WebRtcVad_Init(vadHandle);
|
int error = WebRtcVad_Init(vadHandle);
|
||||||
if (error) throw runtime_error("Error initializing WebRTC VAD handle.");
|
if (error) throw runtime_error("Error initializing WebRTC VAD.");
|
||||||
|
|
||||||
const int aggressiveness = 2; // 0..3. The higher, the more is cut off.
|
const int aggressiveness = 2; // 0..3. The higher, the more is cut off.
|
||||||
error = WebRtcVad_set_mode(vadHandle, aggressiveness);
|
error = WebRtcVad_set_mode(vadHandle, aggressiveness);
|
||||||
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
||||||
|
|
||||||
ProgressMerger progressMerger(progressSink);
|
|
||||||
ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
|
|
||||||
ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
|
|
||||||
|
|
||||||
// Detect activity
|
// Detect activity
|
||||||
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
|
JoiningBoundedTimeline<void> activity(audioClip->getTruncatedRange());
|
||||||
centiseconds time = 0_cs;
|
centiseconds time = 0_cs;
|
||||||
const size_t bufferCapacity = audioClip.getSampleRate() / 100;
|
const size_t frameSize = webRtcSamplingRate / 100;
|
||||||
const auto processBuffer = [&](const vector<int16_t>& buffer) {
|
const auto processBuffer = [&](const vector<int16_t>& buffer) {
|
||||||
// WebRTC is picky regarding buffer size
|
// WebRTC is picky regarding buffer size
|
||||||
if (buffer.size() < bufferCapacity) return;
|
if (buffer.size() < frameSize) return;
|
||||||
|
|
||||||
const int result = WebRtcVad_Process(
|
const int result = WebRtcVad_Process(
|
||||||
vadHandle,
|
vadHandle,
|
||||||
audioClip.getSampleRate(),
|
webRtcSamplingRate,
|
||||||
buffer.data(),
|
buffer.data(),
|
||||||
buffer.size()
|
buffer.size()
|
||||||
) == 1;
|
);
|
||||||
if (result == -1) throw runtime_error("Error processing audio buffer using WebRTC VAD.");
|
if (result == -1) throw runtime_error("Error processing audio buffer using WebRTC VAD.");
|
||||||
|
|
||||||
const bool isActive = result != 0;
|
// Ignore the result of WebRtcVad_Process, instead directly interpret the internal VAD flag.
|
||||||
|
// The result of WebRtcVad_Process stays 1 for a number of frames after the last detected
|
||||||
|
// activity.
|
||||||
|
const bool isActive = reinterpret_cast<VadInstT*>(vadHandle)->vad == 1;
|
||||||
|
|
||||||
if (isActive) {
|
if (isActive) {
|
||||||
activity.set(time, time + 1_cs);
|
activity.set(time, time + 1_cs);
|
||||||
}
|
}
|
||||||
|
|
||||||
time += 1_cs;
|
time += 1_cs;
|
||||||
};
|
};
|
||||||
process16bitAudioClip(audioClip, processBuffer, bufferCapacity, pass1ProgressSink);
|
process16bitAudioClip(*audioClip, processBuffer, frameSize, progressSink);
|
||||||
|
|
||||||
// WebRTC adapts to the audio. This means results may not be correct at the very beginning.
|
|
||||||
// It sometimes returns false activity at the very beginning, mistaking the background noise for
|
|
||||||
// speech.
|
|
||||||
// So we delete the first recognized utterance and re-process the corresponding audio segment.
|
|
||||||
if (!activity.empty()) {
|
|
||||||
TimeRange firstActivity = activity.begin()->getTimeRange();
|
|
||||||
activity.clear(firstActivity);
|
|
||||||
const unique_ptr<AudioClip> streamStart = audioClip.clone()
|
|
||||||
| segment(TimeRange(0_cs, firstActivity.getEnd()));
|
|
||||||
time = 0_cs;
|
|
||||||
process16bitAudioClip(*streamStart, processBuffer, bufferCapacity, pass2ProgressSink);
|
|
||||||
}
|
|
||||||
|
|
||||||
return activity;
|
|
||||||
}
|
|
||||||
|
|
||||||
JoiningBoundedTimeline<void> detectVoiceActivity(
|
|
||||||
const AudioClip& inputAudioClip,
|
|
||||||
int maxThreadCount,
|
|
||||||
ProgressSink& progressSink
|
|
||||||
) {
|
|
||||||
// Prepare audio for VAD
|
|
||||||
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone()
|
|
||||||
| resample(8000) // Convert to the internal sampling rate of WebRTC
|
|
||||||
| removeDcOffset();
|
|
||||||
|
|
||||||
JoiningBoundedTimeline<void> activity(audioClip->getTruncatedRange());
|
|
||||||
std::mutex activityMutex;
|
|
||||||
|
|
||||||
// Split audio into segments and perform parallel VAD
|
|
||||||
const int segmentCount = maxThreadCount;
|
|
||||||
const centiseconds audioDuration = audioClip->getTruncatedRange().getDuration();
|
|
||||||
vector<TimeRange> audioSegments;
|
|
||||||
for (int i = 0; i < segmentCount; ++i) {
|
|
||||||
TimeRange segmentRange = TimeRange(
|
|
||||||
i * audioDuration / segmentCount,
|
|
||||||
(i + 1) * audioDuration / segmentCount
|
|
||||||
);
|
|
||||||
audioSegments.push_back(segmentRange);
|
|
||||||
}
|
|
||||||
runParallel(
|
|
||||||
"VAD",
|
|
||||||
[&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
|
||||||
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
|
|
||||||
JoiningBoundedTimeline<void> activitySegment =
|
|
||||||
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(activityMutex);
|
|
||||||
for (auto activityRange : activitySegment) {
|
|
||||||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
|
||||||
activity.set(activityRange);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
audioSegments,
|
|
||||||
segmentCount,
|
|
||||||
progressSink
|
|
||||||
);
|
|
||||||
|
|
||||||
// Fill small gaps in activity
|
// Fill small gaps in activity
|
||||||
const centiseconds maxGap(5);
|
const centiseconds maxGap(10);
|
||||||
for (const auto& pair : getPairs(activity)) {
|
for (const auto& pair : getPairs(activity)) {
|
||||||
if (pair.second.getStart() - pair.first.getEnd() <= maxGap) {
|
if (pair.second.getStart() - pair.first.getEnd() <= maxGap) {
|
||||||
activity.set(pair.first.getEnd(), pair.second.getStart());
|
activity.set(pair.first.getEnd(), pair.second.getStart());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shorten activities. WebRTC adds a bit of buffer at the end.
|
|
||||||
const centiseconds tail(5);
|
|
||||||
for (const auto& utterance : JoiningBoundedTimeline<void>(activity)) {
|
|
||||||
if (utterance.getDuration() > tail && utterance.getEnd() < audioDuration) {
|
|
||||||
activity.clear(utterance.getEnd() - tail, utterance.getEnd());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logging::debugFormat(
|
logging::debugFormat(
|
||||||
"Found {} sections of voice activity: {}",
|
"Found {} sections of voice activity: {}",
|
||||||
activity.size(),
|
activity.size(),
|
||||||
|
|
|
@ -5,6 +5,5 @@
|
||||||
|
|
||||||
JoiningBoundedTimeline<void> detectVoiceActivity(
|
JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||||
const AudioClip& audioClip,
|
const AudioClip& audioClip,
|
||||||
int maxThreadCount,
|
|
||||||
ProgressSink& progressSink
|
ProgressSink& progressSink
|
||||||
);
|
);
|
||||||
|
|
|
@ -102,7 +102,7 @@ BoundedTimeline<Phone> recognizePhones(
|
||||||
// Split audio into utterances
|
// Split audio into utterances
|
||||||
JoiningBoundedTimeline<void> utterances;
|
JoiningBoundedTimeline<void> utterances;
|
||||||
try {
|
try {
|
||||||
utterances = detectVoiceActivity(*audioClip, maxThreadCount, voiceActivationProgressSink);
|
utterances = detectVoiceActivity(*audioClip, voiceActivationProgressSink);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
|
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue