diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn index 38863ce251..383b107b7b 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -37,9 +37,11 @@ copy("lib") { "quality_assessment/audioproc_wrapper.py", "quality_assessment/data_access.py", "quality_assessment/eval_scores.py", + "quality_assessment/eval_scores_factory.py", "quality_assessment/eval_scores_unittest.py", "quality_assessment/evaluation.py", "quality_assessment/noise_generation.py", + "quality_assessment/noise_generation_factory.py", "quality_assessment/noise_generation_unittest.py", "quality_assessment/signal_processing.py", "quality_assessment/signal_processing_unittest.py", diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py b/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py index ca7e2c3e93..7ba0028a8c 100755 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py @@ -63,6 +63,12 @@ def _instance_arguments_parser(): 'are saved'), default='output') + parser.add_argument('--polqa_path', required=True, + help='path to the POLQA tool') + + parser.add_argument('--air_db_path', required=True, + help='path to the Aechen IR database') + return parser @@ -73,7 +79,9 @@ def main(): parser = _instance_arguments_parser() args = parser.parse_args() - simulator = simulation.ApmModuleSimulator() + simulator = simulation.ApmModuleSimulator( + aechen_ir_database_path=args.air_db_path, + polqa_tool_path=args.polqa_path) simulator.run( config_filepaths=args.config_files, input_filepaths=args.input_files, diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh b/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh index 84330d0d16..91cdce6750 100755 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh @@ -7,6 +7,30 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +# Path to the POLQA tool. +if [ -z ${POLQA_PATH} ]; then # Check if defined. + # Default location. + export POLQA_PATH='/var/opt/PolqaOem64' +fi +if [ -d "${POLQA_PATH}" ]; then + echo "POLQA found in ${POLQA_PATH}" +else + echo "POLQA not found in ${POLQA_PATH}" + exit 1 +fi + +# Path to the Aechen IR database. +if [ -z ${AECHEN_IR_DATABASE_PATH} ]; then # Check if defined. + # Default location. + export AECHEN_IR_DATABASE_PATH='/var/opt/AIR_1_4' +fi +if [ -d "${AECHEN_IR_DATABASE_PATH}" ]; then + echo "AIR database found in ${AECHEN_IR_DATABASE_PATH}" +else + echo "AIR database not found in ${AECHEN_IR_DATABASE_PATH}" + exit 1 +fi + # Customize probing signals, noise sources and scores if needed. PROBING_SIGNALS=(probing_signals/*.wav) NOISE_SOURCES=( \ @@ -44,7 +68,9 @@ for probing_signal_filepath in "${PROBING_SIGNALS[@]}" ; do echo "Starting ${probing_signal_name} ${noise_source_name} "` `"(see ${LOG_FILE})" ./apm_quality_assessment.py \ - -i ${probing_signal_filepath}\ + --polqa_path ${POLQA_PATH}\ + --air_db_path ${AECHEN_IR_DATABASE_PATH}\ + -i ${probing_signal_filepath} \ -o ${OUTPUT_PATH} \ -n ${noise_source_name} \ -c "${APM_CONFIGS[@]}" \ diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py index ba46c579d0..843a84f14a 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py @@ -6,6 +6,9 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Class implementing a wrapper for audioproc_f. +""" + import cProfile import logging import os @@ -13,7 +16,10 @@ import subprocess from .data_access import AudioProcConfigFile + class AudioProcWrapper(object): + """Wrapper for audioproc_f. + """ OUTPUT_FILENAME = 'output.wav' _AUDIOPROC_F_BIN_PATH = os.path.abspath('audioproc_f') @@ -31,6 +37,14 @@ class AudioProcWrapper(object): return self._output_signal_filepath def run(self, config_filepath, input_filepath, output_path): + """Run audioproc_f. + + Args: + config_filepath: path to the configuration file specifing the arguments + for audioproc_f. + input_filepath: path to the audio track input file. + output_path: path of the audio track output file. + """ # Init. self._input_signal_filepath = input_filepath self._output_signal_filepath = os.path.join( diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py index acbe03f9af..aeee74746c 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py @@ -6,12 +6,18 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Data access utility functions and classes. +""" + import json import os + def make_directory(path): - """ - Recursively make a directory without rising exceptions if it already exists. + """Recursively make a directory without rising exceptions if already existing. + + Args: + path: path to the directory to be created. """ if os.path.exists(path): return @@ -19,8 +25,7 @@ def make_directory(path): class Metadata(object): - """ - Data access class to save and load metadata. + """Data access class to save and load metadata. """ def __init__(self): @@ -30,8 +35,12 @@ class Metadata(object): @classmethod def load_audio_in_ref_paths(cls, metadata_path): - """ - Metadata loader for input and reference audio track paths. + """Metadata loader for input and reference audio track paths. + + Args: + metadata_path: path to the directory containing the metadata file. + + Returns: pair of metadata file paths for the input and output audio tracks. """ metadata_filepath = os.path.join(metadata_path, cls._AUDIO_IN_REF_FILENAME) with open(metadata_filepath) as f: @@ -42,8 +51,7 @@ class Metadata(object): @classmethod def save_audio_in_ref_paths(cls, output_path, audio_in_filepath, audio_ref_filepath): - """ - Metadata saver for input and reference audio track paths. + """Metadata saver for input and reference audio track paths. """ output_filepath = os.path.join(output_path, cls._AUDIO_IN_REF_FILENAME) with open(output_filepath, 'w') as f: @@ -51,9 +59,9 @@ class Metadata(object): class AudioProcConfigFile(object): - """ - Data access class to save and load audioproc_f argument lists to control - the APM flags. + """Data access to load/save audioproc_f argument lists. + + The arguments stored in the config files are used to control the APM flags. """ def __init__(self): @@ -71,8 +79,7 @@ class AudioProcConfigFile(object): class ScoreFile(object): - """ - Data access class to save and load float scalar scores. + """Data access class to save and load float scalar scores. """ def __init__(self): diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py index 6787850f23..202a3b73a1 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py @@ -6,11 +6,15 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Evaluation score abstract class and implementations. +""" + import logging import os -from .data_access import ScoreFile -from .signal_processing import SignalProcessingUtils +from . import data_access +from . import signal_processing + class EvaluationScore(object): @@ -27,10 +31,12 @@ class EvaluationScore(object): @classmethod def register_class(cls, class_to_register): - """ + """Register an EvaluationScore implementation. + Decorator to automatically register the classes that extend EvaluationScore. """ cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register + return class_to_register @property def output_filepath(self): @@ -41,28 +47,28 @@ class EvaluationScore(object): return self._score def set_reference_signal_filepath(self, filepath): - """ - Set the path to the audio track used as reference signal. + """ Set the path to the audio track used as reference signal. """ self._reference_signal_filepath = filepath def set_tested_signal_filepath(self, filepath): - """ - Set the path to the audio track used as test signal. + """ Set the path to the audio track used as test signal. """ self._tested_signal_filepath = filepath def _load_reference_signal(self): assert self._reference_signal_filepath is not None - self._reference_signal = SignalProcessingUtils.load_wav( + self._reference_signal = signal_processing.SignalProcessingUtils.load_wav( self._reference_signal_filepath) def _load_tested_signal(self): assert self._tested_signal_filepath is not None - self._tested_signal = SignalProcessingUtils.load_wav( + self._tested_signal = signal_processing.SignalProcessingUtils.load_wav( self._tested_signal_filepath) def run(self, output_path): + """Extracts the score for the set input-reference pair. + """ self._output_filepath = os.path.join(output_path, 'score-{}.txt'.format( self.NAME)) try: @@ -79,16 +85,15 @@ class EvaluationScore(object): raise NotImplementedError() def _load_score(self): - return ScoreFile.load(self._output_filepath) + return data_access.ScoreFile.load(self._output_filepath) def _save_score(self): - return ScoreFile.save(self._output_filepath, self._score) + return data_access.ScoreFile.save(self._output_filepath, self._score) @EvaluationScore.register_class class AudioLevelScore(EvaluationScore): - """ - Compute the difference between the average audio level of the tested and + """Compute the difference between the average audio level of the tested and the reference signals. Unit: dB @@ -109,8 +114,7 @@ class AudioLevelScore(EvaluationScore): @EvaluationScore.register_class class PolqaScore(EvaluationScore): - """ - Compute the POLQA score. + """Compute the POLQA score. Unit: MOS Ideal: 4.5 @@ -119,8 +123,9 @@ class PolqaScore(EvaluationScore): NAME = 'polqa' - def __init__(self): + def __init__(self, polqa_tool_path): EvaluationScore.__init__(self) + self._polqa_tool_path = polqa_tool_path def _run(self, output_path): # TODO(alessio): implement. diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py new file mode 100644 index 0000000000..b33d3f96b2 --- /dev/null +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py @@ -0,0 +1,37 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +"""EvaluationScore factory class. +""" + +import logging + +from . import eval_scores + + +class EvaluationScoreWorkerFactory(object): + """Factory class used to instantiate evaluation score workers. + + It can be used by instanciating a factory, passing parameters to the + constructor. These parameters are used to instantiate evaluation score + workers. + """ + + def __init__(self, polqa_tool_path): + self._polqa_tool_path = polqa_tool_path + + def GetInstance(self, evaluation_score_class): + """Creates an EvaluationScore instance given a class object. + """ + logging.debug( + 'factory producing a %s evaluation score', evaluation_score_class) + if evaluation_score_class == eval_scores.PolqaScore: + return eval_scores.PolqaScore(self._polqa_tool_path) + else: + # By default, no arguments in the constructor. + return evaluation_score_class() diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py index 7ac7dd3615..1abe786a50 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py @@ -6,10 +6,14 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the evaluation scores. +""" + import unittest from . import eval_scores + class TestEvalScores(unittest.TestCase): def test_registered_classes(self): diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py index 78446a1e18..016690a2c4 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py @@ -6,8 +6,12 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Evaluator of the APM module. +""" + import logging + class ApmModuleEvaluator(object): def __init__(self): @@ -16,6 +20,10 @@ class ApmModuleEvaluator(object): @classmethod def run(cls, evaluation_score_workers, apm_output_filepath, reference_input_filepath, output_path): + """Runs the evaluation. + + Iterates over the given evaluation score workers. + """ # Init. scores = {} diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation.py index ef1eec99aa..a6a8a0653c 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation.py @@ -7,25 +7,32 @@ # be found in the AUTHORS file in the root of the source tree. """Noise generators producing pairs of signals intended to be used to test the - APM module. Each pair consists of a noisy and a reference signal. The former - is used as input for APM, and it is generated by adding noise to a signal. - The reference is the expected APM output when using the generated input. + APM module. Each pair consists of a noisy and a reference signal. The former + is used as input for APM, and it is generated by adding noise to a signal. + The reference is the expected APM output when using the generated input. - Throughout this file, the following naming convention is used: - - input signal: the clean signal (e.g., speech), - - noise signal: the noise to be summed up to the input signal (e.g., white - noise, Gaussian noise), - - noisy signal: input + noise. - The noise signal may or may not be a function of the clean signal. For - instance, white noise is independently generated, whereas reverberation is - obtained by convolving the input signal with an impulse response. + Throughout this file, the following naming convention is used: + - input signal: the clean signal (e.g., speech), + - noise signal: the noise to be summed up to the input signal (e.g., white + noise, Gaussian noise), + - noisy signal: input + noise. + The noise signal may or may not be a function of the clean signal. For + instance, white noise is independently generated, whereas reverberation is + obtained by convolving the input signal with an impulse response. """ import logging import os +import sys + +try: + import scipy.io +except ImportError: + logging.critical('Cannot import the third-party Python package scipy') + sys.exit(1) from . import data_access -from .signal_processing import SignalProcessingUtils +from . import signal_processing class NoiseGenerator(object): """Abstract class responsible for the generation of noisy signals. @@ -52,10 +59,12 @@ class NoiseGenerator(object): @classmethod def register_class(cls, class_to_register): - """ Decorator to automatically register the classes that extend - NoiseGenerator. + """Register an NoiseGenerator implementation. + + Decorator to automatically register the classes that extend NoiseGenerator. """ cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register + return class_to_register @property def config_names(self): @@ -76,8 +85,9 @@ class NoiseGenerator(object): def generate( self, input_signal_filepath, input_noise_cache_path, base_output_path): """Generate a set of noisy input and reference audiotrack file pairs. - This method initializes an empty set of pairs and calls the _generate() - method implemented in a concrete class. + + This method initializes an empty set of pairs and calls the _generate() + method implemented in a concrete class. """ self.clear() return self._generate( @@ -96,7 +106,7 @@ class NoiseGenerator(object): def _add_noise_snr_pairs(self, base_output_path, noisy_mix_filepaths, snr_value_pairs): - """ Add noisy-reference signal pairs. + """Adds noisy-reference signal pairs. Args: base_output_path: noisy tracks base output path. @@ -142,9 +152,9 @@ class NoiseGenerator(object): # Identity generator. @NoiseGenerator.register_class class IdentityGenerator(NoiseGenerator): - """ - Generator that adds no noise, therefore both the noisy and the reference - signals are the input signal. + """Generator that adds no noise. + + Both the noisy and the reference signals are the input signal. """ NAME = 'identity' @@ -165,8 +175,7 @@ class IdentityGenerator(NoiseGenerator): @NoiseGenerator.register_class class WhiteNoiseGenerator(NoiseGenerator): - """ - Additive white noise generator. + """Additive white noise generator. """ NAME = 'white' @@ -189,12 +198,16 @@ class WhiteNoiseGenerator(NoiseGenerator): def _generate( self, input_signal_filepath, input_noise_cache_path, base_output_path): # Load the input signal. - input_signal = SignalProcessingUtils.load_wav(input_signal_filepath) - input_signal = SignalProcessingUtils.normalize(input_signal) + input_signal = signal_processing.SignalProcessingUtils.load_wav( + input_signal_filepath) + input_signal = signal_processing.SignalProcessingUtils.normalize( + input_signal) # Create the noise track. - noise_signal = SignalProcessingUtils.generate_white_noise(input_signal) - noise_signal = SignalProcessingUtils.normalize(noise_signal) + noise_signal = signal_processing.SignalProcessingUtils.generate_white_noise( + input_signal) + noise_signal = signal_processing.SignalProcessingUtils.normalize( + noise_signal) # Create the noisy mixes (once for each unique SNR value). noisy_mix_filepaths = {} @@ -207,11 +220,12 @@ class WhiteNoiseGenerator(NoiseGenerator): # Create and save if not done. if not os.path.exists(noisy_signal_filepath): # Create noisy signal. - noisy_signal = SignalProcessingUtils.mix_signals( + noisy_signal = signal_processing.SignalProcessingUtils.mix_signals( input_signal, noise_signal, snr) # Save. - SignalProcessingUtils.save_wav(noisy_signal_filepath, noisy_signal) + signal_processing.SignalProcessingUtils.save_wav( + noisy_signal_filepath, noisy_signal) # Add file to the collection of mixes. noisy_mix_filepaths[snr] = noisy_signal_filepath @@ -230,8 +244,7 @@ class WhiteNoiseGenerator(NoiseGenerator): # TODO(alessiob): remove comment when class implemented. # @NoiseGenerator.register_class class NarrowBandNoiseGenerator(NoiseGenerator): - """ - Additive narrow-band noise generator. + """Additive narrow-band noise generator. """ NAME = 'narrow_band' @@ -247,8 +260,7 @@ class NarrowBandNoiseGenerator(NoiseGenerator): @NoiseGenerator.register_class class EnvironmentalNoiseGenerator(NoiseGenerator): - """ - Additive environmental noise generator. + """Additive environmental noise generator. """ NAME = 'environmental' @@ -258,6 +270,7 @@ class EnvironmentalNoiseGenerator(NoiseGenerator): _NOISE_TRACKS_PATH = os.path.join(os.getcwd(), 'noise_tracks') # TODO(alessiob): allow the user to have custom noise tracks. + # TODO(alessiob): exploit NoiseGeneratorFactory.GetInstance(). _NOISE_TRACKS = [ 'city.wav' ] @@ -277,12 +290,26 @@ class EnvironmentalNoiseGenerator(NoiseGenerator): def _generate( self, input_signal_filepath, input_noise_cache_path, base_output_path): + """Generate environmental noise. + + For each noise track and pair of SNR values, the following 2 audio tracks + are created: the noisy signal and the reference signal. The former is + obtained by mixing the (clean) input signal to the corresponding noise + track enforcing the target SNR. + + Args: + input_signal_filepath: (clean) input signal file path. + input_noise_cache_path: path for the cached noise track files. + base_output_path: base output path. + """ # Init. snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair]) # Load the input signal. - input_signal = SignalProcessingUtils.load_wav(input_signal_filepath) - input_signal = SignalProcessingUtils.normalize(input_signal) + input_signal = signal_processing.SignalProcessingUtils.load_wav( + input_signal_filepath) + input_signal = signal_processing.SignalProcessingUtils.normalize( + input_signal) noisy_mix_filepaths = {} for noise_track_filename in self._NOISE_TRACKS: @@ -294,8 +321,10 @@ class EnvironmentalNoiseGenerator(NoiseGenerator): logging.error('cannot find the <%s> noise track', noise_track_filename) continue - noise_signal = SignalProcessingUtils.load_wav(noise_track_filepath) - noise_signal = SignalProcessingUtils.normalize(noise_signal) + noise_signal = signal_processing.SignalProcessingUtils.load_wav( + noise_track_filepath) + noise_signal = signal_processing.SignalProcessingUtils.normalize( + noise_signal) # Create the noisy mixes (once for each unique SNR value). noisy_mix_filepaths[noise_track_name] = {} @@ -307,11 +336,12 @@ class EnvironmentalNoiseGenerator(NoiseGenerator): # Create and save if not done. if not os.path.exists(noisy_signal_filepath): # Create noisy signal. - noisy_signal = SignalProcessingUtils.mix_signals( + noisy_signal = signal_processing.SignalProcessingUtils.mix_signals( input_signal, noise_signal, snr) # Save. - SignalProcessingUtils.save_wav(noisy_signal_filepath, noisy_signal) + signal_processing.SignalProcessingUtils.save_wav( + noisy_signal_filepath, noisy_signal) # Add file to the collection of mixes. noisy_mix_filepaths[noise_track_name][snr] = noisy_signal_filepath @@ -321,19 +351,128 @@ class EnvironmentalNoiseGenerator(NoiseGenerator): base_output_path, noisy_mix_filepaths, self._SNR_VALUE_PAIRS) -# TODO(alessiob): remove comment when class implemented. -# @NoiseGenerator.register_class +@NoiseGenerator.register_class class EchoNoiseGenerator(NoiseGenerator): - """ - Echo noise generator. + """Echo noise generator. """ NAME = 'echo' - def __init__(self): + _IMPULSE_RESPONSES = { + 'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo. + 'booth': 'air_binaural_booth_0_0_1.mat', # Short echo. + } + _MAX_IMPULSE_RESPONSE_LENGTH = None + + # Each pair indicates the clean vs. noisy and reference vs. noisy SNRs. + # The reference (second value of each pair) always has a lower amount of noise + # - i.e., the SNR is 5 dB higher. + _SNR_VALUE_PAIRS = [ + [3, 8], # Smallest noise. + [-3, 2], # Largest noise. + ] + + _NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav' + _NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav' + + def __init__(self, aechen_ir_database_path): NoiseGenerator.__init__(self) + self._aechen_ir_database_path = aechen_ir_database_path def _generate( self, input_signal_filepath, input_noise_cache_path, base_output_path): - # TODO(alessiob): implement. - pass + """Generates echo noise. + + For each impulse response, one noise track is created. For each impulse + response and pair of SNR values, the following 2 audio tracks are + created: the noisy signal and the reference signal. The former is + obtained by mixing the (clean) input signal to the corresponding noise + track enforcing the target SNR. + + Args: + input_signal_filepath: (clean) input signal file path. + input_noise_cache_path: path for the cached noise track files. + base_output_path: base output path. + """ + # Init. + snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair]) + + # Load the input signal. + input_signal = signal_processing.SignalProcessingUtils.load_wav( + input_signal_filepath) + + noisy_mix_filepaths = {} + for impulse_response_name in self._IMPULSE_RESPONSES: + noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format( + impulse_response_name) + noise_track_filepath = os.path.join( + input_noise_cache_path, noise_track_filename) + noise_signal = None + try: + # Load noise track. + noise_signal = signal_processing.SignalProcessingUtils.load_wav( + noise_track_filepath) + except IOError: # File not found. + # Generate noise track by applying the impulse response. + impulse_response_filepath = os.path.join( + self._aechen_ir_database_path, + self._IMPULSE_RESPONSES[impulse_response_name]) + noise_signal = self._generate_noise_track( + noise_track_filepath, input_signal, impulse_response_filepath) + assert noise_signal is not None + + # Create the noisy mixes (once for each unique SNR value). + noisy_mix_filepaths[impulse_response_name] = {} + for snr in snr_values: + noisy_signal_filepath = os.path.join( + input_noise_cache_path, + self._NOISY_SIGNAL_FILENAME_TEMPLATE.format( + impulse_response_name, snr)) + + # Create and save if not done. + if not os.path.exists(noisy_signal_filepath): + # Create noisy signal. + noisy_signal = signal_processing.SignalProcessingUtils.mix_signals( + input_signal, noise_signal, snr, bln_pad_shortest=True) + + # Save. + signal_processing.SignalProcessingUtils.save_wav( + noisy_signal_filepath, noisy_signal) + + # Add file to the collection of mixes. + noisy_mix_filepaths[impulse_response_name][snr] = noisy_signal_filepath + + # Add all the noise-SNR pairs. + self._add_noise_snr_pairs(base_output_path, noisy_mix_filepaths, + self._SNR_VALUE_PAIRS) + + def _generate_noise_track(self, noise_track_filepath, input_signal, + impulse_response_filepath): + """Generates noise track. + + Generate a signal by convolving input_signal with the impulse response in + impulse_response_filepath; then save to noise_track_filepath. + + Args: + noise_track_filepath: output file path for the noise track. + input_signal: (clean) input signal samples. + impulse_response_filepath: impulse response file path. + """ + # Load impulse response. + data = scipy.io.loadmat(impulse_response_filepath) + impulse_response = data['h_air'].flatten() + if self._MAX_IMPULSE_RESPONSE_LENGTH is not None: + logging.info('truncating impulse response from %d to %d samples', + len(impulse_response), self._MAX_IMPULSE_RESPONSE_LENGTH) + impulse_response = impulse_response[:self._MAX_IMPULSE_RESPONSE_LENGTH] + + # Apply impulse response. + processed_signal = ( + signal_processing.SignalProcessingUtils.apply_impulse_response( + input_signal, impulse_response)) + + # Save. + signal_processing.SignalProcessingUtils.save_wav( + noise_track_filepath, processed_signal) + + return processed_signal diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_factory.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_factory.py new file mode 100644 index 0000000000..acb9f07675 --- /dev/null +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_factory.py @@ -0,0 +1,37 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +"""NoiseGenerator factory class. +""" + +import logging + +from . import noise_generation + + +class NoiseGeneratorFactory(object): + """Factory class used to instantiate noise generator workers. + + It can be used by instanciating a factory, passing parameters to the + constructor. These parameters are used to instantiate noise generator + workers. + """ + + def __init__(self, aechen_ir_database_path): + self._aechen_ir_database_path = aechen_ir_database_path + + def GetInstance(self, noise_generator_class): + """Creates an NoiseGenerator instance given a class object. + """ + logging.debug( + 'factory producing a %s noise generator', noise_generator_class) + if noise_generator_class == noise_generation.EchoNoiseGenerator: + return noise_generation.EchoNoiseGenerator(self._aechen_ir_database_path) + else: + # By default, no arguments in the constructor. + return noise_generator_class() diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py index 55fd1fd2a8..2b750913ac 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py @@ -6,14 +6,19 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the noise_generation module. +""" + import os import shutil import tempfile import unittest from . import noise_generation +from . import noise_generation_factory from . import signal_processing + class TestNoiseGen(unittest.TestCase): def setUp(self): @@ -36,6 +41,13 @@ class TestNoiseGen(unittest.TestCase): self.assertIsInstance(registered_classes, dict) self.assertGreater(len(registered_classes), 0) + # Instance noise generator factory. + noise_generator_factory = noise_generation_factory.NoiseGeneratorFactory( + aechen_ir_database_path='') + # TODO(alessiob): Replace with a mock of NoiseGeneratorFactory that takes + # no arguments in the ctor. For those generators that need parameters, it + # will return a mock generator (see the first comment in the next for loop). + # Use a sample input file as clean input signal. input_signal_filepath = os.path.join( os.getcwd(), 'probing_signals', 'tone-880.wav') @@ -47,9 +59,17 @@ class TestNoiseGen(unittest.TestCase): # Try each registered noise generator. for noise_generator_name in registered_classes: + # Exclude EchoNoiseGenerator. + # TODO(alessiob): Mock EchoNoiseGenerator, the mock should rely on + # hard-coded impulse responses. This requires a mock for + # NoiseGeneratorFactory. The latter knows whether returning the actual + # generator or a mock object (as in the case of EchoNoiseGenerator). + if noise_generator_name == 'echo': + continue + # Instance noise generator. - noise_generator_class = registered_classes[noise_generator_name] - noise_generator = noise_generator_class() + noise_generator = noise_generator_factory.GetInstance( + registered_classes[noise_generator_name]) # Generate the noisy input - reference pairs. noise_generator.generate( @@ -78,8 +98,14 @@ class TestNoiseGen(unittest.TestCase): def _CheckNoiseGeneratorPairsSignalDurations( self, noise_generator, input_signal): - """Checks that the noisy input and the reference tracks are audio files - with duration >= to that of the input signal. + """Check duration of the signals generated by a noise generator. + + Checks that the noisy input and the reference tracks are audio files + with duration equal to or greater than that of the input signal. + + Args: + noise_generator: NoiseGenerator instance. + input_signal: AudioSegment instance. """ input_signal_length = ( signal_processing.SignalProcessingUtils.count_samples(input_signal)) @@ -111,6 +137,9 @@ class TestNoiseGen(unittest.TestCase): def _CheckNoiseGeneratorPairsOutputPaths(self, noise_generator): """Checks that the output path created by the generator exists. + + Args: + noise_generator: NoiseGenerator instance. """ # Iterate over the noisy signal - reference pairs. for noise_config_name in noise_generator.config_names: diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py index 1dbff3d581..76b103d091 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py @@ -8,11 +8,27 @@ import array import logging +import sys + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +try: + import pydub + import pydub.generators +except ImportError: + logging.critical('Cannot import the third-party Python package pydub') + sys.exit(1) + +try: + import scipy.signal +except ImportError: + logging.critical('Cannot import the third-party Python package scipy') + sys.exit(1) -import numpy as np -import pydub -import pydub.generators -import scipy.signal class SignalProcessingException(Exception): pass @@ -25,7 +41,8 @@ class SignalProcessingUtils(object): @classmethod def load_wav(cls, filepath, channels=1): - """ + """Load wav file. + Return: AudioSegment instance. """ @@ -34,7 +51,8 @@ class SignalProcessingUtils(object): @classmethod def save_wav(cls, output_filepath, signal): - """ + """Save wav file. + Args: output_filepath: string, output file path. signal: AudioSegment instance. @@ -43,8 +61,7 @@ class SignalProcessingUtils(object): @classmethod def count_samples(cls, signal): - """ - Number of samples per channel. + """Number of samples per channel. Args: signal: AudioSegment instance. @@ -56,7 +73,8 @@ class SignalProcessingUtils(object): @classmethod def generate_white_noise(cls, signal): - """ + """Generate white noise. + Generate white noise with the same duration and in the same format as a given signal. @@ -75,6 +93,8 @@ class SignalProcessingUtils(object): @classmethod def apply_impulse_response(cls, signal, impulse_response): + """Apply an impulse response to a signal. + """ # Get samples. assert signal.channels == 1, ( 'multiple-channel recordings not supported') @@ -128,7 +148,8 @@ class SignalProcessingUtils(object): @classmethod def mix_signals(cls, signal, noise, target_snr=0.0, bln_pad_shortest=False): - """ + """Mix two signals with a target SNR. + Mix two signals up to a desired SNR by scaling noise (noise). If the target SNR is +/- infinite, a copy of signal/noise is returned. diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py index 82b138c4a4..1765a2c515 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py @@ -6,6 +6,9 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the signal_processing module. +""" + import unittest import numpy as np @@ -13,6 +16,7 @@ import pydub from . import signal_processing + class TestSignalProcessing(unittest.TestCase): def testMixSignals(self): diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py index ac20edc9d7..f7962001b3 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py @@ -6,24 +6,42 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +"""APM module simulator. +""" + import logging import os from . import audioproc_wrapper from . import data_access from . import eval_scores +from . import eval_scores_factory from . import evaluation from . import noise_generation +from . import noise_generation_factory + class ApmModuleSimulator(object): + """APM module simulator class. + """ _NOISE_GENERATOR_CLASSES = noise_generation.NoiseGenerator.REGISTERED_CLASSES _EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES - def __init__(self): + def __init__(self, aechen_ir_database_path, polqa_tool_path): + # Init. self._audioproc_wrapper = audioproc_wrapper.AudioProcWrapper() self._evaluator = evaluation.ApmModuleEvaluator() + # Instance factory objects. + self._noise_generator_factory = ( + noise_generation_factory.NoiseGeneratorFactory( + aechen_ir_database_path=aechen_ir_database_path)) + self._evaluation_score_factory = ( + eval_scores_factory.EvaluationScoreWorkerFactory( + polqa_tool_path=polqa_tool_path)) + + # Properties for each run. self._base_output_path = None self._noise_generators = None self._evaluation_score_workers = None @@ -38,12 +56,15 @@ class ApmModuleSimulator(object): self._base_output_path = os.path.abspath(output_dir) # Instance noise generators. - self._noise_generators = [ - self._NOISE_GENERATOR_CLASSES[name]() for name in noise_generator_names] + self._noise_generators = [self._noise_generator_factory.GetInstance( + noise_generator_class=self._NOISE_GENERATOR_CLASSES[name]) for name in ( + noise_generator_names)] # Instance evaluation score workers. self._evaluation_score_workers = [ - self._EVAL_SCORE_WORKER_CLASSES[name]() for name in eval_score_names] + self._evaluation_score_factory.GetInstance( + evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name]) for ( + name) in eval_score_names] # Set APM configuration file paths. self._config_filepaths = self._get_paths_collection(config_filepaths)