APM-QA clean speech annotations.

Extract and save some simple annotations for the clean speech input.
The annotations are estimated level, VAD (assuming clean speech) and speech level.

TBR=

Bug: webrtc:7494
Change-Id: Id73358e228fac721a77fc8a61a3474a5d52bdc84
Reviewed-on: https://webrtc-review.googlesource.com/12321
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20327}
This commit is contained in:
Alessio Bazzica 2017-10-17 15:59:24 +02:00 committed by Commit Bot
parent 6592f2cfd2
commit 2bdeb226d5
8 changed files with 299 additions and 60 deletions

View File

@ -54,6 +54,7 @@ copy("lib") {
testonly = true
sources = [
"quality_assessment/__init__.py",
"quality_assessment/annotations.py",
"quality_assessment/audioproc_wrapper.py",
"quality_assessment/collect_data.py",
"quality_assessment/data_access.py",
@ -120,6 +121,7 @@ rtc_executable("fake_polqa") {
copy("lib_unit_tests") {
testonly = true
sources = [
"quality_assessment/annotations_unittest.py",
"quality_assessment/echo_path_simulation_unittest.py",
"quality_assessment/eval_scores_unittest.py",
"quality_assessment/input_mixer_unittest.py",

View File

@ -0,0 +1,119 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Extraction of annotations from audio files.
"""
from __future__ import division
import logging
import os
import sys
try:
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import signal_processing
class AudioAnnotationsExtractor(object):
"""Extracts annotations from audio files.
"""
_LEVEL_FILENAME = 'level.npy'
_VAD_FILENAME = 'vad.npy'
_SPEECH_LEVEL_FILENAME = 'speech_level.npy'
# Level estimation params. The time constants in ms indicate the time it takes
# for the level estimate to go down/up by 1 db if the signal is zero.
_LEVEL_ATTACK_MS = 5.0
_LEVEL_DECAY_MS = 20.0
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
# VAD params.
_VAD_THRESHOLD = 1
def __init__(self):
self._signal = None
self._level = None
self._vad = None
self._speech_level = None
self._c_attack = None
self._c_decay = None
@classmethod
def GetLevelFileName(cls):
return cls._LEVEL_FILENAME
@classmethod
def GetVadFileName(cls):
return cls._VAD_FILENAME
@classmethod
def GetSpeechLevelFileName(cls):
return cls._SPEECH_LEVEL_FILENAME
def GetLevel(self):
return self._level
def GetVad(self):
return self._vad
def GetSpeechLevel(self):
return self._speech_level
def Extract(self, filepath):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
if self._signal.channels != 1:
raise NotImplementedError('multiple-channel annotations not implemented')
# Smoothing params.
sample_duration_ms = 1000.0 / self._signal.frame_rate
self._c_attack = 0 if self._LEVEL_ATTACK_MS == 0 else (
self._ONE_DB_REDUCTION ** (sample_duration_ms / self._LEVEL_ATTACK_MS))
self._c_decay = 0 if self._LEVEL_DECAY_MS == 0 else (
self._ONE_DB_REDUCTION ** (sample_duration_ms / self._LEVEL_DECAY_MS))
# Compute level.
self._LevelEstimation()
# Naive VAD based on level thresholding. It assumes ideal clean speech
# with high SNR.
# TODO(alessiob): Maybe replace with a VAD based on stationary-noise
# detection.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._vad = np.uint8(self._level > vad_threshold)
# Speech level based on VAD output.
self._speech_level = self._level * self._vad
def Save(self, output_path):
np.save(os.path.join(output_path, self._LEVEL_FILENAME), self._level)
np.save(os.path.join(output_path, self._VAD_FILENAME), self._vad)
np.save(os.path.join(output_path, self._SPEECH_LEVEL_FILENAME),
self._speech_level)
def _LevelEstimation(self):
# Read samples.
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._signal)
num_samples = len(samples)
# Envelope.
self._level = np.abs(samples)
# Envelope smoothing.
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
for i in range(1, num_samples):
self._level[i] = smooth(
self._level[i], self._level[i - 1], self._c_attack if (
self._level[i] > self._level[i - 1]) else self._c_decay)

View File

@ -0,0 +1,67 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the annotations module.
"""
import logging
import os
import shutil
import tempfile
import unittest
import numpy as np
from . import annotations
from . import input_signal_creator
from . import signal_processing
class TestAnnotationsExtraction(unittest.TestCase):
"""Unit tests for the annotations module.
"""
_CLEAN_TMP_OUTPUT = False
def setUp(self):
"""Create temporary folder."""
self._tmp_path = tempfile.mkdtemp()
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone)
def tearDown(self):
"""Recursively delete temporary folder."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' + (
self._tmp_path))
def testExtraction(self):
e = annotations.AudioAnnotationsExtractor()
e.Extract(self._wav_file_path)
vad = e.GetVad()
assert len(vad) > 0
self.assertGreaterEqual(float(np.sum(vad)) / len(vad), 0.95)
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor()
e.Extract(self._wav_file_path)
e.Save(self._tmp_path)
np.testing.assert_array_equal(
e.GetLevel(),
np.load(os.path.join(self._tmp_path, e.GetLevelFileName())))
np.testing.assert_array_equal(
e.GetVad(),
np.load(os.path.join(self._tmp_path, e.GetVadFileName())))
np.testing.assert_array_equal(
e.GetSpeechLevel(),
np.load(os.path.join(self._tmp_path, e.GetSpeechLevelFileName())))

View File

@ -165,6 +165,8 @@ class SignalProcessingUtils(object):
@classmethod
def Fft(cls, signal, normalize=True):
if signal.channels != 1:
raise NotImplementedError('multiple-channel FFT not implemented')
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
if normalize:
x /= max(abs(np.max(x)), 1.0)
@ -188,7 +190,7 @@ class SignalProcessingUtils(object):
True if hard clipping is detect, False otherwise.
"""
if signal.channels != 1:
raise NotImplementedError('mutliple-channel clipping not implemented')
raise NotImplementedError('multiple-channel clipping not implemented')
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')

View File

@ -12,12 +12,15 @@
import logging
import os
from . import annotations
from . import data_access
from . import echo_path_simulation
from . import echo_path_simulation_factory
from . import eval_scores
from . import exceptions
from . import input_mixer
from . import input_signal_creator
from . import signal_processing
from . import test_data_generation
@ -43,6 +46,7 @@ class ApmModuleSimulator(object):
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor()
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(
@ -52,6 +56,7 @@ class ApmModuleSimulator(object):
# Properties for each run.
self._base_output_path = None
self._output_cache_path = None
self._test_data_generators = None
self._evaluation_score_workers = None
self._config_filepaths = None
@ -116,6 +121,9 @@ class ApmModuleSimulator(object):
'invalid echo path simulator')
self._base_output_path = os.path.abspath(output_dir)
# Output path used to cache the data shared across simulations.
self._output_cache_path = os.path.join(self._base_output_path, '_cache')
# Instance test data generators.
self._test_data_generators = [self._test_data_generator_factory.GetInstance(
test_data_generators_class=(
@ -164,14 +172,28 @@ class ApmModuleSimulator(object):
# Try different capture-render pairs.
for capture_input_name in self._capture_input_filepaths:
# Output path for the capture signal annotations.
capture_annotations_cache_path = os.path.join(
self._output_cache_path,
self._PREFIX_CAPTURE + capture_input_name)
data_access.MakeDirectory(capture_annotations_cache_path)
# Capture.
capture_input_filepath = self._capture_input_filepaths[
capture_input_name]
if not os.path.exists(capture_input_filepath):
# If the input signal file does not exist, try to create using the
# available input signal creators.
self._CreateInputSignal(capture_input_filepath)
assert os.path.exists(capture_input_filepath)
self._ExtractCaptureAnnotations(
capture_input_filepath, capture_annotations_cache_path)
# Render and simulated echo path (optional).
render_input_filepath = None if without_render_input else (
self._render_input_filepaths[capture_input_name])
render_input_name = '(none)' if without_render_input else (
self._ExtractFileName(render_input_filepath))
# Instance echo path simulator (if needed).
echo_path_simulator = (
echo_path_simulation_factory.EchoPathSimulatorFactory.GetInstance(
self._echo_path_simulator_class, render_input_filepath))
@ -184,10 +206,8 @@ class ApmModuleSimulator(object):
test_data_generators.NAME, echo_path_simulator.NAME)
# Output path for the generated test data.
# The path is used to cache the signals shared across simulations.
test_data_cache_path = os.path.join(
self._base_output_path, '_cache',
self._PREFIX_CAPTURE + capture_input_name,
capture_annotations_cache_path,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(test_data_cache_path)
logging.debug('test data cache path: <%s>', test_data_cache_path)
@ -216,6 +236,38 @@ class ApmModuleSimulator(object):
echo_test_data_cache_path, output_path,
config_filepath, echo_path_simulator)
@staticmethod
def _CreateInputSignal(input_signal_filepath):
"""Creates a missing input signal file.
The file name is parsed to extract input signal creator and params. If a
creator is matched and the parameters are valid, a new signal is generated
and written in |input_signal_filepath|.
Args:
input_signal_filepath: Path to the input signal audio file to write.
Raises:
InputSignalCreatorException
"""
filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
def _ExtractCaptureAnnotations(self, input_filepath, output_path):
self._annotator.Extract(input_filepath)
self._annotator.Save(output_path)
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path, config_filepath,

View File

@ -102,6 +102,39 @@ class TestApmModuleSimulator(unittest.TestCase):
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='')),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'))),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(any([os.path.exists(input_file) for input_file in (
input_files)]))
# The input files are created during the simulation.
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(all([os.path.exists(input_file) for input_file in (
input_files)]))
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
@ -143,3 +176,21 @@ class TestApmModuleSimulator(unittest.TestCase):
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))

View File

@ -33,7 +33,6 @@ except ImportError:
from . import data_access
from . import exceptions
from . import input_signal_creator
from . import signal_processing
@ -110,12 +109,6 @@ class TestDataGenerator(object):
base_output_path: base path where output is written.
"""
self.Clear()
# If the input signal file does not exist, try to create using the
# available input signal creators.
if not os.path.exists(input_signal_filepath):
self._CreateInputSignal(input_signal_filepath)
self._Generate(
input_signal_filepath, test_data_cache_path, base_output_path)
@ -126,34 +119,6 @@ class TestDataGenerator(object):
self._apm_output_paths = {}
self._reference_signal_filepaths = {}
@classmethod
def _CreateInputSignal(cls, input_signal_filepath):
"""Creates a missing input signal file.
The file name is parsed to extract input signal creator and params. If a
creator is matched and the parameters are valid, a new signal is generated
and written in |input_signal_filepath|.
Args:
input_signal_filepath: Path to the input signal audio file to write.
Raises:
InputSignalCreatorException
"""
filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Abstract method to be implemented in each concrete class.

View File

@ -52,25 +52,6 @@ class TestTestDataGenerators(unittest.TestCase):
shutil.rmtree(self._test_data_cache_path)
shutil.rmtree(self._fake_air_db_path)
def testInputSignalCreation(self):
# Init.
generator = test_data_generation.IdentityTestDataGenerator('tmp')
input_signal_filepath = os.path.join(
self._test_data_cache_path, 'pure_tone-440_1000.wav')
# Check that the input signal is generated.
self.assertFalse(os.path.exists(input_signal_filepath))
generator.Generate(
input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
self.assertTrue(os.path.exists(input_signal_filepath))
# Check input signal properties.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
self.assertEqual(1000, len(input_signal))
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))