audio_processing VAD annotations in APM-qa.

Added possibility to extract audio_processing VAD annotations in the Quality Assessment tool. 
Annotations are extracted into compressed Numpy 'annotations.npz' files.
Annotations contain information about VAD, speech level, speech probabilities etc.

TBR=alessiob@webrtc.org

Bug: webrtc:7494
Change-Id: I0e54bb67132ae4e180f89959b8bca3ea7f259458
Reviewed-on: https://webrtc-review.googlesource.com/17840
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20581}
This commit is contained in:
Alex Loiko 2017-11-07 10:51:20 +01:00 committed by Commit Bot
parent 360742078b
commit 3e83b7fe8d
8 changed files with 270 additions and 57 deletions

View File

@ -99,6 +99,7 @@ group("unit_tests") {
testonly = true
visibility = [ ":*" ] # Only targets in this file can depend on this.
deps = [
":apm_vad",
":fake_polqa",
":lib_unit_tests",
":scripts_unit_tests",
@ -130,6 +131,18 @@ rtc_executable("vad") {
]
}
rtc_executable("apm_vad") {
sources = [
"quality_assessment/apm_vad.cc",
]
deps = [
"../..",
"../../../..:webrtc_common",
"../../../../common_audio",
"../../../../rtc_base:rtc_base_approved",
]
}
copy("lib_unit_tests") {
testonly = true
sources = [

View File

@ -10,7 +10,6 @@
"""
from __future__ import division
import enum
import logging
import os
import shutil
@ -33,10 +32,30 @@ class AudioAnnotationsExtractor(object):
"""Extracts annotations from audio files.
"""
@enum.unique
class VadType(enum.Enum):
ENERGY_THRESHOLD = 0 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC = 1
# TODO(aleloi): change to enum.IntEnum when py 3.6 is available.
class VadType(object):
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
def __init__(self, value):
if (not isinstance(value, int)) or not 0 <= value <= 7:
raise exceptions.InitializationException(
'Invalid vad type: ' + value)
self._value = value
def Contains(self, vad_type):
return self._value | vad_type == self._value
def __str__(self):
vads = []
if self.Contains(self.ENERGY_THRESHOLD):
vads.append("energy")
if self.Contains(self.WEBRTC_COMMON_AUDIO):
vads.append("common_audio")
if self.Contains(self.WEBRTC_APM):
vads.append("apm")
return "VadType({})".format(", ".join(vads))
_OUTPUT_FILENAME = 'annotations.npz'
@ -52,25 +71,31 @@ class AudioAnnotationsExtractor(object):
_VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
os.path.abspath(__file__)), os.pardir, os.pardir)
_VAD_WEBRTC_BIN_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
_VAD_WEBRTC_APM_PATH = os.path.join(
_VAD_WEBRTC_PATH, 'apm_vad')
def __init__(self, vad_type):
self._signal = None
self._level = None
self._level_frame_size = None
self._vad_output = None
self._common_audio_vad = None
self._energy_vad = None
self._apm_vad_probs = None
self._apm_vad_rms = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None
self._c_decay = None
self._vad_type = vad_type
if self._vad_type not in self.VadType:
raise exceptions.InitializationException(
'Invalid vad type: ' + self._vad_type)
logging.info('VAD used for annotations: ' + str(self._vad_type))
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
assert os.path.exists(self._VAD_WEBRTC_BIN_PATH), self._VAD_WEBRTC_BIN_PATH
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
self._VAD_WEBRTC_APM_PATH
@classmethod
def GetOutputFileName(cls):
@ -86,8 +111,16 @@ class AudioAnnotationsExtractor(object):
def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self):
return self._vad_output
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return (self._energy_vad, )
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return (self._common_audio_vad, )
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
raise exceptions.InitializationException(
'Invalid vad type: ' + vad_type)
def GetVadFrameSize(self):
return self._vad_frame_size
@ -115,15 +148,18 @@ class AudioAnnotationsExtractor(object):
self._LevelEstimation()
# Ideal VAD output, it requires clean speech with high SNR as input.
if self._vad_type == self.VadType.ENERGY_THRESHOLD:
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
# Naive VAD based on level thresholding.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._vad_output = np.uint8(self._level > vad_threshold)
self._energy_vad = np.uint8(self._level > vad_threshold)
self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
elif self._vad_type == self.VadType.WEBRTC:
# WebRTC VAD.
self._RunWebRtcVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
# WebRTC common_audio/ VAD.
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
def Save(self, output_path):
np.savez_compressed(
@ -131,9 +167,13 @@ class AudioAnnotationsExtractor(object):
level=self._level,
level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._vad_output,
vad_output=self._common_audio_vad,
vad_energy_output=self._energy_vad,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms)
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms
)
def _LevelEstimation(self):
# Read samples.
@ -155,8 +195,8 @@ class AudioAnnotationsExtractor(object):
self._level[i], self._level[i - 1], self._c_attack if (
self._level[i] > self._level[i - 1]) else self._c_decay)
def _RunWebRtcVad(self, wav_file_path, sample_rate):
self._vad_output = None
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
self._common_audio_vad = None
self._vad_frame_size = None
# Create temporary output path.
@ -167,7 +207,7 @@ class AudioAnnotationsExtractor(object):
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_BIN_PATH,
self._VAD_WEBRTC_COMMON_AUDIO_PATH,
'-i', wav_file_path,
'-o', output_file_path
], cwd=self._VAD_WEBRTC_PATH)
@ -186,16 +226,45 @@ class AudioAnnotationsExtractor(object):
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
self._vad_output = np.zeros(num_frames, np.uint8)
self._common_audio_vad = np.zeros(num_frames, np.uint8)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._vad_output[i * 8 + j] = int(byte & 1)
self._common_audio_vad[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def _RunWebRtcApmVad(self, wav_file_path):
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path_probs = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
output_file_path_rms = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_APM_PATH,
'-i', wav_file_path,
'-o_probs', output_file_path_probs,
'-o_rms', output_file_path_rms
], cwd=self._VAD_WEBRTC_PATH)
# Parse annotations.
self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double)
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
except Exception as e:
logging.error('Error while running the WebRTC APM VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)

View File

@ -49,22 +49,49 @@ class TestAnnotationsExtraction(unittest.TestCase):
self._tmp_path))
def testFrameSizes(self):
for vad_type in annotations.AudioAnnotationsExtractor.VadType:
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
vad_type_class = annotations.AudioAnnotationsExtractor.VadType
vad_type = (vad_type_class.ENERGY_THRESHOLD |
vad_type_class.WEBRTC_COMMON_AUDIO |
vad_type_class.WEBRTC_APM)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testVoiceActivityDetectors(self):
for vad_type in annotations.AudioAnnotationsExtractor.VadType:
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
vad_type_class = annotations.AudioAnnotationsExtractor.VadType
max_vad_type = (vad_type_class.ENERGY_THRESHOLD |
vad_type_class.WEBRTC_COMMON_AUDIO |
vad_type_class.WEBRTC_APM)
for vad_type_value in range(0, max_vad_type+1):
vad_type = vad_type_class(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
vad_output = e.GetVadOutput()
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95)
if vad_type.Contains(vad_type_class.ENERGY_THRESHOLD):
# pylint: disable=unbalanced-tuple-unpacking
(vad_output, ) = e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(vad_type_class.WEBRTC_COMMON_AUDIO):
# pylint: disable=unbalanced-tuple-unpacking
(vad_output,) = e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(vad_type_class.WEBRTC_APM):
# pylint: disable=unbalanced-tuple-unpacking
(vad_probs, vad_rms) = e.GetVadOutput(vad_type_class.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
0.95)
self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
@ -84,13 +111,26 @@ class TestAnnotationsExtraction(unittest.TestCase):
plt.show()
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor(
vad_type=annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD)
vad_type_class = annotations.AudioAnnotationsExtractor.VadType
vad_type = (vad_type_class.ENERGY_THRESHOLD |
vad_type_class.WEBRTC_COMMON_AUDIO |
vad_type_class.WEBRTC_APM)
e = annotations.AudioAnnotationsExtractor(vad_type)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path)
data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName()))
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(e.GetVadOutput(), data['vad_output'])
self.assertEqual(np.uint8, data['vad_output'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO), data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.WEBRTC_APM)[0], data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.WEBRTC_APM)[1], data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)

View File

@ -0,0 +1,93 @@
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
//
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file in the root of the source
// tree. An additional intellectual property rights grant can be found
// in the file PATENTS. All contributing project authors may
// be found in the AUTHORS file in the root of the source tree.
#include <array>
#include <fstream>
#include <memory>
#include "common_audio/wav_file.h"
#include "modules/audio_processing/vad/voice_activity_detector.h"
#include "rtc_base/flags.h"
#include "rtc_base/logging.h"
namespace webrtc {
namespace test {
namespace {
constexpr uint8_t kAudioFrameLengthMilliseconds = 10;
constexpr int kMaxSampleRate = 48000;
constexpr size_t kMaxFrameLen =
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
DEFINE_string(i, "", "Input wav file");
DEFINE_string(o_probs, "", "VAD probabilities output file");
DEFINE_string(o_rms, "", "VAD output file");
int main(int argc, char* argv[]) {
if (rtc::FlagList::SetFlagsFromCommandLine(&argc, argv, true))
return 1;
// Open wav input file and check properties.
WavReader wav_reader(FLAG_i);
if (wav_reader.num_channels() != 1) {
LOG(LS_ERROR) << "Only mono wav files supported";
return 1;
}
if (wav_reader.sample_rate() > kMaxSampleRate) {
LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate << ")";
return 1;
}
const size_t audio_frame_len = rtc::CheckedDivExact(
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
if (audio_frame_len > kMaxFrameLen) {
LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
return 1;
}
// Create output file and write header.
std::ofstream out_probs_file(FLAG_o_probs, std::ofstream::binary);
std::ofstream out_rms_file(FLAG_o_rms, std::ofstream::binary);
// Run VAD and write decisions.
VoiceActivityDetector vad;
std::array<int16_t, kMaxFrameLen> samples;
while (true) {
// Process frame.
const auto read_samples =
wav_reader.ReadSamples(audio_frame_len, samples.data());
if (read_samples < audio_frame_len) {
break;
}
vad.ProcessChunk(samples.data(), audio_frame_len, wav_reader.sample_rate());
// Write output.
auto probs = vad.chunkwise_voice_probabilities();
auto rms = vad.chunkwise_rms();
RTC_CHECK_EQ(probs.size(), rms.size());
RTC_CHECK_EQ(sizeof(double), 8);
for (const auto& p : probs) {
out_probs_file.write(reinterpret_cast<const char*>(&p), 8);
}
for (const auto& r : rms) {
out_rms_file.write(reinterpret_cast<const char*>(&r), 8);
}
}
out_probs_file.close();
out_rms_file.close();
return 0;
}
} // namespace
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}

View File

@ -47,7 +47,9 @@ class ApmModuleSimulator(object):
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor(
vad_type=annotations.AudioAnnotationsExtractor.VadType.WEBRTC)
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(

View File

@ -44,11 +44,11 @@ int main(int argc, char* argv[]) {
LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate << ")";
return 1;
}
const size_t kAudioFrameLen = rtc::CheckedDivExact(
const size_t audio_frame_length = rtc::CheckedDivExact(
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
if (kAudioFrameLen > kMaxFrameLen) {
if (audio_frame_length > kMaxFrameLen) {
LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
return 2;
return 1;
}
// Create output file and write header.
@ -64,11 +64,11 @@ int main(int argc, char* argv[]) {
while (true) {
// Process frame.
const auto read_samples =
wav_reader.ReadSamples(kAudioFrameLen, samples.data());
if (read_samples < kAudioFrameLen)
wav_reader.ReadSamples(audio_frame_length, samples.data());
if (read_samples < audio_frame_length)
break;
const auto is_speech = vad->VoiceActivity(samples.data(), kAudioFrameLen,
wav_reader.sample_rate());
const auto is_speech = vad->VoiceActivity(
samples.data(), audio_frame_length, wav_reader.sample_rate());
// Write output.
buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);

View File

@ -17,7 +17,6 @@
namespace webrtc {
namespace {
const size_t kMaxLength = 320;
const size_t kNumChannels = 1;
const double kDefaultVoiceValue = 1.0;
@ -40,7 +39,6 @@ void VoiceActivityDetector::ProcessChunk(const int16_t* audio,
size_t length,
int sample_rate_hz) {
RTC_DCHECK_EQ(length, sample_rate_hz / 100);
RTC_DCHECK_LE(length, kMaxLength);
// Resample to the required rate.
const int16_t* resampled_ptr = audio;
if (sample_rate_hz != kSampleRateHz) {

View File

@ -29,9 +29,7 @@ class VoiceActivityDetector {
VoiceActivityDetector();
~VoiceActivityDetector();
// Processes each audio chunk and estimates the voice probability. The maximum
// supported sample rate is 32kHz.
// TODO(aluebs): Change |length| to size_t.
// Processes each audio chunk and estimates the voice probability.
void ProcessChunk(const int16_t* audio, size_t length, int sample_rate_hz);
// Returns a vector of voice probabilities for each chunk. It can be empty for