Total Harmonic Distorsion plus noise (THD+n) score in APM-QA.
In order to compute a THD score, a pure tone must be used as input signal. Also, its frequency must be known. For this reason, this CL adds a number of changes in the APM-QA pipeline. More in detail, input signal metadata is loaded and passed to the THD evaluation score instance. This makes the eval_scores module less reusable, but it is fine since the module has been specifically designed for the APM-QA module. BUG=webrtc:7494 Review-Url: https://codereview.webrtc.org/3010413002 Cr-Commit-Position: refs/heads/master@{#19970}
This commit is contained in:
parent
a42055116d
commit
5d26edcc02
@ -33,6 +33,12 @@ reference one used for evaluation.
|
|||||||
- Go to `out/Default/py_quality_assessment` and check that
|
- Go to `out/Default/py_quality_assessment` and check that
|
||||||
`apm_quality_assessment.py` exists
|
`apm_quality_assessment.py` exists
|
||||||
|
|
||||||
|
## Unit tests
|
||||||
|
|
||||||
|
- Compile WebRTC
|
||||||
|
- Go to `out/Default/py_quality_assessment`
|
||||||
|
- Run `python -m unittest -p "*_unittest.py" discover`
|
||||||
|
|
||||||
## First time setup
|
## First time setup
|
||||||
|
|
||||||
- Deploy PolqaOem64 and set the `POLQA_PATH` environment variable
|
- Deploy PolqaOem64 and set the `POLQA_PATH` environment variable
|
||||||
|
|||||||
@ -31,8 +31,33 @@ class Metadata(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
_GENERIC_METADATA_SUFFIX = '.mdata'
|
||||||
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def LoadFileMetadata(cls, filepath):
|
||||||
|
"""Loads generic metadata linked to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: path to the metadata file to read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict.
|
||||||
|
"""
|
||||||
|
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def SaveFileMetadata(cls, filepath, metadata):
|
||||||
|
"""Saves generic metadata linked to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: path to the metadata file to write.
|
||||||
|
metadata: a dict.
|
||||||
|
"""
|
||||||
|
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
|
||||||
|
json.dump(metadata, f)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def LoadAudioTestDataPaths(cls, metadata_path):
|
def LoadAudioTestDataPaths(cls, metadata_path):
|
||||||
"""Loads the input and the reference audio track paths.
|
"""Loads the input and the reference audio track paths.
|
||||||
|
|||||||
@ -14,6 +14,13 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
logging.critical('Cannot import the third-party Python package numpy')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
from . import data_access
|
from . import data_access
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
@ -27,6 +34,7 @@ class EvaluationScore(object):
|
|||||||
|
|
||||||
def __init__(self, score_filename_prefix):
|
def __init__(self, score_filename_prefix):
|
||||||
self._score_filename_prefix = score_filename_prefix
|
self._score_filename_prefix = score_filename_prefix
|
||||||
|
self._input_signal_metadata = None
|
||||||
self._reference_signal = None
|
self._reference_signal = None
|
||||||
self._reference_signal_filepath = None
|
self._reference_signal_filepath = None
|
||||||
self._tested_signal = None
|
self._tested_signal = None
|
||||||
@ -56,8 +64,16 @@ class EvaluationScore(object):
|
|||||||
def score(self):
|
def score(self):
|
||||||
return self._score
|
return self._score
|
||||||
|
|
||||||
|
def SetInputSignalMetadata(self, metadata):
|
||||||
|
"""Sets input signal metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: dict instance.
|
||||||
|
"""
|
||||||
|
self._input_signal_metadata = metadata
|
||||||
|
|
||||||
def SetReferenceSignalFilepath(self, filepath):
|
def SetReferenceSignalFilepath(self, filepath):
|
||||||
""" Sets the path to the audio track used as reference signal.
|
"""Sets the path to the audio track used as reference signal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filepath: path to the reference audio track.
|
filepath: path to the reference audio track.
|
||||||
@ -65,7 +81,7 @@ class EvaluationScore(object):
|
|||||||
self._reference_signal_filepath = filepath
|
self._reference_signal_filepath = filepath
|
||||||
|
|
||||||
def SetTestedSignalFilepath(self, filepath):
|
def SetTestedSignalFilepath(self, filepath):
|
||||||
""" Sets the path to the audio track used as test signal.
|
"""Sets the path to the audio track used as test signal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filepath: path to the test audio track.
|
filepath: path to the test audio track.
|
||||||
@ -242,3 +258,84 @@ class PolqaScore(EvaluationScore):
|
|||||||
# Build and return a dictionary with field names (header) as keys and the
|
# Build and return a dictionary with field names (header) as keys and the
|
||||||
# corresponding field values as values.
|
# corresponding field values as values.
|
||||||
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
|
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
|
||||||
|
|
||||||
|
|
||||||
|
@EvaluationScore.RegisterClass
|
||||||
|
class TotalHarmonicDistorsionScore(EvaluationScore):
|
||||||
|
"""Total harmonic distorsion plus noise score.
|
||||||
|
|
||||||
|
Total harmonic distorsion plus noise score.
|
||||||
|
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
|
||||||
|
|
||||||
|
Unit: -.
|
||||||
|
Ideal: 0.
|
||||||
|
Worst case: +inf
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = 'thd'
|
||||||
|
|
||||||
|
def __init__(self, score_filename_prefix):
|
||||||
|
EvaluationScore.__init__(self, score_filename_prefix)
|
||||||
|
self._input_frequency = None
|
||||||
|
|
||||||
|
def _Run(self, output_path):
|
||||||
|
# TODO(aleloi): Integrate changes made locally.
|
||||||
|
self._CheckInputSignal()
|
||||||
|
|
||||||
|
self._LoadTestedSignal()
|
||||||
|
if self._tested_signal.channels != 1:
|
||||||
|
raise exceptions.EvaluationScoreException(
|
||||||
|
'unsupported number of channels')
|
||||||
|
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||||
|
self._tested_signal)
|
||||||
|
|
||||||
|
# Init.
|
||||||
|
num_samples = len(samples)
|
||||||
|
duration = len(self._tested_signal) / 1000.0
|
||||||
|
scaling = 2.0 / num_samples
|
||||||
|
max_freq = self._tested_signal.frame_rate / 2
|
||||||
|
f0_freq = float(self._input_frequency)
|
||||||
|
t = np.linspace(0, duration, num_samples)
|
||||||
|
|
||||||
|
# Analyze harmonics.
|
||||||
|
b_terms = []
|
||||||
|
n = 1
|
||||||
|
while f0_freq * n < max_freq:
|
||||||
|
x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||||
|
y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||||
|
b_terms.append(np.sqrt(x_n**2 + y_n**2))
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
output_without_fundamental = samples - b_terms[0] * np.sin(
|
||||||
|
2.0 * np.pi * f0_freq * t)
|
||||||
|
distortion_and_noise = np.sqrt(np.sum(
|
||||||
|
output_without_fundamental**2) * np.pi * scaling)
|
||||||
|
|
||||||
|
# TODO(alessiob): Fix or remove if not needed.
|
||||||
|
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
|
||||||
|
|
||||||
|
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
|
||||||
|
# docstring above if accordingly.
|
||||||
|
thd_plus_noise = distortion_and_noise / b_terms[0]
|
||||||
|
|
||||||
|
self._score = thd_plus_noise
|
||||||
|
self._SaveScore()
|
||||||
|
|
||||||
|
def _CheckInputSignal(self):
|
||||||
|
# Check input signal and get properties.
|
||||||
|
try:
|
||||||
|
if self._input_signal_metadata['signal'] != 'pure_tone':
|
||||||
|
raise exceptions.EvaluationScoreException(
|
||||||
|
'The THD score requires a pure tone as input signal')
|
||||||
|
self._input_frequency = self._input_signal_metadata['frequency']
|
||||||
|
if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
|
||||||
|
self._input_signal_metadata['test_data_gen_config'] != 'default'):
|
||||||
|
raise exceptions.EvaluationScoreException(
|
||||||
|
'The THD score cannot be used with any test data generator other '
|
||||||
|
'than "identity"')
|
||||||
|
except TypeError:
|
||||||
|
raise exceptions.EvaluationScoreException(
|
||||||
|
'The THD score requires an input signal with associated metadata')
|
||||||
|
except KeyError:
|
||||||
|
raise exceptions.EvaluationScoreException(
|
||||||
|
'Invalid input signal metadata to compute the THD score')
|
||||||
|
|||||||
@ -52,6 +52,9 @@ class TestEvalScores(unittest.TestCase):
|
|||||||
shutil.rmtree(self._output_path)
|
shutil.rmtree(self._output_path)
|
||||||
|
|
||||||
def testRegisteredClasses(self):
|
def testRegisteredClasses(self):
|
||||||
|
# Evaluation score names to exclude (tested separately).
|
||||||
|
exceptions = ['thd']
|
||||||
|
|
||||||
# Preliminary check.
|
# Preliminary check.
|
||||||
self.assertTrue(os.path.exists(self._output_path))
|
self.assertTrue(os.path.exists(self._output_path))
|
||||||
|
|
||||||
@ -69,11 +72,14 @@ class TestEvalScores(unittest.TestCase):
|
|||||||
|
|
||||||
# Try each registered evaluation score worker.
|
# Try each registered evaluation score worker.
|
||||||
for eval_score_name in registered_classes:
|
for eval_score_name in registered_classes:
|
||||||
|
if eval_score_name in exceptions:
|
||||||
|
continue
|
||||||
|
|
||||||
# Instance evaluation score worker.
|
# Instance evaluation score worker.
|
||||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||||
registered_classes[eval_score_name])
|
registered_classes[eval_score_name])
|
||||||
|
|
||||||
# Set reference and test, then run.
|
# Set fake input metadata and reference and test file paths, then run.
|
||||||
eval_score_worker.SetReferenceSignalFilepath(
|
eval_score_worker.SetReferenceSignalFilepath(
|
||||||
self._fake_reference_signal_filepath)
|
self._fake_reference_signal_filepath)
|
||||||
eval_score_worker.SetTestedSignalFilepath(
|
eval_score_worker.SetTestedSignalFilepath(
|
||||||
@ -83,3 +89,43 @@ class TestEvalScores(unittest.TestCase):
|
|||||||
# Check output.
|
# Check output.
|
||||||
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
|
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
|
||||||
self.assertTrue(isinstance(score, float))
|
self.assertTrue(isinstance(score, float))
|
||||||
|
|
||||||
|
def testTotalHarmonicDistorsionScore(self):
|
||||||
|
# Init.
|
||||||
|
pure_tone_freq = 5000.0
|
||||||
|
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
|
||||||
|
eval_score_worker.SetInputSignalMetadata({
|
||||||
|
'signal': 'pure_tone',
|
||||||
|
'frequency': pure_tone_freq,
|
||||||
|
'test_data_gen_name': 'identity',
|
||||||
|
'test_data_gen_config': 'default',
|
||||||
|
})
|
||||||
|
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||||
|
|
||||||
|
# Create 3 test signals: pure tone, pure tone + white noise, white noise
|
||||||
|
# only.
|
||||||
|
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||||
|
template, pure_tone_freq)
|
||||||
|
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||||
|
template)
|
||||||
|
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
|
||||||
|
pure_tone, white_noise)
|
||||||
|
|
||||||
|
# Compute scores for increasingly distorted pure tone signals.
|
||||||
|
scores = [None, None, None]
|
||||||
|
for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
|
||||||
|
# Save signal.
|
||||||
|
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
|
||||||
|
signal_processing.SignalProcessingUtils.SaveWav(
|
||||||
|
tmp_filepath, tested_signal)
|
||||||
|
|
||||||
|
# Compute score.
|
||||||
|
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
|
||||||
|
eval_score_worker.Run(self._output_path)
|
||||||
|
scores[index] = eval_score_worker.score
|
||||||
|
|
||||||
|
# Remove output file to avoid caching.
|
||||||
|
os.remove(eval_score_worker.output_filepath)
|
||||||
|
|
||||||
|
# Validate scores (lowest score with a pure tone).
|
||||||
|
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
|
||||||
|
|||||||
@ -20,14 +20,15 @@ class ApmModuleEvaluator(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Run(cls, evaluation_score_workers, apm_output_filepath,
|
def Run(cls, evaluation_score_workers, apm_input_metadata,
|
||||||
reference_input_filepath, output_path):
|
apm_output_filepath, reference_input_filepath, output_path):
|
||||||
"""Runs the evaluation.
|
"""Runs the evaluation.
|
||||||
|
|
||||||
Iterates over the given evaluation score workers.
|
Iterates over the given evaluation score workers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
evaluation_score_workers: list of EvaluationScore instances.
|
evaluation_score_workers: list of EvaluationScore instances.
|
||||||
|
apm_input_metadata: dictionary with metadata of the APM input.
|
||||||
apm_output_filepath: path to the audio track file with the APM output.
|
apm_output_filepath: path to the audio track file with the APM output.
|
||||||
reference_input_filepath: path to the reference audio track file.
|
reference_input_filepath: path to the reference audio track file.
|
||||||
output_path: output path.
|
output_path: output path.
|
||||||
@ -40,6 +41,7 @@ class ApmModuleEvaluator(object):
|
|||||||
|
|
||||||
for evaluation_score_worker in evaluation_score_workers:
|
for evaluation_score_worker in evaluation_score_workers:
|
||||||
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
|
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
|
||||||
|
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
|
||||||
evaluation_score_worker.SetReferenceSignalFilepath(
|
evaluation_score_worker.SetReferenceSignalFilepath(
|
||||||
reference_input_filepath)
|
reference_input_filepath)
|
||||||
evaluation_score_worker.SetTestedSignalFilepath(
|
evaluation_score_worker.SetTestedSignalFilepath(
|
||||||
|
|||||||
@ -32,3 +32,9 @@ class InputSignalCreatorException(Exception):
|
|||||||
"""Input signal creator exeception.
|
"""Input signal creator exeception.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationScoreException(Exception):
|
||||||
|
"""Evaluation score exeception.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@ -18,26 +18,36 @@ class InputSignalCreator(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Create(cls, name, params):
|
def Create(cls, name, raw_params):
|
||||||
"""Creates a input signal.
|
"""Creates a input signal and its metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Input signal creator name.
|
name: Input signal creator name.
|
||||||
params: Tuple of parameters to pass to the specific signal creator.
|
raw_params: Tuple of parameters to pass to the specific signal creator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AudioSegment instance.
|
(AudioSegment, dict) tuple.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
signal = {}
|
||||||
|
params = {}
|
||||||
|
|
||||||
if name == 'pure_tone':
|
if name == 'pure_tone':
|
||||||
return cls._CreatePureTone(float(params[0]), int(params[1]))
|
params['frequency'] = float(raw_params[0])
|
||||||
|
params['duration'] = int(raw_params[1])
|
||||||
|
signal = cls._CreatePureTone(params['frequency'], params['duration'])
|
||||||
|
else:
|
||||||
|
raise exceptions.InputSignalCreatorException(
|
||||||
|
'Invalid input signal creator name')
|
||||||
|
|
||||||
|
# Complete metadata.
|
||||||
|
params['signal'] = name
|
||||||
|
|
||||||
|
return signal, params
|
||||||
except (TypeError, AssertionError) as e:
|
except (TypeError, AssertionError) as e:
|
||||||
raise exceptions.InputSignalCreatorException(
|
raise exceptions.InputSignalCreatorException(
|
||||||
'Invalid signal creator parameters: {}'.format(e))
|
'Invalid signal creator parameters: {}'.format(e))
|
||||||
|
|
||||||
raise exceptions.InputSignalCreatorException(
|
|
||||||
'Invalid input signal creator name')
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _CreatePureTone(cls, frequency, duration):
|
def _CreatePureTone(cls, frequency, duration):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -148,6 +148,13 @@ class SignalProcessingUtils(object):
|
|||||||
duration=len(template),
|
duration=len(template),
|
||||||
volume=0.0)
|
volume=0.0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def AudioSegmentToRawData(cls, signal):
|
||||||
|
samples = signal.get_array_of_samples()
|
||||||
|
if samples.typecode != 'h':
|
||||||
|
raise exceptions.SignalProcessingException('Unsupported samples type')
|
||||||
|
return np.array(signal.get_array_of_samples(), np.int16)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def DetectHardClipping(cls, signal, threshold=2):
|
def DetectHardClipping(cls, signal, threshold=2):
|
||||||
"""Detects hard clipping.
|
"""Detects hard clipping.
|
||||||
@ -169,13 +176,7 @@ class SignalProcessingUtils(object):
|
|||||||
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
||||||
raise exceptions.SignalProcessingException(
|
raise exceptions.SignalProcessingException(
|
||||||
'hard-clipping detection only supported for 16 bit samples')
|
'hard-clipping detection only supported for 16 bit samples')
|
||||||
|
samples = cls.AudioSegmentToRawData(signal)
|
||||||
# Get raw samples, check type, cast.
|
|
||||||
samples = signal.get_array_of_samples()
|
|
||||||
if samples.typecode != 'h':
|
|
||||||
raise exceptions.SignalProcessingException(
|
|
||||||
'hard-clipping detection only supported for 16 bit samples')
|
|
||||||
samples = np.array(signal.get_array_of_samples(), np.int16)
|
|
||||||
|
|
||||||
# Detect adjacent clipped samples.
|
# Detect adjacent clipped samples.
|
||||||
samples_type_info = np.iinfo(samples.dtype)
|
samples_type_info = np.iinfo(samples.dtype)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from . import echo_path_simulation
|
|||||||
from . import echo_path_simulation_factory
|
from . import echo_path_simulation_factory
|
||||||
from . import eval_scores
|
from . import eval_scores
|
||||||
from . import eval_scores_factory
|
from . import eval_scores_factory
|
||||||
|
from . import exceptions
|
||||||
from . import input_mixer
|
from . import input_mixer
|
||||||
from . import test_data_generation
|
from . import test_data_generation
|
||||||
from . import test_data_generation_factory
|
from . import test_data_generation_factory
|
||||||
@ -248,9 +249,20 @@ class ApmModuleSimulator(object):
|
|||||||
test_data_cache_path=test_data_cache_path,
|
test_data_cache_path=test_data_cache_path,
|
||||||
base_output_path=output_path)
|
base_output_path=output_path)
|
||||||
|
|
||||||
|
# Extract metadata linked to the clean input file (if any).
|
||||||
|
apm_input_metadata = None
|
||||||
|
try:
|
||||||
|
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
|
||||||
|
clean_capture_input_filepath)
|
||||||
|
except IOError as e:
|
||||||
|
apm_input_metadata = {}
|
||||||
|
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
|
||||||
|
apm_input_metadata['test_data_gen_config'] = None
|
||||||
|
|
||||||
# For each test data pair, simulate a call and evaluate.
|
# For each test data pair, simulate a call and evaluate.
|
||||||
for config_name in test_data_generators.config_names:
|
for config_name in test_data_generators.config_names:
|
||||||
logging.info(' - test data generator config: <%s>', config_name)
|
logging.info(' - test data generator config: <%s>', config_name)
|
||||||
|
apm_input_metadata['test_data_gen_config'] = config_name
|
||||||
|
|
||||||
# Paths to the test data generator output.
|
# Paths to the test data generator output.
|
||||||
# Note that the reference signal does not depend on the render input
|
# Note that the reference signal does not depend on the render input
|
||||||
@ -278,9 +290,11 @@ class ApmModuleSimulator(object):
|
|||||||
render_input_filepath=render_input_filepath,
|
render_input_filepath=render_input_filepath,
|
||||||
output_path=evaluation_output_path)
|
output_path=evaluation_output_path)
|
||||||
|
|
||||||
|
try:
|
||||||
# Evaluate.
|
# Evaluate.
|
||||||
self._evaluator.Run(
|
self._evaluator.Run(
|
||||||
evaluation_score_workers=self._evaluation_score_workers,
|
evaluation_score_workers=self._evaluation_score_workers,
|
||||||
|
apm_input_metadata=apm_input_metadata,
|
||||||
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
||||||
reference_input_filepath=reference_signal_filepath,
|
reference_input_filepath=reference_signal_filepath,
|
||||||
output_path=evaluation_output_path)
|
output_path=evaluation_output_path)
|
||||||
@ -295,6 +309,9 @@ class ApmModuleSimulator(object):
|
|||||||
capture_filepath=apm_input_filepath,
|
capture_filepath=apm_input_filepath,
|
||||||
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
||||||
apm_reference_filepath=reference_signal_filepath)
|
apm_reference_filepath=reference_signal_filepath)
|
||||||
|
except exceptions.EvaluationScoreException as e:
|
||||||
|
logging.warning('the evaluation failed: %s', e.message)
|
||||||
|
continue
|
||||||
|
|
||||||
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
||||||
render_input_filepaths):
|
render_input_filepaths):
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
"""Unit tests for the simulation module.
|
"""Unit tests for the simulation module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
@ -33,8 +34,9 @@ class TestApmModuleSimulator(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Create temporary folder and fake audio track."""
|
"""Create temporary folders and fake audio track."""
|
||||||
self._output_path = tempfile.mkdtemp()
|
self._output_path = tempfile.mkdtemp()
|
||||||
|
self._tmp_path = tempfile.mkdtemp()
|
||||||
|
|
||||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||||
@ -46,6 +48,7 @@ class TestApmModuleSimulator(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Recursively delete temporary folders."""
|
"""Recursively delete temporary folders."""
|
||||||
shutil.rmtree(self._output_path)
|
shutil.rmtree(self._output_path)
|
||||||
|
shutil.rmtree(self._tmp_path)
|
||||||
|
|
||||||
def testSimulation(self):
|
def testSimulation(self):
|
||||||
# Instance dependencies to inject and mock.
|
# Instance dependencies to inject and mock.
|
||||||
@ -87,3 +90,39 @@ class TestApmModuleSimulator(unittest.TestCase):
|
|||||||
min_number_of_simulations)
|
min_number_of_simulations)
|
||||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||||
min_number_of_simulations)
|
min_number_of_simulations)
|
||||||
|
|
||||||
|
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
|
||||||
|
logging.warning = mock.MagicMock(name='warning')
|
||||||
|
|
||||||
|
# Instance simulator.
|
||||||
|
simulator = simulation.ApmModuleSimulator(
|
||||||
|
aechen_ir_database_path='',
|
||||||
|
polqa_tool_bin_path=os.path.join(
|
||||||
|
os.path.dirname(__file__), 'fake_polqa'),
|
||||||
|
ap_wrapper=audioproc_wrapper.AudioProcWrapper(),
|
||||||
|
evaluator=evaluation.ApmModuleEvaluator())
|
||||||
|
|
||||||
|
# What to simulate.
|
||||||
|
config_files = ['apm_configs/default.json']
|
||||||
|
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
|
||||||
|
eval_scores = ['thd']
|
||||||
|
|
||||||
|
# Should work.
|
||||||
|
simulator.Run(
|
||||||
|
config_filepaths=config_files,
|
||||||
|
capture_input_filepaths=input_files,
|
||||||
|
test_data_generator_names=['identity'],
|
||||||
|
eval_score_names=eval_scores,
|
||||||
|
output_dir=self._output_path)
|
||||||
|
self.assertFalse(logging.warning.called)
|
||||||
|
|
||||||
|
# Warning expected.
|
||||||
|
simulator.Run(
|
||||||
|
config_filepaths=config_files,
|
||||||
|
capture_input_filepaths=input_files,
|
||||||
|
test_data_generator_names=['white_noise'], # Not allowed with THD.
|
||||||
|
eval_score_names=eval_scores,
|
||||||
|
output_dir=self._output_path)
|
||||||
|
logging.warning.assert_called_with('the evaluation failed: %s', (
|
||||||
|
'The THD score cannot be used with any test data generator other than '
|
||||||
|
'"identity"'))
|
||||||
|
|||||||
@ -147,11 +147,12 @@ class TestDataGenerator(object):
|
|||||||
raise exceptions.InputSignalCreatorException(
|
raise exceptions.InputSignalCreatorException(
|
||||||
'Cannot parse input signal file name')
|
'Cannot parse input signal file name')
|
||||||
|
|
||||||
signal = input_signal_creator.InputSignalCreator.Create(
|
signal, metadata = input_signal_creator.InputSignalCreator.Create(
|
||||||
filename_parts[0], filename_parts[1].split('_'))
|
filename_parts[0], filename_parts[1].split('_'))
|
||||||
|
|
||||||
signal_processing.SignalProcessingUtils.SaveWav(
|
signal_processing.SignalProcessingUtils.SaveWav(
|
||||||
input_signal_filepath, signal)
|
input_signal_filepath, signal)
|
||||||
|
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
|
||||||
|
|
||||||
def _Generate(
|
def _Generate(
|
||||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user