Support for external VAD program in APM-QA

There is now an 'ExternalVad' class in the AnnotationsExtractor. The
Extractor takes an extra list of these in addition to the other
VADs. The external VAD runs an external program to generate the
annotations. Annotations are loaded and saved to a compressed Numpy format.

Also made a small fix to name a mixed file in a way so that files will not
be overwritten.

Also did some minor changes to the unittests.
TBR=alessiob@webrtc.org

Bug: webrtc:7494
Change-Id: I7816b04466be16cd635ac6ceab18cd7aad5325a4
Reviewed-on: https://webrtc-review.googlesource.com/23623
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20819}
This commit is contained in:
Alex Loiko 2017-11-21 13:21:28 +01:00 committed by Commit Bot
parent c61ce0d0cd
commit 10dd7ed81a
9 changed files with 209 additions and 41 deletions

View File

@ -66,6 +66,7 @@ copy("lib") {
"quality_assessment/exceptions.py",
"quality_assessment/export.py",
"quality_assessment/export_unittest.py",
"quality_assessment/external_vad.py",
"quality_assessment/input_mixer.py",
"quality_assessment/input_signal_creator.py",
"quality_assessment/results.css",
@ -149,6 +150,7 @@ copy("lib_unit_tests") {
"quality_assessment/annotations_unittest.py",
"quality_assessment/echo_path_simulation_unittest.py",
"quality_assessment/eval_scores_unittest.py",
"quality_assessment/fake_external_vad.py",
"quality_assessment/input_mixer_unittest.py",
"quality_assessment/signal_processing_unittest.py",
"quality_assessment/simulation_unittest.py",

View File

@ -27,6 +27,7 @@ import quality_assessment.echo_path_simulation as echo_path_simulation
import quality_assessment.eval_scores as eval_scores
import quality_assessment.evaluation as evaluation
import quality_assessment.eval_scores_factory as eval_scores_factory
import quality_assessment.external_vad as external_vad
import quality_assessment.test_data_generation as test_data_generation
import quality_assessment.test_data_generation_factory as \
test_data_generation_factory
@ -113,6 +114,14 @@ def _InstanceArgumentsParser():
'copy of the clean speech input file.'),
default=False)
parser.add_argument('--external_vad_paths', nargs='+', required=False,
help=('Paths to external VAD programs. Each must take'
'\'-i <wav file> -o <output>\' inputs'), default=[])
parser.add_argument('--external_vad_names', nargs='+', required=False,
help=('Keys to the vad paths. Must be different and '
'as many as the paths.'), default=[])
return parser
@ -128,6 +137,12 @@ def _ValidateArguments(args, parser):
'also required')
sys.exit(1)
if len(args.external_vad_names) != len(args.external_vad_paths):
parser.error('If provided, --external_vad_paths and '
'--external_vad_names must '
'have the same number of arguments.')
sys.exit(1)
def main():
# TODO(alessiob): level = logging.INFO once debugged.
@ -145,7 +160,9 @@ def main():
evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
evaluator=evaluation.ApmModuleEvaluator())
evaluator=evaluation.ApmModuleEvaluator(),
external_vads=external_vad.ExternalVad.ConstructVadDict(
args.external_vad_paths, args.external_vad_names))
simulator.Run(
config_filepaths=args.config_files,
capture_input_filepaths=args.capture_input_files,

View File

@ -24,6 +24,7 @@ except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import external_vad
from . import exceptions
from . import signal_processing
@ -76,7 +77,7 @@ class AudioAnnotationsExtractor(object):
_VAD_WEBRTC_APM_PATH = os.path.join(
_VAD_WEBRTC_PATH, 'apm_vad')
def __init__(self, vad_type):
def __init__(self, vad_type, external_vads=None):
self._signal = None
self._level = None
self._level_frame_size = None
@ -92,6 +93,19 @@ class AudioAnnotationsExtractor(object):
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
if external_vads is None:
external_vads = {}
self._external_vads = external_vads
assert len(self._external_vads) == len(external_vads), (
'The external VAD names must be unique.')
for vad in external_vads.values():
if not isinstance(vad, external_vad.ExternalVad):
raise exceptions.InitializationException(
'Invalid vad type: ' + str(type(vad)))
logging.info('External VAD used for annotation: ' +
str(vad.name))
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
@ -113,9 +127,9 @@ class AudioAnnotationsExtractor(object):
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return (self._energy_vad, )
return self._energy_vad
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return (self._common_audio_vad, )
return self._common_audio_vad
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
@ -132,7 +146,7 @@ class AudioAnnotationsExtractor(object):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
if self._signal.channels != 1:
raise NotImplementedError('multiple-channel annotations not implemented')
raise NotImplementedError('Multiple-channel annotations not implemented')
# Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 * (
@ -160,8 +174,14 @@ class AudioAnnotationsExtractor(object):
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
for extvad_name in self._external_vads:
self._external_vads[extvad_name].Run(filepath)
def Save(self, output_path):
ext_kwargs = {'extvad_conf-' + ext_vad:
self._external_vads[ext_vad].GetVadOutput()
for ext_vad in self._external_vads}
# pylint: disable=star-args
np.savez_compressed(
file=os.path.join(output_path, self._OUTPUT_FILENAME),
level=self._level,
@ -172,7 +192,8 @@ class AudioAnnotationsExtractor(object):
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms
vad_rms=self._apm_vad_rms,
**ext_kwargs
)
def _LevelEstimation(self):

View File

@ -19,6 +19,7 @@ import unittest
import numpy as np
from . import annotations
from . import external_vad
from . import input_signal_creator
from . import signal_processing
@ -29,6 +30,11 @@ class TestAnnotationsExtraction(unittest.TestCase):
_CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
_VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
_VAD_TYPE_CLASS.WEBRTC_APM)
def setUp(self):
"""Create temporary folder."""
@ -49,11 +55,7 @@ class TestAnnotationsExtraction(unittest.TestCase):
self._tmp_path))
def testFrameSizes(self):
vad_type_class = annotations.AudioAnnotationsExtractor.VadType
vad_type = (vad_type_class.ENERGY_THRESHOLD |
vad_type_class.WEBRTC_COMMON_AUDIO |
vad_type_class.WEBRTC_APM)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
@ -62,35 +64,31 @@ class TestAnnotationsExtraction(unittest.TestCase):
e.GetVadFrameSizeMs())
def testVoiceActivityDetectors(self):
vad_type_class = annotations.AudioAnnotationsExtractor.VadType
max_vad_type = (vad_type_class.ENERGY_THRESHOLD |
vad_type_class.WEBRTC_COMMON_AUDIO |
vad_type_class.WEBRTC_APM)
for vad_type_value in range(0, max_vad_type+1):
vad_type = vad_type_class(vad_type_value)
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
if vad_type.Contains(vad_type_class.ENERGY_THRESHOLD):
# pylint: disable=unbalanced-tuple-unpacking
(vad_output, ) = e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD)
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(vad_type_class.WEBRTC_COMMON_AUDIO):
# pylint: disable=unbalanced-tuple-unpacking
(vad_output,) = e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(vad_type_class.WEBRTC_APM):
# pylint: disable=unbalanced-tuple-unpacking
(vad_probs, vad_rms) = e.GetVadOutput(vad_type_class.WEBRTC_APM)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
# pylint: disable=unpacking-non-sequence
(vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
0.95)
0.5)
self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
if self._DEBUG_PLOT_VAD:
@ -111,11 +109,7 @@ class TestAnnotationsExtraction(unittest.TestCase):
plt.show()
def testSaveLoad(self):
vad_type_class = annotations.AudioAnnotationsExtractor.VadType
vad_type = (vad_type_class.ENERGY_THRESHOLD |
vad_type_class.WEBRTC_COMMON_AUDIO |
vad_type_class.WEBRTC_APM)
e = annotations.AudioAnnotationsExtractor(vad_type)
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path)
@ -123,14 +117,37 @@ class TestAnnotationsExtraction(unittest.TestCase):
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD),
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO), data['vad_output'])
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.WEBRTC_APM)[0], data['vad_probs'])
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(vad_type_class.WEBRTC_APM)[1], data['vad_rms'])
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)
def testEmptyExternalShouldNotCrash(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
annotations.AudioAnnotationsExtractor(vad_type_value, {})
def testFakeExternalSaveLoad(self):
def FakeExternalFactory():
return external_vad.ExternalVad(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
'fake'
)
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
e = annotations.AudioAnnotationsExtractor(
vad_type_value,
{'fake': FakeExternalFactory()})
e.Extract(self._wav_file_path)
e.Save(self._tmp_path)
data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName()))
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
data['extvad_conf-fake'])

View File

@ -0,0 +1,77 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
from __future__ import division
import logging
import os
import subprocess
import shutil
import sys
import tempfile
try:
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import signal_processing
class ExternalVad(object):
def __init__(self, path_to_binary, name):
"""Args:
path_to_binary: path to binary that accepts '-i <wav>', '-o
<float probabilities>'. There must be one float value per
10ms audio
name: a name to identify the external VAD. Used for saving
the output as extvad_output-<name>.
"""
self._path_to_binary = path_to_binary
self.name = name
assert os.path.exists(self._path_to_binary), (
self._path_to_binary)
self._vad_output = None
def Run(self, wav_file_path):
_signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
if _signal.channels != 1:
raise NotImplementedError('Multiple-channel'
' annotations not implemented')
if _signal.frame_rate != 48000:
raise NotImplementedError('Frame rates '
'other than 48000 not implemented')
tmp_path = tempfile.mkdtemp()
try:
output_file_path = os.path.join(
tmp_path, self.name + '_vad.tmp')
subprocess.call([
self._path_to_binary,
'-i', wav_file_path,
'-o', output_file_path
])
self._vad_output = np.fromfile(output_file_path, np.float32)
except Exception as e:
logging.error('Error while running the ' + self.name +
' VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def GetVadOutput(self):
assert self._vad_output is not None
return self._vad_output
@classmethod
def ConstructVadDict(cls, vad_paths, vad_names):
external_vads = {}
for path, name in zip(vad_paths, vad_names):
external_vads[name] = ExternalVad(path, name)
return external_vads

View File

@ -0,0 +1,24 @@
#!/usr/bin/python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True)
parser.add_argument('-o', required=True)
args = parser.parse_args()
array = np.arange(100, dtype=np.float32)
array.tofile(open(args.o, 'w'))
if __name__ == '__main__':
main()

View File

@ -65,8 +65,10 @@ class ApmInputMixer(object):
# This ensures that if the internal parameters of the echo path simulator
# change, no erroneous cache hit occurs.
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
mix_filepath = os.path.join(output_path, 'mix_capture_{}.wav'.format(
echo_file_name))
capture_input_file_name, _ = os.path.splitext(
os.path.split(capture_input_filepath)[1])
mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format(
capture_input_file_name, echo_file_name))
# Create the mix if not done yet.
mix = None

View File

@ -41,7 +41,9 @@ class ApmModuleSimulator(object):
_PREFIX_SCORE = 'score-'
def __init__(self, test_data_generator_factory, evaluation_score_factory,
ap_wrapper, evaluator):
ap_wrapper, evaluator, external_vads=None):
if external_vads is None:
external_vads = {}
self._test_data_generator_factory = test_data_generator_factory
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
@ -49,7 +51,9 @@ class ApmModuleSimulator(object):
self._annotator = annotations.AudioAnnotationsExtractor(
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM)
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
external_vads
)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(

View File

@ -26,6 +26,7 @@ import pydub
from . import audioproc_wrapper
from . import eval_scores_factory
from . import evaluation
from . import external_vad
from . import signal_processing
from . import simulation
from . import test_data_generation_factory
@ -75,7 +76,10 @@ class TestApmModuleSimulator(unittest.TestCase):
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator)
evaluator=evaluator,
external_vads={'fake': external_vad.ExternalVad(os.path.join(
os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
)
# What to simulate.
config_files = ['apm_configs/default.json']