Using WebRTC for voice activity detection (VAD)
My simple power-based approach wasn't reliable enough.
This commit is contained in:
parent
aec3dbae01
commit
c6c31a831c
|
@ -93,6 +93,31 @@ set_target_properties(gtest_main PROPERTIES FOLDER lib)
|
|||
# ... GSL
|
||||
include_directories(SYSTEM "lib/gsl/include")
|
||||
|
||||
# ... WebRTC
|
||||
include_directories(SYSTEM "lib/webrtc-8d2248ff")
|
||||
set(webRTCFiles
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/cross_correlation.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/division_operations.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/downsample_fast.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/energy.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/get_scaling_square.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/min_max_operations.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/resample_48khz.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/resample_by_2_internal.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/resample_fractional.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/spl_init.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/spl_inl.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/signal_processing/vector_scaling_operations.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_core.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_filterbank.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_gmm.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/vad/vad_sp.c
|
||||
lib/webrtc-8d2248ff/webrtc/common_audio/vad/webrtc_vad.c
|
||||
)
|
||||
add_library(webRTC ${webRTCFiles})
|
||||
target_compile_options(webRTC PRIVATE ${disableWarningsFlags})
|
||||
set_target_properties(webRTC PROPERTIES FOLDER lib)
|
||||
|
||||
# ... Flite
|
||||
include_directories("lib/flite-1.4/include" "lib/flite-1.4")
|
||||
set(fliteFiles
|
||||
|
@ -180,6 +205,7 @@ set(SOURCE_FILES
|
|||
src/audio/voiceActivityDetection.cpp src/audio/voiceActivityDetection.h
|
||||
src/audio/WaveFileReader.cpp src/audio/WaveFileReader.h
|
||||
src/audio/waveFileWriting.cpp src/audio/waveFileWriting.h
|
||||
src/audio/processing.cpp src/audio/processing.h
|
||||
src/stringTools.cpp src/stringTools.h
|
||||
src/NiceCmdLineOutput.cpp src/NiceCmdLineOutput.h
|
||||
src/TablePrinter.cpp src/TablePrinter.h
|
||||
|
@ -198,7 +224,7 @@ set(SOURCE_FILES
|
|||
src/tupleHash.h
|
||||
)
|
||||
add_executable(rhubarb ${SOURCE_FILES})
|
||||
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite)
|
||||
target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite webRTC)
|
||||
target_compile_options(rhubarb PUBLIC ${enableWarningsFlags})
|
||||
|
||||
# Define test project
|
||||
|
|
16
LICENSE.md
16
LICENSE.md
|
@ -160,4 +160,18 @@ The [Sound Change Applier](http://www.zompist.com/sounds.htm) and its [rule set
|
|||
|
||||
> The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
### WebRTC
|
||||
|
||||
The [WebRTC](https://chromium.googlesource.com/external/webrtc) library is released under the **3-clause BSD License**.
|
||||
|
||||
> Copyright (c) 2011, The WebRTC project authors. All rights reserved.
|
||||
|
||||
> Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
> * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
> * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
> * Neither the name of Google nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#include "processing.h"
|
||||
|
||||
using std::function;
|
||||
using std::vector;
|
||||
|
||||
// Converts a float in the range -1..1 to a signed 16-bit int
|
||||
inline int16_t floatSampleToInt16(float sample) {
|
||||
sample = std::max(sample, -1.0f);
|
||||
sample = std::min(sample, 1.0f);
|
||||
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
|
||||
}
|
||||
|
||||
void process16bitAudioStream(AudioStream& audioStream, function<void(const vector<int16_t>&)> processBuffer, size_t bufferCapacity, ProgressSink& progressSink) {
|
||||
// Process entire sound stream
|
||||
vector<int16_t> buffer;
|
||||
buffer.reserve(bufferCapacity);
|
||||
int sampleCount = 0;
|
||||
do {
|
||||
// Read to buffer
|
||||
buffer.clear();
|
||||
while (buffer.size() < bufferCapacity && !audioStream.endOfStream()) {
|
||||
// Read sample
|
||||
float floatSample = audioStream.readSample();
|
||||
int16_t sample = floatSampleToInt16(floatSample);
|
||||
buffer.push_back(sample);
|
||||
}
|
||||
|
||||
// Process buffer
|
||||
processBuffer(buffer);
|
||||
|
||||
sampleCount += buffer.size();
|
||||
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream.getSampleCount());
|
||||
} while (buffer.size());
|
||||
}
|
||||
|
||||
void process16bitAudioStream(AudioStream& audioStream, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
|
||||
const size_t capacity = 1600; // 0.1 second capacity
|
||||
process16bitAudioStream(audioStream, processBuffer, capacity, progressSink);
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "audio/AudioStream.h"
|
||||
#include "ProgressBar.h"
|
||||
|
||||
void process16bitAudioStream(AudioStream& audioStream, std::function<void(const std::vector<int16_t>&)> processBuffer, size_t bufferCapacity, ProgressSink& progressSink);
|
||||
void process16bitAudioStream(AudioStream& audioStream, std::function<void(const std::vector<int16_t>&)> processBuffer, ProgressSink& progressSink);
|
|
@ -1,76 +1,45 @@
|
|||
#include "voiceActivityDetection.h"
|
||||
#include <audio/DCOffset.h>
|
||||
#include <audio/SampleRateConverter.h>
|
||||
#include <boost/optional/optional.hpp>
|
||||
#include <logging.h>
|
||||
#include <pairs.h>
|
||||
#include <boost/range/adaptor/transformed.hpp>
|
||||
#include <stringTools.h>
|
||||
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
|
||||
#include "processing.h"
|
||||
|
||||
using std::numeric_limits;
|
||||
using std::vector;
|
||||
using boost::optional;
|
||||
using boost::adaptors::transformed;
|
||||
using fmt::format;
|
||||
|
||||
float getRMS(AudioStream& audioStream, int maxSampleCount = numeric_limits<int>::max()) {
|
||||
double sum = 0; // Use double to prevent rounding errors with large number of summands
|
||||
int sampleCount;
|
||||
for (sampleCount = 0; sampleCount < maxSampleCount && !audioStream.endOfStream(); sampleCount++) {
|
||||
sum += std::pow(static_cast<double>(audioStream.readSample()), 2);
|
||||
}
|
||||
return sampleCount > 0 ? static_cast<float>(std::sqrt(sum / sampleCount)) : 0.0f;
|
||||
}
|
||||
|
||||
float getRMS(const vector<float>& rmsSegments) {
|
||||
if (rmsSegments.empty()) return 0;
|
||||
|
||||
double sum = 0; // Use double to prevent rounding errors with large number of summands
|
||||
for (float rmsSegment : rmsSegments) {
|
||||
sum += rmsSegment;
|
||||
}
|
||||
return static_cast<float>(std::sqrt(sum / rmsSegments.size()));
|
||||
}
|
||||
using std::runtime_error;
|
||||
|
||||
BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStream, ProgressSink& progressSink) {
|
||||
// Make sure audio stream has no DC offset
|
||||
audioStream = removeDCOffset(std::move(audioStream));
|
||||
// Prepare audio for VAD
|
||||
audioStream = removeDCOffset(convertSampleRate(std::move(audioStream), 16000));
|
||||
|
||||
// Resample to remove noise
|
||||
constexpr int maxFrequency = 4000;
|
||||
constexpr int sampleRate = 2 * maxFrequency;
|
||||
audioStream = convertSampleRate(std::move(audioStream), sampleRate);
|
||||
VadInst* vadHandle = WebRtcVad_Create();
|
||||
if (!vadHandle) throw runtime_error("Error creating WebRTC VAD handle.");
|
||||
|
||||
// Collect RMS data
|
||||
vector<float> rmsSegments;
|
||||
logging::debug("RMS calculation -- start");
|
||||
int64_t centisecondCount = (audioStream->getSampleCount() - audioStream->getSampleIndex()) / 100;
|
||||
for (int cs = 0; cs < centisecondCount; ++cs) {
|
||||
rmsSegments.push_back(getRMS(*audioStream, sampleRate / 100));
|
||||
progressSink.reportProgress(static_cast<double>(cs) / centisecondCount);
|
||||
}
|
||||
logging::debug("RMS calculation -- end");
|
||||
int error = WebRtcVad_Init(vadHandle);
|
||||
if (error) throw runtime_error("Error initializing WebRTC VAD handle.");
|
||||
|
||||
const float rms = getRMS(rmsSegments);
|
||||
logging::debugFormat("RMS value: {0:.5f}", rms);
|
||||
const int aggressiveness = 1; // 0..3. The higher, the more is cut off.
|
||||
error = WebRtcVad_set_mode(vadHandle, aggressiveness);
|
||||
if (error) throw runtime_error("Error setting WebRTC VAD aggressiveness.");
|
||||
|
||||
// Detect activity
|
||||
const float cutoff = rms / 25;
|
||||
logging::debugFormat("RMS cutoff for voice activity detection: {0:.5f}", cutoff);
|
||||
BoundedTimeline<void> activity(audioStream->getTruncatedRange());
|
||||
for (centiseconds time = centiseconds::zero(); static_cast<size_t>(time.count()) < rmsSegments.size(); ++time) {
|
||||
float currentRMS = rmsSegments[time.count()];
|
||||
bool active = currentRMS > cutoff;
|
||||
if (active) {
|
||||
centiseconds time = centiseconds::zero();
|
||||
auto processBuffer = [&](const vector<int16_t>& buffer) {
|
||||
bool isActive = WebRtcVad_Process(vadHandle, audioStream->getSampleRate(), buffer.data(), buffer.size()) == 1;
|
||||
if (isActive) {
|
||||
activity.set(time, time + centiseconds(1));
|
||||
}
|
||||
}
|
||||
time += centiseconds(1);
|
||||
};
|
||||
const size_t bufferCapacity = audioStream->getSampleRate() / 100;
|
||||
process16bitAudioStream(*audioStream.get(), processBuffer, bufferCapacity, progressSink);
|
||||
|
||||
// Pad each activity to prevent cropping
|
||||
const centiseconds padding(3);
|
||||
for (const auto& element : BoundedTimeline<void>(activity)) {
|
||||
activity.set(element.getStart() - padding, element.getEnd() + padding);
|
||||
}
|
||||
WebRtcVad_Free(vadHandle);
|
||||
|
||||
// Fill small gaps in activity
|
||||
const centiseconds maxGap(5);
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#include <iostream>
|
||||
#include <boost/filesystem.hpp>
|
||||
#include "phoneExtraction.h"
|
||||
#include "audio/SampleRateConverter.h"
|
||||
|
@ -17,6 +16,7 @@
|
|||
#include "tokenization.h"
|
||||
#include "g2p.h"
|
||||
#include "ContinuousTimeline.h"
|
||||
#include "audio/processing.h"
|
||||
|
||||
extern "C" {
|
||||
#include <pocketsphinx.h>
|
||||
|
@ -70,37 +70,6 @@ lambda_unique_ptr<ps_decoder_t> createDecoder() {
|
|||
return recognizer;
|
||||
}
|
||||
|
||||
// Converts a float in the range -1..1 to a signed 16-bit int
|
||||
int16_t floatSampleToInt16(float sample) {
|
||||
sample = std::max(sample, -1.0f);
|
||||
sample = std::min(sample, 1.0f);
|
||||
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
|
||||
}
|
||||
|
||||
void processAudioStream(AudioStream& audioStream16kHz, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
|
||||
// Process entire sound stream
|
||||
vector<int16_t> buffer;
|
||||
const int capacity = 1600; // 0.1 second capacity
|
||||
buffer.reserve(capacity);
|
||||
int sampleCount = 0;
|
||||
do {
|
||||
// Read to buffer
|
||||
buffer.clear();
|
||||
while (buffer.size() < capacity && !audioStream16kHz.endOfStream()) {
|
||||
// Read sample
|
||||
float floatSample = audioStream16kHz.readSample();
|
||||
int16_t sample = floatSampleToInt16(floatSample);
|
||||
buffer.push_back(sample);
|
||||
}
|
||||
|
||||
// Process buffer
|
||||
processBuffer(buffer);
|
||||
|
||||
sampleCount += buffer.size();
|
||||
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHz.getSampleCount());
|
||||
} while (buffer.size());
|
||||
}
|
||||
|
||||
logging::Level ConvertSphinxErrorLevel(err_lvl_t errorLevel) {
|
||||
switch (errorLevel) {
|
||||
case ERR_DEBUG:
|
||||
|
@ -161,7 +130,7 @@ BoundedTimeline<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_d
|
|||
int searchedFrameCount = ps_process_raw(&decoder, buffer.data(), buffer.size(), false, false);
|
||||
if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition.");
|
||||
};
|
||||
processAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||
process16bitAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||
|
||||
// End recognition
|
||||
error = ps_end_utt(&decoder);
|
||||
|
@ -235,7 +204,7 @@ optional<BoundedTimeline<Phone>> getPhoneAlignment(
|
|||
}
|
||||
}
|
||||
};
|
||||
processAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||
process16bitAudioStream(*audioStream.get(), processBuffer, progressSink);
|
||||
|
||||
// End search
|
||||
error = ps_search_finish(search.get());
|
||||
|
@ -302,6 +271,7 @@ BoundedTimeline<Phone> detectPhones(
|
|||
try {
|
||||
// Split audio into utterances
|
||||
BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
|
||||
|
||||
// For progress reporting: weigh utterances by length
|
||||
ProgressMerger dialogProgressMerger(dialogProgressSink);
|
||||
vector<ProgressSink*> utteranceProgressSinks;
|
||||
|
|
Loading…
Reference in New Issue