Using WebRTC for voice activity detection (VAD)

My simple power-based approach wasn't reliable enough.
This commit is contained in:
Daniel Wolf 2016-06-21 22:20:18 +02:00
parent aec3dbae01
commit c6c31a831c
6 changed files with 116 additions and 88 deletions

View File

@ -93,6 +93,31 @@ set_target_properties(gtest_main PROPERTIES FOLDER lib)
# ... GSL # ... GSL
include_directories(SYSTEM "lib/gsl/include") include_directories(SYSTEM "lib/gsl/include")
# ... WebRTC
include_directories(SYSTEM "lib/webrtc-8d2248ff")
set(webRTCFiles
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/cross_correlation.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/division_operations.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/downsample_fast.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/energy.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/get_scaling_square.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/min_max_operations.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/resample_48khz.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/resample_by_2_internal.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/resample_fractional.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/spl_init.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/spl_inl.c
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/vector_scaling_operations.c
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_core.c
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_filterbank.c
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_gmm.c
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_sp.c
lib/webrtc-8d2248ff/webrtc/common_audio/vad/webrtc_vad.c
)
add_library(webRTC ${webRTCFiles})
target_compile_options(webRTC PRIVATE ${disableWarningsFlags})
set_target_properties(webRTC PROPERTIES FOLDER lib)
# ... Flite # ... Flite
include_directories("lib/flite-1.4/include" "lib/flite-1.4") include_directories("lib/flite-1.4/include" "lib/flite-1.4")
set(fliteFiles set(fliteFiles
@ -180,6 +205,7 @@ set(SOURCE_FILES
src/audio/voiceActivityDetection.cpp src/audio/voiceActivityDetection.h src/audio/voiceActivityDetection.cpp src/audio/voiceActivityDetection.h
src/audio/WaveFileReader.cpp src/audio/WaveFileReader.h src/audio/WaveFileReader.cpp src/audio/WaveFileReader.h
src/audio/waveFileWriting.cpp src/audio/waveFileWriting.h src/audio/waveFileWriting.cpp src/audio/waveFileWriting.h
src/audio/processing.cpp src/audio/processing.h
src/stringTools.cpp src/stringTools.h src/stringTools.cpp src/stringTools.h
src/NiceCmdLineOutput.cpp src/NiceCmdLineOutput.h src/NiceCmdLineOutput.cpp src/NiceCmdLineOutput.h
src/TablePrinter.cpp src/TablePrinter.h src/TablePrinter.cpp src/TablePrinter.h
@ -198,7 +224,7 @@ set(SOURCE_FILES
src/tupleHash.h src/tupleHash.h
) )
add_executable(rhubarb ${SOURCE_FILES}) add_executable(rhubarb ${SOURCE_FILES})
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite) target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)
target_compile_options(rhubarb PUBLIC ${enableWarningsFlags}) target_compile_options(rhubarb PUBLIC ${enableWarningsFlags})
# Define test project # Define test project

View File

@ -161,3 +161,17 @@ The [Sound Change Applier](http://www.zompist.com/sounds.htm) and its [rule set
> The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. > The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. > THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
### WebRTC
The [WebRTC](https://chromium.googlesource.com/external/webrtc) library is released under the **3-clause BSD License**.
> Copyright (c) 2011, The WebRTC project authors. All rights reserved.
> Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
> * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
> * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
> * Neither the name of Google nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

40
src/audio/processing.cpp Normal file
View File

@ -0,0 +1,40 @@
#include "processing.h"
using std::function;
using std::vector;
// Converts a float in the range -1..1 to a signed 16-bit int
inline int16_t floatSampleToInt16(float sample) {
sample = std::max(sample, -1.0f);
sample = std::min(sample, 1.0f);
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
}
void process16bitAudioStream(AudioStream& audioStream, function<void(const vector<int16_t>&)> processBuffer, size_t bufferCapacity, ProgressSink& progressSink) {
// Process entire sound stream
vector<int16_t> buffer;
buffer.reserve(bufferCapacity);
int sampleCount = 0;
do {
// Read to buffer
buffer.clear();
while (buffer.size() < bufferCapacity && !audioStream.endOfStream()) {
// Read sample
float floatSample = audioStream.readSample();
int16_t sample = floatSampleToInt16(floatSample);
buffer.push_back(sample);
}
// Process buffer
processBuffer(buffer);
sampleCount += buffer.size();
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream.getSampleCount());
} while (buffer.size());
}
void process16bitAudioStream(AudioStream& audioStream, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
const size_t capacity = 1600; // 0.1 second capacity
process16bitAudioStream(audioStream, processBuffer, capacity, progressSink);
}

9
src/audio/processing.h Normal file
View File

@ -0,0 +1,9 @@
#pragma once
#include <vector>
#include <functional>
#include "audio/AudioStream.h"
#include "ProgressBar.h"
void process16bitAudioStream(AudioStream& audioStream, std::function<void(const std::vector<int16_t>&)> processBuffer, size_t bufferCapacity, ProgressSink& progressSink);
void process16bitAudioStream(AudioStream& audioStream, std::function<void(const std::vector<int16_t>&)> processBuffer, ProgressSink& progressSink);

View File

@ -1,76 +1,45 @@
#include "voiceActivityDetection.h" #include "voiceActivityDetection.h"
#include <audio/DCOffset.h> #include <audio/DCOffset.h>
#include <audio/SampleRateConverter.h> #include <audio/SampleRateConverter.h>
#include <boost/optional/optional.hpp>
#include <logging.h> #include <logging.h>
#include <pairs.h> #include <pairs.h>
#include <boost/range/adaptor/transformed.hpp> #include <boost/range/adaptor/transformed.hpp>
#include <stringTools.h> #include <webrtc/common_audio/vad/include/webrtc_vad.h>
#include "processing.h"
using std::numeric_limits;
using std::vector; using std::vector;
using boost::optional;
using boost::adaptors::transformed; using boost::adaptors::transformed;
using fmt::format; using fmt::format;
using std::runtime_error;
float getRMS(AudioStream& audioStream, int maxSampleCount = numeric_limits<int>::max()) {
double sum = 0; // Use double to prevent rounding errors with large number of summands
int sampleCount;
for (sampleCount = 0; sampleCount < maxSampleCount && !audioStream.endOfStream(); sampleCount++) {
sum += std::pow(static_cast<double>(audioStream.readSample()), 2);
}
return sampleCount > 0 ? static_cast<float>(std::sqrt(sum / sampleCount)) : 0.0f;
}
float getRMS(const vector<float>& rmsSegments) {
if (rmsSegments.empty()) return 0;
double sum = 0; // Use double to prevent rounding errors with large number of summands
for (float rmsSegment : rmsSegments) {
sum += rmsSegment;
}
return static_cast<float>(std::sqrt(sum / rmsSegments.size()));
}
BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) { BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) {
// Make sure audio stream has no DC offset // Prepare audio for VAD
audioStream = removeDCOffset(std::move(audioStream)); audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000));
// Resample to remove noise VadInst* vadHandle = WebRtcVad_Create();
constexpr int maxFrequency = 4000; if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
constexpr int sampleRate = 2 * maxFrequency;
audioStream = convertSampleRate(std::move(audioStream), sampleRate);
// Collect RMS data int error = WebRtcVad_Init(vadHandle);
vector<float> rmsSegments; if (error) throw runtime_error("Error initializing WebRTC VAD handle.");
logging::debug("RMS calculation -- start");
int64_t centisecondCount = (audioStream->getSampleCount() - audioStream->getSampleIndex()) / 100;
for (int cs = 0; cs < centisecondCount; ++cs) {
rmsSegments.push_back(getRMS(*audioStream, sampleRate / 100));
progressSink.reportProgress(static_cast<double>(cs) / centisecondCount);
}
logging::debug("RMS calculation -- end");
const float rms = getRMS(rmsSegments); const int aggressiveness = 1; // 0..3. The higher, the more is cut off.
logging::debugFormat("RMS value: {0:.5f}", rms); error = WebRtcVad_set_mode(vadHandle, aggressiveness);
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
// Detect activity // Detect activity
const float cutoff = rms / 25;
logging::debugFormat("RMS cutoff for voice activity detection: {0:.5f}", cutoff);
BoundedTimeline<void> activity(audioStream->getTruncatedRange()); BoundedTimeline<void> activity(audioStream->getTruncatedRange());
for (centiseconds time = centiseconds::zero(); static_cast<size_t>(time.count()) < rmsSegments.size(); ++time) { centiseconds time = centiseconds::zero();
float currentRMS = rmsSegments[time.count()]; auto processBuffer = [&](const vector<int16_t>& buffer) {
bool active = currentRMS > cutoff; bool isActive = WebRtcVad_Process(vadHandle, audioStream->getSampleRate(), buffer.data(), buffer.size()) == 1;
if (active) { if (isActive) {
activity.set(time, time + centiseconds(1)); activity.set(time, time + centiseconds(1));
} }
} time += centiseconds(1);
};
const size_t bufferCapacity = audioStream->getSampleRate() / 100;
process16bitAudioStream(*audioStream.get(), processBuffer, bufferCapacity, progressSink);
// Pad each activity to prevent cropping WebRtcVad_Free(vadHandle);
const centiseconds padding(3);
for (const auto& element : BoundedTimeline<void>(activity)) {
activity.set(element.getStart() - padding, element.getEnd() + padding);
}
// Fill small gaps in activity // Fill small gaps in activity
const centiseconds maxGap(5); const centiseconds maxGap(5);

View File

@ -1,4 +1,3 @@
#include <iostream>
#include <boost/filesystem.hpp> #include <boost/filesystem.hpp>
#include "phoneExtraction.h" #include "phoneExtraction.h"
#include "audio/SampleRateConverter.h" #include "audio/SampleRateConverter.h"
@ -17,6 +16,7 @@
#include "tokenization.h" #include "tokenization.h"
#include "g2p.h" #include "g2p.h"
#include "ContinuousTimeline.h" #include "ContinuousTimeline.h"
#include "audio/processing.h"
extern "C" { extern "C" {
#include <pocketsphinx.h> #include <pocketsphinx.h>
@ -70,37 +70,6 @@ lambda_unique_ptr<ps_decoder_t> createDecoder() {
return recognizer; return recognizer;
} }
// Converts a float in the range -1..1 to a signed 16-bit int
int16_t floatSampleToInt16(float sample) {
sample = std::max(sample, -1.0f);
sample = std::min(sample, 1.0f);
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
}
void processAudioStream(AudioStream& audioStream16kHz, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
// Process entire sound stream
vector<int16_t> buffer;
const int capacity = 1600; // 0.1 second capacity
buffer.reserve(capacity);
int sampleCount = 0;
do {
// Read to buffer
buffer.clear();
while (buffer.size() < capacity && !audioStream16kHz.endOfStream()) {
// Read sample
float floatSample = audioStream16kHz.readSample();
int16_t sample = floatSampleToInt16(floatSample);
buffer.push_back(sample);
}
// Process buffer
processBuffer(buffer);
sampleCount += buffer.size();
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHz.getSampleCount());
} while (buffer.size());
}
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) { logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
switch (errorLevel) { switch (errorLevel) {
case ERR_DEBUG: case ERR_DEBUG:
@ -161,7 +130,7 @@ BoundedTimeline<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_d
int searchedFrameCount = ps_process_raw(&decoder, buffer.data(), buffer.size(), false, false); int searchedFrameCount = ps_process_raw(&decoder, buffer.data(), buffer.size(), false, false);
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition."); if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
}; };
processAudioStream(*audioStream.get(), processBuffer, progressSink); process16bitAudioStream(*audioStream.get(), processBuffer, progressSink);
// End recognition // End recognition
error = ps_end_utt(&decoder); error = ps_end_utt(&decoder);
@ -235,7 +204,7 @@ optional<BoundedTimeline<Phone>> getPhoneAlignment(
} }
} }
}; };
processAudioStream(*audioStream.get(), processBuffer, progressSink); process16bitAudioStream(*audioStream.get(), processBuffer, progressSink);
// End search // End search
error = ps_search_finish(search.get()); error = ps_search_finish(search.get());
@ -302,6 +271,7 @@ BoundedTimeline<Phone> detectPhones(
try { try {
// Split audio into utterances // Split audio into utterances
BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink); BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
// For progress reporting: weigh utterances by length // For progress reporting: weigh utterances by length
ProgressMerger dialogProgressMerger(dialogProgressSink); ProgressMerger dialogProgressMerger(dialogProgressSink);
vector<ProgressSink*> utteranceProgressSinks; vector<ProgressSink*> utteranceProgressSinks;