Do not use multithreading for VAD
WebRTC adapts to the audio signal. If we slice the audio clip into multiple shorter clips, then perform VAD on them in parallel, the result may not be as good.
This commit is contained in:
parent
a723942f22
commit
2bbad258c0
|
@ -212,6 +212,7 @@
|
|||
<s:Boolean x:Key="/Default/UserDictionary/Words/=qwhy/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=rbegin/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=resample/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=resamples/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=retime/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=retimed/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/UserDictionary/Words/=synth/@EntryIndexedValue">True</s:Boolean>
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include "processing.h"
|
||||
#include <gsl_util.h>
|
||||
#include "tools/parallel.h"
|
||||
#include "AudioSegment.h"
|
||||
#include <webrtc/common_audio/vad/vad_core.h>
|
||||
|
||||
using std::vector;
|
||||
using boost::adaptors::transformed;
|
||||
|
@ -16,124 +16,65 @@ using fmt::format;
|
|||
using std::runtime_error;
|
||||
using std::unique_ptr;
|
||||
|
||||
JoiningBoundedTimeline<void> webRtcDetectVoiceActivity(
|
||||
const AudioClip& audioClip,
|
||||
JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||
const AudioClip& inputAudioClip,
|
||||
ProgressSink& progressSink
|
||||
) {
|
||||
// Prepare audio for VAD
|
||||
constexpr int webRtcSamplingRate = 8000;
|
||||
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone()
|
||||
| resample(webRtcSamplingRate)
|
||||
| removeDcOffset();
|
||||
|
||||
VadInst* vadHandle = WebRtcVad_Create();
|
||||
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
|
||||
|
||||
auto freeHandle = gsl::finally([&]() { WebRtcVad_Free(vadHandle); });
|
||||
|
||||
int error = WebRtcVad_Init(vadHandle);
|
||||
if (error) throw runtime_error("Error initializing WebRTC VAD handle.");
|
||||
if (error) throw runtime_error("Error initializing WebRTC VAD.");
|
||||
|
||||
const int aggressiveness = 2; // 0..3. The higher, the more is cut off.
|
||||
error = WebRtcVad_set_mode(vadHandle, aggressiveness);
|
||||
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
||||
|
||||
ProgressMerger progressMerger(progressSink);
|
||||
ProgressSink& pass1ProgressSink = progressMerger.addSource("VAD pass 1", 1.0);
|
||||
ProgressSink& pass2ProgressSink = progressMerger.addSource("VAD pass 2", 0.3);
|
||||
|
||||
// Detect activity
|
||||
JoiningBoundedTimeline<void> activity(audioClip.getTruncatedRange());
|
||||
JoiningBoundedTimeline<void> activity(audioClip->getTruncatedRange());
|
||||
centiseconds time = 0_cs;
|
||||
const size_t bufferCapacity = audioClip.getSampleRate() / 100;
|
||||
const size_t frameSize = webRtcSamplingRate / 100;
|
||||
const auto processBuffer = [&](const vector<int16_t>& buffer) {
|
||||
// WebRTC is picky regarding buffer size
|
||||
if (buffer.size() < bufferCapacity) return;
|
||||
if (buffer.size() < frameSize) return;
|
||||
|
||||
const int result = WebRtcVad_Process(
|
||||
vadHandle,
|
||||
audioClip.getSampleRate(),
|
||||
webRtcSamplingRate,
|
||||
buffer.data(),
|
||||
buffer.size()
|
||||
) == 1;
|
||||
);
|
||||
if (result == -1) throw runtime_error("Error processing audio buffer using WebRTC VAD.");
|
||||
|
||||
const bool isActive = result != 0;
|
||||
// Ignore the result of WebRtcVad_Process, instead directly interpret the internal VAD flag.
|
||||
// The result of WebRtcVad_Process stays 1 for a number of frames after the last detected
|
||||
// activity.
|
||||
const bool isActive = reinterpret_cast<VadInstT*>(vadHandle)->vad == 1;
|
||||
|
||||
if (isActive) {
|
||||
activity.set(time, time + 1_cs);
|
||||
}
|
||||
|
||||
time += 1_cs;
|
||||
};
|
||||
process16bitAudioClip(audioClip, processBuffer, bufferCapacity, pass1ProgressSink);
|
||||
|
||||
// WebRTC adapts to the audio. This means results may not be correct at the very beginning.
|
||||
// It sometimes returns false activity at the very beginning, mistaking the background noise for
|
||||
// speech.
|
||||
// So we delete the first recognized utterance and re-process the corresponding audio segment.
|
||||
if (!activity.empty()) {
|
||||
TimeRange firstActivity = activity.begin()->getTimeRange();
|
||||
activity.clear(firstActivity);
|
||||
const unique_ptr<AudioClip> streamStart = audioClip.clone()
|
||||
| segment(TimeRange(0_cs, firstActivity.getEnd()));
|
||||
time = 0_cs;
|
||||
process16bitAudioClip(*streamStart, processBuffer, bufferCapacity, pass2ProgressSink);
|
||||
}
|
||||
|
||||
return activity;
|
||||
}
|
||||
|
||||
JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||
const AudioClip& inputAudioClip,
|
||||
int maxThreadCount,
|
||||
ProgressSink& progressSink
|
||||
) {
|
||||
// Prepare audio for VAD
|
||||
const unique_ptr<AudioClip> audioClip = inputAudioClip.clone()
|
||||
| resample(8000) // Convert to the internal sampling rate of WebRTC
|
||||
| removeDcOffset();
|
||||
|
||||
JoiningBoundedTimeline<void> activity(audioClip->getTruncatedRange());
|
||||
std::mutex activityMutex;
|
||||
|
||||
// Split audio into segments and perform parallel VAD
|
||||
const int segmentCount = maxThreadCount;
|
||||
const centiseconds audioDuration = audioClip->getTruncatedRange().getDuration();
|
||||
vector<TimeRange> audioSegments;
|
||||
for (int i = 0; i < segmentCount; ++i) {
|
||||
TimeRange segmentRange = TimeRange(
|
||||
i * audioDuration / segmentCount,
|
||||
(i + 1) * audioDuration / segmentCount
|
||||
);
|
||||
audioSegments.push_back(segmentRange);
|
||||
}
|
||||
runParallel(
|
||||
"VAD",
|
||||
[&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||
const unique_ptr<AudioClip> audioSegment = audioClip->clone() | segment(segmentRange);
|
||||
JoiningBoundedTimeline<void> activitySegment =
|
||||
webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||
|
||||
std::lock_guard<std::mutex> lock(activityMutex);
|
||||
for (auto activityRange : activitySegment) {
|
||||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||
activity.set(activityRange);
|
||||
}
|
||||
},
|
||||
audioSegments,
|
||||
segmentCount,
|
||||
progressSink
|
||||
);
|
||||
process16bitAudioClip(*audioClip, processBuffer, frameSize, progressSink);
|
||||
|
||||
// Fill small gaps in activity
|
||||
const centiseconds maxGap(5);
|
||||
const centiseconds maxGap(10);
|
||||
for (const auto& pair : getPairs(activity)) {
|
||||
if (pair.second.getStart() - pair.first.getEnd() <= maxGap) {
|
||||
activity.set(pair.first.getEnd(), pair.second.getStart());
|
||||
}
|
||||
}
|
||||
|
||||
// Shorten activities. WebRTC adds a bit of buffer at the end.
|
||||
const centiseconds tail(5);
|
||||
for (const auto& utterance : JoiningBoundedTimeline<void>(activity)) {
|
||||
if (utterance.getDuration() > tail && utterance.getEnd() < audioDuration) {
|
||||
activity.clear(utterance.getEnd() - tail, utterance.getEnd());
|
||||
}
|
||||
}
|
||||
|
||||
logging::debugFormat(
|
||||
"Found {} sections of voice activity: {}",
|
||||
activity.size(),
|
||||
|
|
|
@ -5,6 +5,5 @@
|
|||
|
||||
JoiningBoundedTimeline<void> detectVoiceActivity(
|
||||
const AudioClip& audioClip,
|
||||
int maxThreadCount,
|
||||
ProgressSink& progressSink
|
||||
);
|
||||
|
|
|
@ -102,7 +102,7 @@ BoundedTimeline<Phone> recognizePhones(
|
|||
// Split audio into utterances
|
||||
JoiningBoundedTimeline<void> utterances;
|
||||
try {
|
||||
utterances = detectVoiceActivity(*audioClip, maxThreadCount, voiceActivationProgressSink);
|
||||
utterances = detectVoiceActivity(*audioClip, voiceActivationProgressSink);
|
||||
} catch (...) {
|
||||
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue