Refactored audio streams

* All streams are now mono (simplifies reasoning about samples)
* Streams can be cloned
* Streams can be seeked within
This commit is contained in:
Daniel Wolf 2016-03-07 21:28:31 +01:00
parent 419b0ec469
commit b78e418a8f
13 changed files with 214 additions and 200 deletions

View File

@ -100,7 +100,7 @@ set(SOURCE_FILES
src/phoneExtraction.cpp src/phoneExtraction.cpp
src/platformTools.cpp src/platformTools.cpp
src/tools.cpp src/tools.cpp
src/audioInput/ChannelDownmixer.cpp src/audioInput/AudioStream.cpp
src/audioInput/SampleRateConverter.cpp src/audioInput/SampleRateConverter.cpp
src/audioInput/WaveFileReader.cpp src/audioInput/WaveFileReader.cpp
src/audioInput/waveFileWriting.cpp src/audioInput/waveFileWriting.cpp

View File

@ -0,0 +1,5 @@
#include "AudioStream.h"
bool AudioStream::endOfStream() {
return getSampleIndex() >= getSampleCount();
}

View File

@ -1,9 +1,16 @@
#pragma once #pragma once
#include <memory>
// A mono stream of floating-point samples.
class AudioStream { class AudioStream {
public: public:
virtual int getFrameRate() = 0; virtual ~AudioStream() {}
virtual int getFrameCount() = 0; virtual std::unique_ptr<AudioStream> clone(bool reset) = 0;
virtual int getChannelCount() = 0; virtual int getSampleRate() = 0;
virtual bool getNextSample(float &sample) = 0; virtual int getSampleCount() = 0;
virtual int getSampleIndex() = 0;
virtual void seek(int sampleIndex) = 0;
bool endOfStream();
virtual float readSample() = 0;
}; };

View File

@ -1,31 +0,0 @@
#include "ChannelDownmixer.h"
ChannelDownmixer::ChannelDownmixer(std::unique_ptr<AudioStream> inputStream) :
inputStream(std::move(inputStream)),
inputChannelCount(this->inputStream->getChannelCount())
{}
int ChannelDownmixer::getFrameRate() {
return inputStream->getFrameRate();
}
int ChannelDownmixer::getFrameCount() {
return inputStream->getFrameCount();
}
int ChannelDownmixer::getChannelCount() {
return 1;
}
bool ChannelDownmixer::getNextSample(float &sample) {
float sum = 0;
for (int channelIndex = 0; channelIndex < inputChannelCount; channelIndex++) {
float currentSample;
if (!inputStream->getNextSample(currentSample)) return false;
sum += currentSample;
}
sample = sum / inputChannelCount;
return true;
}

View File

@ -1,18 +0,0 @@
#pragma once
#include "AudioStream.h"
#include <memory>
// Converts a multi-channel audio stream to mono.
class ChannelDownmixer : public AudioStream {
public:
ChannelDownmixer(std::unique_ptr<AudioStream> inputStream);
virtual int getFrameRate() override;
virtual int getFrameCount() override;
virtual int getChannelCount() override;
virtual bool getNextSample(float &sample) override;
private:
std::unique_ptr<AudioStream> inputStream;
int inputChannelCount;
};

View File

@ -1,80 +1,94 @@
#include <cmath> #include <cmath>
#include "SampleRateConverter.h" #include "SampleRateConverter.h"
#include <stdexcept>
#include <algorithm>
using std::runtime_error; using std::runtime_error;
SampleRateConverter::SampleRateConverter(std::unique_ptr<AudioStream> inputStream, int outputFrameRate) : SampleRateConverter::SampleRateConverter(std::unique_ptr<AudioStream> inputStream, int outputFrameRate) :
inputStream(std::move(inputStream)), inputStream(std::move(inputStream)),
downscalingFactor(static_cast<double>(this->inputStream->getFrameRate()) / outputFrameRate), downscalingFactor(static_cast<double>(this->inputStream->getSampleRate()) / outputFrameRate),
outputFrameRate(outputFrameRate), outputFrameRate(outputFrameRate),
outputFrameCount(std::lround(this->inputStream->getFrameCount() / downscalingFactor)), outputFrameCount(std::lround(this->inputStream->getSampleCount() / downscalingFactor)),
lastInputSample(0), lastInputSample(0),
lastInputSampleIndex(-1), lastInputSampleIndex(-1),
nextOutputSampleIndex(0) nextOutputSampleIndex(0)
{ {
if (this->inputStream->getChannelCount() != 1) { if (this->inputStream->getSampleRate() < outputFrameRate) {
throw runtime_error("Only mono input streams are supported.");
}
if (this->inputStream->getFrameRate() < outputFrameRate) {
throw runtime_error("Upsampling not supported."); throw runtime_error("Upsampling not supported.");
} }
} }
int SampleRateConverter::getFrameRate() { SampleRateConverter::SampleRateConverter(const SampleRateConverter& rhs, bool reset) :
SampleRateConverter(rhs.inputStream->clone(reset), outputFrameRate)
{
nextOutputSampleIndex = reset ? 0 : rhs.nextOutputSampleIndex;
}
std::unique_ptr<AudioStream> SampleRateConverter::clone(bool reset) {
return std::make_unique<SampleRateConverter>(*this, reset);
}
int SampleRateConverter::getSampleRate() {
return outputFrameRate; return outputFrameRate;
} }
int SampleRateConverter::getFrameCount() { int SampleRateConverter::getSampleCount() {
return outputFrameCount; return outputFrameCount;
} }
int SampleRateConverter::getChannelCount() { int SampleRateConverter::getSampleIndex() {
return 1; return nextOutputSampleIndex;
} }
bool SampleRateConverter::getNextSample(float &sample) { void SampleRateConverter::seek(int sampleIndex) {
if (nextOutputSampleIndex >= outputFrameCount) return false; if (sampleIndex < 0 || sampleIndex >= outputFrameCount) throw std::invalid_argument("sampleIndex out of range.");
double start = nextOutputSampleIndex * downscalingFactor; nextOutputSampleIndex = sampleIndex;
double end = (nextOutputSampleIndex + 1) * downscalingFactor; }
float SampleRateConverter::readSample() {
if (nextOutputSampleIndex >= outputFrameCount) throw std::out_of_range("End of stream.");
double inputStart = nextOutputSampleIndex * downscalingFactor;
double inputEnd = (nextOutputSampleIndex + 1) * downscalingFactor;
sample = mean(start, end);
nextOutputSampleIndex++; nextOutputSampleIndex++;
return true; return mean(inputStart, inputEnd);
} }
float SampleRateConverter::mean(double start, double end) { float SampleRateConverter::mean(double inputStart, double inputEnd) {
// Calculate weighted sum... // Calculate weighted sum...
double sum = 0; double sum = 0;
// ... first sample (weight <= 1) // ... first sample (weight <= 1)
int startIndex = static_cast<int>(start); int startIndex = static_cast<int>(inputStart);
sum += getInputSample(startIndex) * ((startIndex + 1) - start); sum += getInputSample(startIndex) * ((startIndex + 1) - inputStart);
// ... middle samples (weight 1 each) // ... middle samples (weight 1 each)
int endIndex = static_cast<int>(end); int endIndex = static_cast<int>(inputEnd);
for (int index = startIndex + 1; index < endIndex; index++) { for (int index = startIndex + 1; index < endIndex; index++) {
sum += getInputSample(index); sum += getInputSample(index);
} }
// ... last sample (weight < 1) // ... last sample (weight < 1)
sum += getInputSample(endIndex) * (end - endIndex); sum += getInputSample(endIndex) * (inputEnd - endIndex);
return static_cast<float>(sum / (end - start)); return static_cast<float>(sum / (inputEnd - inputStart));
} }
float SampleRateConverter::getInputSample(int sampleIndex) { float SampleRateConverter::getInputSample(int sampleIndex) {
sampleIndex = std::min(sampleIndex, inputStream->getSampleCount() - 1);
if (sampleIndex < 0) return 0.0f;
if (sampleIndex == lastInputSampleIndex) { if (sampleIndex == lastInputSampleIndex) {
return lastInputSample; return lastInputSample;
} }
if (sampleIndex == lastInputSampleIndex + 1) {
lastInputSampleIndex++;
// Read the next sample.
// If the input stream has no more samples (at the very end),
// we'll just reuse the last sample we have.
inputStream->getNextSample(lastInputSample);
return lastInputSample;
}
throw runtime_error("Can only return the last sample or the one following it."); if (sampleIndex != inputStream->getSampleIndex()) {
inputStream->seek(sampleIndex);
}
lastInputSample = inputStream->readSample();
lastInputSampleIndex = sampleIndex;
return lastInputSample;
} }

View File

@ -1,22 +1,21 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include "AudioStream.h" #include "AudioStream.h"
class SampleRateConverter : public AudioStream { class SampleRateConverter : public AudioStream {
public: public:
SampleRateConverter(std::unique_ptr<AudioStream> inputStream, int outputFrameRate); SampleRateConverter(std::unique_ptr<AudioStream> inputStream, int outputFrameRate);
virtual int getFrameRate() override; SampleRateConverter(const SampleRateConverter& rhs, bool reset);
virtual int getFrameCount() override; std::unique_ptr<AudioStream> clone(bool reset) override;
virtual int getChannelCount() override; int getSampleRate() override;
virtual bool getNextSample(float &sample) override; int getSampleCount() override;
int getSampleIndex() override;
void seek(int sampleIndex) override;
float readSample() override;
private: private:
// The stream we're reading from
std::unique_ptr<AudioStream> inputStream; std::unique_ptr<AudioStream> inputStream;
double downscalingFactor; // input frame rate / output frame rate
// input frame rate / output frame rate
double downscalingFactor;
int outputFrameRate; int outputFrameRate;
int outputFrameCount; int outputFrameCount;

View File

@ -24,10 +24,12 @@ enum class Codec {
Float = 0x03 Float = 0x03
}; };
WaveFileReader::WaveFileReader(boost::filesystem::path filePath) { WaveFileReader::WaveFileReader(boost::filesystem::path filePath) :
// Open file filePath(filePath),
file.exceptions(std::ifstream::failbit | std::ifstream::badbit); file(),
file.open(filePath, std::ios::binary); sampleIndex(0)
{
openFile();
// Read header // Read header
uint32_t rootChunkId = read<uint32_t>(file); uint32_t rootChunkId = read<uint32_t>(file);
@ -42,111 +44,148 @@ WaveFileReader::WaveFileReader(boost::filesystem::path filePath) {
// Read chunks until we reach the data chunk // Read chunks until we reach the data chunk
bool reachedDataChunk = false; bool reachedDataChunk = false;
int bytesPerSample = 0; bytesPerSample = 0;
do { do {
uint32_t chunkId = read<uint32_t>(file); uint32_t chunkId = read<uint32_t>(file);
int chunkSize = read<uint32_t>(file); int chunkSize = read<uint32_t>(file);
switch (chunkId) { switch (chunkId) {
case fourcc('f', 'm', 't', ' '): { case fourcc('f', 'm', 't', ' '): {
// Read relevant data // Read relevant data
Codec codec = (Codec) read<uint16_t>(file); Codec codec = (Codec)read<uint16_t>(file);
channelCount = read<uint16_t>(file); channelCount = read<uint16_t>(file);
frameRate = read<uint32_t>(file); frameRate = read<uint32_t>(file);
read<uint32_t>(file); // Bytes per second read<uint32_t>(file); // Bytes per second
int frameSize = read<uint16_t>(file); int frameSize = read<uint16_t>(file);
int bitsPerSample = read<uint16_t>(file); int bitsPerSample = read<uint16_t>(file);
// We're read 16 bytes so far. Skip the remainder. // We've read 16 bytes so far. Skip the remainder.
file.seekg(roundToEven(chunkSize) - 16, file.cur); file.seekg(roundToEven(chunkSize) - 16, file.cur);
// Determine sample format // Determine sample format
switch (codec) { switch (codec) {
case Codec::PCM: case Codec::PCM:
// Determine sample size. // Determine sample size.
// According to the WAVE standard, sample sizes that are not multiples of 8 bits // According to the WAVE standard, sample sizes that are not multiples of 8 bits
// (e.g. 12 bits) can be treated like the next-larger byte size. // (e.g. 12 bits) can be treated like the next-larger byte size.
if (bitsPerSample == 8) { if (bitsPerSample == 8) {
sampleFormat = SampleFormat::UInt8; sampleFormat = SampleFormat::UInt8;
bytesPerSample = 1; bytesPerSample = 1;
} else if (bitsPerSample <= 16) { } else if (bitsPerSample <= 16) {
sampleFormat = SampleFormat::Int16; sampleFormat = SampleFormat::Int16;
bytesPerSample = 2; bytesPerSample = 2;
} else if (bitsPerSample <= 24) { } else if (bitsPerSample <= 24) {
sampleFormat = SampleFormat::Int24; sampleFormat = SampleFormat::Int24;
bytesPerSample = 3; bytesPerSample = 3;
} else { } else {
throw runtime_error( throw runtime_error(
format("Unsupported sample format: {}-bit integer samples.", bitsPerSample)); format("Unsupported sample format: {}-bit integer samples.", bitsPerSample));
} }
if (bytesPerSample != frameSize / channelCount) { if (bytesPerSample != frameSize / channelCount) {
throw runtime_error("Unsupported sample organization."); throw runtime_error("Unsupported sample organization.");
}
break;
case Codec::Float:
if (bitsPerSample == 32) {
sampleFormat = SampleFormat::Float32;
bytesPerSample = 4;
} else {
throw runtime_error(format("Unsupported sample format: {}-bit floating-point samples.", bitsPerSample));
}
break;
default:
throw runtime_error("Unsupported sample format. Only uncompressed formats are supported.");
} }
break; break;
} case Codec::Float:
case fourcc('d', 'a', 't', 'a'): { if (bitsPerSample == 32) {
reachedDataChunk = true; sampleFormat = SampleFormat::Float32;
remainingSamples = chunkSize / bytesPerSample; bytesPerSample = 4;
frameCount = remainingSamples / channelCount; } else {
break; throw runtime_error(format("Unsupported sample format: {}-bit floating-point samples.", bitsPerSample));
} }
default: {
// Skip unknown chunk
file.seekg(roundToEven(chunkSize), file.cur);
break; break;
default:
throw runtime_error("Unsupported sample format. Only uncompressed formats are supported.");
} }
break;
}
case fourcc('d', 'a', 't', 'a'): {
reachedDataChunk = true;
dataOffset = file.tellg();
sampleCount = chunkSize / bytesPerSample;
frameCount = sampleCount / channelCount;
break;
}
default: {
// Skip unknown chunk
file.seekg(roundToEven(chunkSize), file.cur);
break;
}
} }
} while (!reachedDataChunk); } while (!reachedDataChunk);
} }
int WaveFileReader::getFrameRate() { WaveFileReader::WaveFileReader(const WaveFileReader& rhs, bool reset) :
filePath(rhs.filePath),
file(),
bytesPerSample(rhs.bytesPerSample),
sampleFormat(rhs.sampleFormat),
frameRate(rhs.frameRate),
frameCount(rhs.frameCount),
channelCount(rhs.channelCount),
sampleCount(rhs.sampleCount),
dataOffset(rhs.dataOffset),
sampleIndex(-1)
{
openFile();
seek(reset ? 0 : rhs.sampleIndex);
}
std::unique_ptr<AudioStream> WaveFileReader::clone(bool reset) {
return std::make_unique<WaveFileReader>(*this, reset);
}
void WaveFileReader::openFile() {
file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
file.open(filePath, std::ios::binary);
}
int WaveFileReader::getSampleRate() {
return frameRate; return frameRate;
} }
int WaveFileReader::getFrameCount() { int WaveFileReader::getSampleCount() {
return frameCount; return frameCount;
} }
int WaveFileReader::getChannelCount() { int WaveFileReader::getSampleIndex() {
return channelCount; return sampleIndex;
} }
bool WaveFileReader::getNextSample(float &sample) { void WaveFileReader::seek(int sampleIndex) {
if (remainingSamples == 0) return false; if (sampleIndex < 0 || sampleIndex >= sampleCount) throw std::invalid_argument("sampleIndex out of range.");
remainingSamples--;
switch (sampleFormat) { file.seekg(dataOffset + sampleIndex * channelCount * bytesPerSample);
this->sampleIndex = sampleIndex;
}
float WaveFileReader::readSample() {
if (sampleIndex + channelCount > sampleCount) throw std::out_of_range("End of stream.");
sampleIndex += channelCount;
float sum = 0;
for (int channelIndex = 0; channelIndex < channelCount; channelIndex++) {
switch (sampleFormat) {
case SampleFormat::UInt8: { case SampleFormat::UInt8: {
uint8_t raw = read<uint8_t>(file); uint8_t raw = read<uint8_t>(file);
sample = toNormalizedFloat(raw, 0, UINT8_MAX); sum += toNormalizedFloat(raw, 0, UINT8_MAX);
break; break;
} }
case SampleFormat::Int16: { case SampleFormat::Int16: {
int16_t raw = read<int16_t>(file); int16_t raw = read<int16_t>(file);
sample = toNormalizedFloat(raw, INT16_MIN, INT16_MAX); sum += toNormalizedFloat(raw, INT16_MIN, INT16_MAX);
break; break;
} }
case SampleFormat::Int24: { case SampleFormat::Int24: {
int raw = read<int, 24>(file); int raw = read<int, 24>(file);
if (raw & 0x800000) raw |= 0xFF000000; // Fix two's complement if (raw & 0x800000) raw |= 0xFF000000; // Fix two's complement
sample = toNormalizedFloat(raw, INT24_MIN, INT24_MAX); sum += toNormalizedFloat(raw, INT24_MIN, INT24_MAX);
break; break;
} }
case SampleFormat::Float32: { case SampleFormat::Float32: {
sample = read<float>(file); sum += read<float>(file);
break; break;
} }
}
} }
return true;
return sum / channelCount;
} }

View File

@ -1,8 +1,5 @@
#pragma once #pragma once
#include <string>
#include <cstdint>
#include <fstream>
#include <boost/filesystem/path.hpp> #include <boost/filesystem/path.hpp>
#include <boost/filesystem/fstream.hpp> #include <boost/filesystem/fstream.hpp>
#include "AudioStream.h" #include "AudioStream.h"
@ -17,16 +14,26 @@ enum class SampleFormat {
class WaveFileReader : public AudioStream { class WaveFileReader : public AudioStream {
public: public:
WaveFileReader(boost::filesystem::path filePath); WaveFileReader(boost::filesystem::path filePath);
virtual int getFrameRate() override ; WaveFileReader(const WaveFileReader& rhs, bool reset);
virtual int getFrameCount() override; std::unique_ptr<AudioStream> clone(bool reset) override;
virtual int getChannelCount() override; int getSampleRate() override ;
virtual bool getNextSample(float &sample) override; int getSampleCount() override;
int getSampleIndex() override;
void seek(int sampleIndex) override;
float readSample() override;
private: private:
void openFile();
private:
boost::filesystem::path filePath;
boost::filesystem::ifstream file; boost::filesystem::ifstream file;
int bytesPerSample;
SampleFormat sampleFormat; SampleFormat sampleFormat;
int frameRate; int frameRate;
int frameCount; int frameCount;
int channelCount; int channelCount;
int remainingSamples; int sampleCount;
size_t dataOffset;
int sampleIndex;
}; };

View File

@ -13,9 +13,9 @@ void createWaveFile(std::unique_ptr<AudioStream> inputStream, std::string fileNa
// Write RIFF chunk // Write RIFF chunk
write<uint32_t>(fourcc('R', 'I', 'F', 'F'), file); write<uint32_t>(fourcc('R', 'I', 'F', 'F'), file);
uint32_t formatChunkSize = 16; uint32_t formatChunkSize = 16;
uint16_t channelCount = static_cast<uint16_t>(inputStream->getChannelCount()); uint16_t channelCount = 1;
uint16_t frameSize = static_cast<uint16_t>(channelCount * sizeof(float)); uint16_t frameSize = static_cast<uint16_t>(channelCount * sizeof(float));
uint32_t dataChunkSize = static_cast<uint32_t>(inputStream->getFrameCount() * frameSize); uint32_t dataChunkSize = static_cast<uint32_t>(inputStream->getSampleCount() * frameSize);
uint32_t riffChunkSize = 4 + (8 + formatChunkSize) + (8 + dataChunkSize); uint32_t riffChunkSize = 4 + (8 + formatChunkSize) + (8 + dataChunkSize);
write<uint32_t>(riffChunkSize, file); write<uint32_t>(riffChunkSize, file);
write<uint32_t>(fourcc('W', 'A', 'V', 'E'), file); write<uint32_t>(fourcc('W', 'A', 'V', 'E'), file);
@ -26,7 +26,7 @@ void createWaveFile(std::unique_ptr<AudioStream> inputStream, std::string fileNa
uint16_t codec = 0x03; // 32-bit float uint16_t codec = 0x03; // 32-bit float
write<uint16_t>(codec, file); write<uint16_t>(codec, file);
write<uint16_t>(channelCount, file); write<uint16_t>(channelCount, file);
uint32_t frameRate = static_cast<uint16_t>(inputStream->getFrameRate()); uint32_t frameRate = static_cast<uint16_t>(inputStream->getSampleRate());
write<uint32_t>(frameRate, file); write<uint32_t>(frameRate, file);
uint32_t bytesPerSecond = frameRate * frameSize; uint32_t bytesPerSecond = frameRate * frameSize;
write<uint32_t>(bytesPerSecond, file); write<uint32_t>(bytesPerSecond, file);
@ -37,8 +37,8 @@ void createWaveFile(std::unique_ptr<AudioStream> inputStream, std::string fileNa
// Write data chunk // Write data chunk
write<uint32_t>(fourcc('d', 'a', 't', 'a'), file); write<uint32_t>(fourcc('d', 'a', 't', 'a'), file);
write<uint32_t>(dataChunkSize, file); write<uint32_t>(dataChunkSize, file);
float sample; while (!inputStream->endOfStream()) {
while (inputStream->getNextSample(sample)) { float sample = inputStream->readSample();
write<float>(sample, file); write<float>(sample, file);
} }
} }

View File

@ -107,7 +107,7 @@ int main(int argc, char *argv[]) {
{ {
ProgressBar progressBar; ProgressBar progressBar;
phones = detectPhones( phones = detectPhones(
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, createAudioStream(inputFileName.getValue()),
dialog.getValue(), dialog.getValue(),
progressBar); progressBar);
} }

View File

@ -3,7 +3,6 @@
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
#include "phoneExtraction.h" #include "phoneExtraction.h"
#include "audioInput/SampleRateConverter.h" #include "audioInput/SampleRateConverter.h"
#include "audioInput/ChannelDownmixer.h"
#include "platformTools.h" #include "platformTools.h"
#include "tools.h" #include "tools.h"
#include <format.h> #include <format.h>
@ -33,17 +32,12 @@ using std::regex;
using std::regex_replace; using std::regex_replace;
using std::chrono::duration; using std::chrono::duration;
unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) { unique_ptr<AudioStream> to16kHz(unique_ptr<AudioStream> stream) {
// Downmix, if required
if (stream->getChannelCount() != 1) {
stream.reset(new ChannelDownmixer(std::move(stream)));
}
// Downsample, if required // Downsample, if required
if (stream->getFrameRate() < 16000) { if (stream->getSampleRate() < 16000) {
throw invalid_argument("Audio sample rate must not be below 16kHz."); throw invalid_argument("Audio sample rate must not be below 16kHz.");
} }
if (stream->getFrameRate() != 16000) { if (stream->getSampleRate() != 16000) {
stream.reset(new SampleRateConverter(std::move(stream), 16000)); stream.reset(new SampleRateConverter(std::move(stream), 16000));
} }
@ -88,7 +82,7 @@ int16_t floatSampleToInt16(float sample) {
return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN); return static_cast<int16_t>(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN);
} }
void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) { void processAudioStream(AudioStream& audioStream16kHz, function<void(const vector<int16_t>&)> processBuffer, ProgressSink& progressSink) {
// Process entire sound file // Process entire sound file
vector<int16_t> buffer; vector<int16_t> buffer;
const int capacity = 1600; // 0.1 second capacity const int capacity = 1600; // 0.1 second capacity
@ -97,10 +91,9 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
do { do {
// Read to buffer // Read to buffer
buffer.clear(); buffer.clear();
while (buffer.size() < capacity) { while (buffer.size() < capacity && !audioStream16kHz.endOfStream()) {
// Read sample // Read sample
float floatSample; float floatSample = audioStream16kHz.readSample();
if (!audioStream16kHzMono.getNextSample(floatSample)) break;
int16_t sample = floatSampleToInt16(floatSample); int16_t sample = floatSampleToInt16(floatSample);
buffer.push_back(sample); buffer.push_back(sample);
} }
@ -109,7 +102,7 @@ void processAudioStream(AudioStream& audioStream16kHzMono, function<void(const v
processBuffer(buffer); processBuffer(buffer);
sampleCount += buffer.size(); sampleCount += buffer.size();
progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHzMono.getFrameCount()); progressSink.reportProgress(static_cast<double>(sampleCount) / audioStream16kHz.getSampleCount());
} while (buffer.size()); } while (buffer.size());
} }
@ -158,7 +151,7 @@ void sphinxLogCallback(void* user_data, err_lvl_t errorLevel, const char* format
vector<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { vector<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Convert audio stream to the exact format PocketSphinx requires // Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream)); audioStream = to16kHz(std::move(audioStream));
// Start recognition // Start recognition
int error = ps_start_utt(&recognizer); int error = ps_start_utt(&recognizer);
@ -243,7 +236,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
if (error) throw runtime_error("Error populating alignment struct."); if (error) throw runtime_error("Error populating alignment struct.");
// Convert audio stream to the exact format PocketSphinx requires // Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream)); audioStream = to16kHz(std::move(audioStream));
// Create search structure // Create search structure
acmod_t* acousticModel = recognizer.acmod; acmod_t* acousticModel = recognizer.acmod;
@ -304,7 +297,7 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
} }
map<centiseconds, Phone> detectPhones( map<centiseconds, Phone> detectPhones(
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, unique_ptr<AudioStream> audioStream,
boost::optional<std::string> dialog, boost::optional<std::string> dialog,
ProgressSink& progressSink) ProgressSink& progressSink)
{ {
@ -329,13 +322,13 @@ map<centiseconds, Phone> detectPhones(
// Get words // Get words
vector<string> words = dialog vector<string> words = dialog
? extractDialogWords(*dialog) ? extractDialogWords(*dialog)
: recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink); : recognizeWords(audioStream->clone(true), *recognizer.get(), wordRecognitionProgressSink);
// Look up words in dictionary // Look up words in dictionary
vector<s3wid_t> wordIds = getWordIds(words, *recognizer->dict); vector<s3wid_t> wordIds = getWordIds(words, *recognizer->dict);
// Align the word's phones with speech // Align the word's phones with speech
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink); map<centiseconds, Phone> result = getPhoneAlignment(wordIds, std::move(audioStream), *recognizer.get(), alignmentProgressSink);
return result; return result;
} }
catch (...) { catch (...) {

View File

@ -2,7 +2,6 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <functional>
#include "audioInput/AudioStream.h" #include "audioInput/AudioStream.h"
#include "Phone.h" #include "Phone.h"
#include "centiseconds.h" #include "centiseconds.h"
@ -10,6 +9,6 @@
#include <boost/optional/optional.hpp> #include <boost/optional/optional.hpp>
std::map<centiseconds, Phone> detectPhones( std::map<centiseconds, Phone> detectPhones(
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, std::unique_ptr<AudioStream> audioStream,
boost::optional<std::string> dialog, boost::optional<std::string> dialog,
ProgressSink& progressSink); ProgressSink& progressSink);