Refactored audio streams

* All streams are now mono (simplifies reasoning about samples)
* Streams can be cloned
* Streams can be seeked within
This commit is contained in:
Daniel Wolf 2016-03-07 21:28:31 +01:00
parent 419b0ec469
commit b78e418a8f
13 changed files with 214 additions and 200 deletions

View File

@ -100,7 +100,7 @@ set(SOURCE_FILES
src/phoneExtraction.cpp
src/platformTools.cpp
src/tools.cpp
src/audioInput/ChannelDownmixer.cpp
src/audioInput/AudioStream.cpp
src/audioInput/SampleRateConverter.cpp
src/audioInput/WaveFileReader.cpp
src/audioInput/waveFileWriting.cpp

View File

@ -0,0 +1,5 @@
#include "AudioStream.h"
bool AudioStream::endOfStream() {
return getSampleIndex() >= getSampleCount();
}

View File

@ -1,9 +1,16 @@
#pragma once
#include <memory>
// A mono stream of floating-point samples.
class AudioStream {
public:
virtual int getFrameRate() = 0;
virtual int getFrameCount() = 0;
virtual int getChannelCount() = 0;
virtual bool getNextSample(float &sample) = 0;
virtual ~AudioStream() {}
virtual std::unique_ptr<AudioStream> clone(bool reset) = 0;
virtual int getSampleRate() = 0;
virtual int getSampleCount() = 0;
virtual int getSampleIndex() = 0;
virtual void seek(int sampleIndex) = 0;
bool endOfStream();
virtual float readSample() = 0;
};

View File

@ -1,31 +0,0 @@
#include "ChannelDownmixer.h"
ChannelDownmixer::ChannelDownmixer(std::unique_ptr<AudioStream> inputStream) :
inputStream(std::move(inputStream)),
inputChannelCount(this->inputStream->getChannelCount())
{}
int ChannelDownmixer::getFrameRate() {
return inputStream->getFrameRate();
}
int ChannelDownmixer::getFrameCount() {
return inputStream->getFrameCount();
}
int ChannelDownmixer::getChannelCount() {
return 1;
}
bool ChannelDownmixer::getNextSample(float &sample) {
float sum = 0;
for (int channelIndex = 0; channelIndex < inputChannelCount; channelIndex++) {
float currentSample;
if (!inputStream->getNextSample(currentSample)) return false;
sum += currentSample;
}
sample = sum / inputChannelCount;
return true;
}

View File

@ -1,18 +0,0 @@
#pragma once
#include "AudioStream.h"
#include <memory>
// Converts a multi-channel audio stream to mono.
class ChannelDownmixer : public AudioStream {
public:
ChannelDownmixer(std::unique_ptr<AudioStream> inputStream);
virtual int getFrameRate() override;
virtual int getFrameCount() override;
virtual int getChannelCount() override;
virtual bool getNextSample(float &sample) override;
private:
std::unique_ptr<AudioStream> inputStream;
int inputChannelCount;
};

View File

@ -1,80 +1,94 @@
#include <cmath>
#include "SampleRateConverter.h"
#include <stdexcept>
#include <algorithm>
using std::runtime_error;
SampleRateConverter::SampleRateConverter(std::unique_ptr<AudioStream> inputStream, int outputFrameRate) :
inputStream(std::move(inputStream)),
downscalingFactor(static_cast<double>(this->inputStream->getFrameRate()) / outputFrameRate),
downscalingFactor(static_cast<double>(this->inputStream->getSampleRate()) / outputFrameRate),
outputFrameRate(outputFrameRate),
outputFrameCount(std::lround(this->inputStream->getFrameCount() / downscalingFactor)),
outputFrameCount(std::lround(this->inputStream->getSampleCount() / downscalingFactor)),
lastInputSample(0),
lastInputSampleIndex(-1),
nextOutputSampleIndex(0)
{
if (this->inputStream->getChannelCount() != 1) {
throw runtime_error("Only mono input streams are supported.");
}
if (this->inputStream->getFrameRate() < outputFrameRate) {
if (this->inputStream->getSampleRate() < outputFrameRate) {
throw runtime_error("Upsampling not supported.");
}
}
int SampleRateConverter::getFrameRate() {
SampleRateConverter::SampleRateConverter(const SampleRateConverter& rhs, bool reset) :
SampleRateConverter(rhs.inputStream->clone(reset), outputFrameRate)
{
nextOutputSampleIndex = reset ? 0 : rhs.nextOutputSampleIndex;
}
std::unique_ptr<AudioStream> SampleRateConverter::clone(bool reset) {
return std::make_unique<SampleRateConverter>(*this, reset);
}
int SampleRateConverter::getSampleRate() {
return outputFrameRate;
}
int SampleRateConverter::getFrameCount() {
int SampleRateConverter::getSampleCount() {
return outputFrameCount;
}
int SampleRateConverter::getChannelCount() {
return 1;
int SampleRateConverter::getSampleIndex() {
return nextOutputSampleIndex;
}
bool SampleRateConverter::getNextSample(float &sample) {
if (nextOutputSampleIndex >= outputFrameCount) return false;
void SampleRateConverter::seek(int sampleIndex) {
if (sampleIndex < 0 || sampleIndex >= outputFrameCount) throw std::invalid_argument("sampleIndex out of range.");
double start = nextOutputSampleIndex * downscalingFactor;
double end = (nextOutputSampleIndex + 1) * downscalingFactor;
nextOutputSampleIndex = sampleIndex;
}
float SampleRateConverter::readSample() {
if (nextOutputSampleIndex >= outputFrameCount) throw std::out_of_range("End of stream.");
double inputStart = nextOutputSampleIndex * downscalingFactor;
double inputEnd = (nextOutputSampleIndex + 1) * downscalingFactor;
sample = mean(start, end);
nextOutputSampleIndex++;
return true;
return mean(inputStart, inputEnd);
}
float SampleRateConverter::mean(double start, double end) {
float SampleRateConverter::mean(double inputStart, double inputEnd) {
// Calculate weighted sum...
double sum = 0;
// ... first sample (weight <= 1)
int startIndex = static_cast<int>(start);
sum += getInputSample(startIndex) * ((startIndex + 1) - start);
int startIndex = static_cast<int>(inputStart);
sum += getInputSample(startIndex) * ((startIndex + 1) - inputStart);
// ... middle samples (weight 1 each)
int endIndex = static_cast<int>(end);
int endIndex = static_cast<int>(inputEnd);
for (int index = startIndex + 1; index < endIndex; index++) {
sum += getInputSample(index);
}
// ... last sample (weight < 1)
sum += getInputSample(endIndex) * (end - endIndex);
sum += getInputSample(endIndex) * (inputEnd - endIndex);
return static_cast<float>(sum / (end - start));
return static_cast<float>(sum / (inputEnd - inputStart));
}
float SampleRateConverter::getInputSample(int sampleIndex) {
sampleIndex = std::min(sampleIndex, inputStream->getSampleCount() - 1);
if (sampleIndex < 0) return 0.0f;
if (sampleIndex == lastInputSampleIndex) {
return lastInputSample;
}
if (sampleIndex == lastInputSampleIndex + 1) {
lastInputSampleIndex++;
// Read the next sample.
// If the input stream has no more samples (at the very end),
// we'll just reuse the last sample we have.
inputStream->getNextSample(lastInputSample);
return lastInputSample;
}
throw runtime_error("Can only return the last sample or the one following it.");
if (sampleIndex != inputStream->getSampleIndex()) {
inputStream->seek(sampleIndex);
}
lastInputSample = inputStream->readSample();
lastInputSampleIndex = sampleIndex;
return lastInputSample;
}

View File

@ -1,22 +1,21 @@
#pragma once
#include <memory>
#include <vector>
#include "AudioStream.h"
class SampleRateConverter : public AudioStream {
public:
SampleRateConverter(std::unique_ptr<AudioStream> inputStream, int outputFrameRate);
virtual int getFrameRate() override;
virtual int getFrameCount() override;
virtual int getChannelCount() override;
virtual bool getNextSample(float &sample) override;
SampleRateConverter(const SampleRateConverter& rhs, bool reset);
std::unique_ptr<AudioStream> clone(bool reset) override;
int getSampleRate() override;
int getSampleCount() override;
int getSampleIndex() override;
void seek(int sampleIndex) override;
float readSample() override;
private:
// The stream we're reading from
std::unique_ptr<AudioStream> inputStream;
// input frame rate / output frame rate
double downscalingFactor;
double downscalingFactor; // input frame rate / output frame rate
int outputFrameRate;
int outputFrameCount;

View File

@ -24,10 +24,12 @@ enum class Codec {
Float = 0x03
};
WaveFileReader::WaveFileReader(boost::filesystem::path filePath) {
// Open file
file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
file.open(filePath, std::ios::binary);
WaveFileReader::WaveFileReader(boost::filesystem::path filePath) :
filePath(filePath),
file(),
sampleIndex(0)
{
openFile();
// Read header
uint32_t rootChunkId = read<uint32_t>(file);
@ -42,111 +44,148 @@ WaveFileReader::WaveFileReader(boost::filesystem::path filePath) {
// Read chunks until we reach the data chunk
bool reachedDataChunk = false;
int bytesPerSample = 0;
bytesPerSample = 0;
do {
uint32_t chunkId = read<uint32_t>(file);
int chunkSize = read<uint32_t>(file);
switch (chunkId) {
case fourcc('f', 'm', 't', ' '): {
// Read relevant data
Codec codec = (Codec) read<uint16_t>(file);
channelCount = read<uint16_t>(file);
frameRate = read<uint32_t>(file);
read<uint32_t>(file); // Bytes per second
int frameSize = read<uint16_t>(file);
int bitsPerSample = read<uint16_t>(file);
case fourcc('f', 'm', 't', ' '): {
// Read relevant data
Codec codec = (Codec)read<uint16_t>(file);
channelCount = read<uint16_t>(file);
frameRate = read<uint32_t>(file);
read<uint32_t>(file); // Bytes per second
int frameSize = read<uint16_t>(file);
int bitsPerSample = read<uint16_t>(file);
// We're read 16 bytes so far. Skip the remainder.
file.seekg(roundToEven(chunkSize) - 16, file.cur);
// We've read 16 bytes so far. Skip the remainder.
file.seekg(roundToEven(chunkSize) - 16, file.cur);
// Determine sample format
switch (codec) {
case Codec::PCM:
// Determine sample size.
// According to the WAVE standard, sample sizes that are not multiples of 8 bits
// (e.g. 12 bits) can be treated like the next-larger byte size.
if (bitsPerSample == 8) {
sampleFormat = SampleFormat::UInt8;
bytesPerSample = 1;
} else if (bitsPerSample <= 16) {
sampleFormat = SampleFormat::Int16;
bytesPerSample = 2;
} else if (bitsPerSample <= 24) {
sampleFormat = SampleFormat::Int24;
bytesPerSample = 3;
} else {
throw runtime_error(
format("Unsupported sample format: {}-bit integer samples.", bitsPerSample));
}
if (bytesPerSample != frameSize / channelCount) {
throw runtime_error("Unsupported sample organization.");
}
break;
case Codec::Float:
if (bitsPerSample == 32) {
sampleFormat = SampleFormat::Float32;
bytesPerSample = 4;
} else {
throw runtime_error(format("Unsupported sample format: {}-bit floating-point samples.", bitsPerSample));
}
break;
default:
throw runtime_error("Unsupported sample format. Only uncompressed formats are supported.");
// Determine sample format
switch (codec) {
case Codec::PCM:
// Determine sample size.
// According to the WAVE standard, sample sizes that are not multiples of 8 bits
// (e.g. 12 bits) can be treated like the next-larger byte size.
if (bitsPerSample == 8) {
sampleFormat = SampleFormat::UInt8;
bytesPerSample = 1;
} else if (bitsPerSample <= 16) {
sampleFormat = SampleFormat::Int16;
bytesPerSample = 2;
} else if (bitsPerSample <= 24) {
sampleFormat = SampleFormat::Int24;
bytesPerSample = 3;
} else {
throw runtime_error(
format("Unsupported sample format: {}-bit integer samples.", bitsPerSample));
}
if (bytesPerSample != frameSize / channelCount) {
throw runtime_error("Unsupported sample organization.");
}
break;
}
case fourcc('d', 'a', 't', 'a'): {
reachedDataChunk = true;
remainingSamples = chunkSize / bytesPerSample;
frameCount = remainingSamples / channelCount;
break;
}
default: {
// Skip unknown chunk
file.seekg(roundToEven(chunkSize), file.cur);
case Codec::Float:
if (bitsPerSample == 32) {
sampleFormat = SampleFormat::Float32;
bytesPerSample = 4;
} else {
throw runtime_error(format("Unsupported sample format: {}-bit floating-point samples.", bitsPerSample));
}
break;
default:
throw runtime_error("Unsupported sample format. Only uncompressed formats are supported.");
}
break;
}
case fourcc('d', 'a', 't', 'a'): {
reachedDataChunk = true;
dataOffset = file.tellg();
sampleCount = chunkSize / bytesPerSample;
frameCount = sampleCount / channelCount;
break;
}
default: {
// Skip unknown chunk
file.seekg(roundToEven(chunkSize), file.cur);
break;
}
}
} while (!reachedDataChunk);
}
int WaveFileReader::getFrameRate() {
WaveFileReader::WaveFileReader(const WaveFileReader& rhs, bool reset) :
filePath(rhs.filePath),
file(),
bytesPerSample(rhs.bytesPerSample),
sampleFormat(rhs.sampleFormat),
frameRate(rhs.frameRate),
frameCount(rhs.frameCount),
channelCount(rhs.channelCount),
sampleCount(rhs.sampleCount),
dataOffset(rhs.dataOffset),
sampleIndex(-1)
{
openFile();
seek(reset ? 0 : rhs.sampleIndex);
}
std::unique_ptr<AudioStream> WaveFileReader::clone(bool reset) {
return std::make_unique<WaveFileReader>(*this, reset);
}
void WaveFileReader::openFile() {
file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
file.open(filePath, std::ios::binary);
}
int WaveFileReader::getSampleRate() {
return frameRate;
}
int WaveFileReader::getFrameCount() {
int WaveFileReader::getSampleCount() {
return frameCount;
}
int WaveFileReader::getChannelCount() {
return channelCount;
int WaveFileReader::getSampleIndex() {
return sampleIndex;
}
bool WaveFileReader::getNextSample(float &sample) {
if (remainingSamples == 0) return false;
remainingSamples--;
void WaveFileReader::seek(int sampleIndex) {
if (sampleIndex < 0 || sampleIndex >= sampleCount) throw std::invalid_argument("sampleIndex out of range.");
switch (sampleFormat) {
file.seekg(dataOffset + sampleIndex * channelCount * bytesPerSample);
this->sampleIndex = sampleIndex;
}
float WaveFileReader::readSample() {
if (sampleIndex + channelCount > sampleCount) throw std::out_of_range("End of stream.");
sampleIndex += channelCount;
float sum = 0;
for (int channelIndex = 0; channelIndex < channelCount; channelIndex++) {
switch (sampleFormat) {
case SampleFormat::UInt8: {
uint8_t raw = read<uint8_t>(file);
sample = toNormalizedFloat(raw, 0, UINT8_MAX);
sum += toNormalizedFloat(raw, 0, UINT8_MAX);
break;
}
case SampleFormat::Int16: {
int16_t raw = read<int16_t>(file);
sample = toNormalizedFloat(raw, INT16_MIN, INT16_MAX);
sum += toNormalizedFloat(raw, INT16_MIN, INT16_MAX);
break;
}
case SampleFormat::Int24: {
int raw = read<int, 24>(file);
if (raw & 0x800000) raw |= 0xFF000000; // Fix two's complement
sample = toNormalizedFloat(raw, INT24_MIN, INT24_MAX);
sum += toNormalizedFloat(raw, INT24_MIN, INT24_MAX);
break;
}
case SampleFormat::Float32: {
sample = read<float>(file);
sum += read<float>(file);
break;
}
}
}
return true;
return sum / channelCount;
}

View File

@ -1,8 +1,5 @@
#pragma once
#include <string>
#include <cstdint>
#include <fstream>
#include <boost/filesystem/path.hpp>
#include <boost/filesystem/fstream.hpp>
#include "AudioStream.h"
@ -17,16 +14,26 @@ enum class SampleFormat {
class WaveFileReader : public AudioStream {
public:
WaveFileReader(boost::filesystem::path filePath);
virtual int getFrameRate() override ;
virtual int getFrameCount() override;
virtual int getChannelCount() override;
virtual bool getNextSample(float &sample) override;
WaveFileReader(const WaveFileReader& rhs, bool reset);
std::unique_ptr<AudioStream> clone(bool reset) override;
int getSampleRate() override ;
int getSampleCount() override;
int getSampleIndex() override;
void seek(int sampleIndex) override;
float readSample() override;
private:
void openFile();
private:
boost::filesystem::path filePath;
boost::filesystem::ifstream file;
int bytesPerSample;
SampleFormat sampleFormat;
int frameRate;
int frameCount;
int channelCount;
int remainingSamples;
int sampleCount;
size_t dataOffset;
int sampleIndex;
};

View File

@ -13,9 +13,9 @@ void createWaveFile(std::unique_ptr<AudioStream> inputStream, std::string fileNa
// Write RIFF chunk
write<uint32_t>(fourcc('R', 'I', 'F', 'F'), file);
uint32_t formatChunkSize = 16;
uint16_t channelCount = static_cast<uint16_t>(inputStream->getChannelCount());
uint16_t channelCount = 1;
uint16_t frameSize = static_cast<uint16_t>(channelCount * sizeof(float));
uint32_t dataChunkSize = static_cast<uint32_t>(inputStream->getFrameCount() * frameSize);
uint32_t dataChunkSize = static_cast<uint32_t>(inputStream->getSampleCount() * frameSize);
uint32_t riffChunkSize = 4 + (8 + formatChunkSize) + (8 + dataChunkSize);
write<uint32_t>(riffChunkSize, file);
write<uint32_t>(fourcc('W', 'A', 'V', 'E'), file);
@ -26,7 +26,7 @@ void createWaveFile(std::unique_ptr<AudioStream> inputStream, std::string fileNa
uint16_t codec = 0x03; // 32-bit float
write<uint16_t>(codec, file);
write<uint16_t>(channelCount, file);
uint32_t frameRate = static_cast<uint16_t>(inputStream->getFrameRate());
uint32_t frameRate = static_cast<uint16_t>(inputStream->getSampleRate());
write<uint32_t>(frameRate, file);
uint32_t bytesPerSecond = frameRate * frameSize;
write<uint32_t>(bytesPerSecond, file);
@ -37,8 +37,8 @@ void createWaveFile(std::unique_ptr<AudioStream> inputStream, std::string fileNa
// Write data chunk
write<uint32_t>(fourcc('d', 'a', 't', 'a'), file);
write<uint32_t>(dataChunkSize, file);
float sample;
while (inputStream->getNextSample(sample)) {
while (!inputStream->endOfStream()) {
float sample = inputStream->readSample();
write<float>(sample, file);
}
}

View File

@ -107,7 +107,7 @@ int main(int argc, char *argv[]) {
{
ProgressBar progressBar;
phones = detectPhones(
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
createAudioStream(inputFileName.getValue()),
dialog.getValue(),
progressBar);
}

View File

@ -3,7 +3,6 @@
#include <boost/algorithm/string.hpp>
#include "phoneExtraction.h"
#include "audioInput/SampleRateConverter.h"
#include "audioInput/ChannelDownmixer.h"
#include "platformTools.h"
#include "tools.h"
#include <format.h>
@ -33,17 +32,12 @@ using std::regex;
using std::regex_replace;
using std::chrono::duration;
unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
// Downmix, if required
if (stream->getChannelCount() != 1) {
stream.reset(new ChannelDownmixer(std::move(stream)));
}
unique_ptr<AudioStream> to16kHz(unique_ptr<AudioStream> stream) {
// Downsample, if required
if (stream->getFrameRate() < 16000) {
if (stream->getSampleRate() < 16000) {
throw invalid_argument("Audio sample rate must not be below 16kHz.");
}
if (stream->getFrameRate() != 16000) {
if (stream->getSampleRate() != 16000) {
stream.reset(new SampleRateConverter(std::move(stream), 16000));
}
@ -88,7 +82,7 @@ int16_t floatSampleToInt16(float sample) {
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
}
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
void processAudioStream(AudioStream& audioStream16kHz, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
// Process entire sound file
vector<int16_t> buffer;
const int capacity = 1600; // 0.1 second capacity
@ -97,10 +91,9 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
do {
// Read to buffer
buffer.clear();
while (buffer.size() < capacity) {
while (buffer.size() < capacity && !audioStream16kHz.endOfStream()) {
// Read sample
float floatSample;
if (!audioStream16kHzMono.getNextSample(floatSample)) break;
float floatSample = audioStream16kHz.readSample();
int16_t sample = floatSampleToInt16(floatSample);
buffer.push_back(sample);
}
@ -109,7 +102,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
processBuffer(buffer);
sampleCount += buffer.size();
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount());
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHz.getSampleCount());
} while (buffer.size());
}
@ -158,7 +151,7 @@ void sphinxLogCallback(void* user_data, err_lvl_t errorLevel, const char* format
vector<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream));
audioStream = to16kHz(std::move(audioStream));
// Start recognition
int error = ps_start_utt(&recognizer);
@ -243,7 +236,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
if (error) throw runtime_error("Error populating alignment struct.");
// Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream));
audioStream = to16kHz(std::move(audioStream));
// Create search structure
acmod_t* acousticModel = recognizer.acmod;
@ -304,7 +297,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
}
map<centiseconds, Phone> detectPhones(
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
unique_ptr<AudioStream> audioStream,
boost::optional<std::string> dialog,
ProgressSink& progressSink)
{
@ -329,13 +322,13 @@ map<centiseconds, Phone> detectPhones(
// Get words
vector<string> words = dialog
? extractDialogWords(*dialog)
: recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
: recognizeWords(audioStream->clone(true), *recognizer.get(), wordRecognitionProgressSink);
// Look up words in dictionary
vector<s3wid_t> wordIds = getWordIds(words, *recognizer->dict);
// Align the word's phones with speech
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, std::move(audioStream), *recognizer.get(), alignmentProgressSink);
return result;
}
catch (...) {

View File

@ -2,7 +2,6 @@
#include <map>
#include <memory>
#include <functional>
#include "audioInput/AudioStream.h"
#include "Phone.h"
#include "centiseconds.h"
@ -10,6 +9,6 @@
#include <boost/optional/optional.hpp>
std::map<centiseconds, Phone> detectPhones(
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
std::unique_ptr<AudioStream> audioStream,
boost::optional<std::string> dialog,
ProgressSink& progressSink);