Remove py_quality_assessment and old TODOs in conversational_speech

Bug: webrtc:379542219
Change-Id: I7a6c087ce42f854d9b440da018248323b2435b55
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/368500
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43418}
This commit is contained in:
Alessio Bazzica 2024-11-18 14:50:32 +01:00 committed by WebRTC LUCI CQ
parent 7f775bc94c
commit 331ca30635
52 changed files with 8 additions and 6839 deletions

View File

@ -310,7 +310,6 @@ if (rtc_include_tests) {
":audioproc_unittest_proto",
"aec_dump:aec_dump_unittests",
"test/conversational_speech",
"test/py_quality_assessment",
]
}
}

View File

@ -70,6 +70,7 @@ rtc_library("unittest") {
"../../../../api:array_view",
"../../../../common_audio",
"../../../../rtc_base:logging",
"../../../../rtc_base:safe_conversions",
"../../../../test:fileutils",
"../../../../test:test_support",
"//testing/gtest",

View File

@ -53,6 +53,7 @@
#include "modules/audio_processing/test/conversational_speech/timing.h"
#include "modules/audio_processing/test/conversational_speech/wavreader_factory.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "test/gmock.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
@ -101,17 +102,15 @@ std::unique_ptr<MockWavReaderFactory> CreateMockWavReaderFactory() {
void CreateSineWavFile(absl::string_view filepath,
const MockWavReaderFactory::Params& params,
float frequency = 440.0f) {
// Create samples.
constexpr double two_pi = 2.0 * M_PI;
float frequency_hz = 440.0f) {
const double phase_step = 2 * M_PI * frequency_hz / params.sample_rate;
double phase = 0.0;
std::vector<int16_t> samples(params.num_samples);
for (std::size_t i = 0; i < params.num_samples; ++i) {
// TODO(alessiob): the produced tone is not pure, improve.
samples[i] = std::lround(
32767.0f * std::sin(two_pi * i * frequency / params.sample_rate));
for (size_t i = 0; i < params.num_samples; ++i) {
samples[i] = rtc::saturated_cast<int16_t>(32767.0f * std::sin(phase));
phase += phase_step;
}
// Write samples.
WavWriter wav_writer(filepath, params.sample_rate, params.num_channels);
wav_writer.WriteSamples(samples.data(), params.num_samples);
}

View File

@ -27,7 +27,6 @@ class MockWavReader : public WavReaderInterface {
MockWavReader(int sample_rate, size_t num_channels, size_t num_samples);
~MockWavReader();
// TODO(alessiob): use ON_CALL to return random samples if needed.
MOCK_METHOD(size_t, ReadFloatSamples, (rtc::ArrayView<float>), (override));
MOCK_METHOD(size_t, ReadInt16Samples, (rtc::ArrayView<int16_t>), (override));

View File

@ -187,13 +187,6 @@ std::unique_ptr<std::map<std::string, SpeakerOutputFilePaths>> Simulate(
const auto& audiotrack_readers = multiend_call.audiotrack_readers();
auto audiotracks = PreloadAudioTracks(audiotrack_readers);
// TODO(alessiob): When speaker_names.size() == 2, near-end and far-end
// across the 2 speakers are symmetric; hence, the code below could be
// replaced by only creating the near-end or the far-end. However, this would
// require to split the unit tests and document the behavior in README.md.
// In practice, it should not be an issue since the files are not expected to
// be signinificant.
// Write near-end and far-end output tracks.
for (const auto& speaking_turn : multiend_call.speaking_turns()) {
const std::string& active_speaker_name = speaking_turn.speaker_name;

View File

@ -1,170 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import("../../../../webrtc.gni")
if (!build_with_chromium) {
group("py_quality_assessment") {
testonly = true
deps = [
":scripts",
":unit_tests",
]
}
copy("scripts") {
testonly = true
sources = [
"README.md",
"apm_quality_assessment.py",
"apm_quality_assessment.sh",
"apm_quality_assessment_boxplot.py",
"apm_quality_assessment_export.py",
"apm_quality_assessment_gencfgs.py",
"apm_quality_assessment_optimize.py",
]
outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ]
deps = [
":apm_configs",
":lib",
":output",
"../../../../resources/audio_processing/test/py_quality_assessment:probing_signals",
"../../../../rtc_tools:audioproc_f",
]
}
copy("apm_configs") {
testonly = true
sources = [ "apm_configs/default.json" ]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [
"$root_build_dir/py_quality_assessment/apm_configs/{{source_file_part}}",
]
} # apm_configs
copy("lib") {
testonly = true
sources = [
"quality_assessment/__init__.py",
"quality_assessment/annotations.py",
"quality_assessment/audioproc_wrapper.py",
"quality_assessment/collect_data.py",
"quality_assessment/data_access.py",
"quality_assessment/echo_path_simulation.py",
"quality_assessment/echo_path_simulation_factory.py",
"quality_assessment/eval_scores.py",
"quality_assessment/eval_scores_factory.py",
"quality_assessment/evaluation.py",
"quality_assessment/exceptions.py",
"quality_assessment/export.py",
"quality_assessment/export_unittest.py",
"quality_assessment/external_vad.py",
"quality_assessment/input_mixer.py",
"quality_assessment/input_signal_creator.py",
"quality_assessment/results.css",
"quality_assessment/results.js",
"quality_assessment/signal_processing.py",
"quality_assessment/simulation.py",
"quality_assessment/test_data_generation.py",
"quality_assessment/test_data_generation_factory.py",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ]
deps = [ "../../../../resources/audio_processing/test/py_quality_assessment:noise_tracks" ]
}
copy("output") {
testonly = true
sources = [ "output/README.md" ]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs =
[ "$root_build_dir/py_quality_assessment/output/{{source_file_part}}" ]
}
group("unit_tests") {
testonly = true
visibility = [ ":*" ] # Only targets in this file can depend on this.
deps = [
":apm_vad",
":fake_polqa",
":lib_unit_tests",
":scripts_unit_tests",
":vad",
]
}
rtc_executable("fake_polqa") {
testonly = true
sources = [ "quality_assessment/fake_polqa.cc" ]
visibility = [ ":*" ] # Only targets in this file can depend on this.
output_dir = "${root_out_dir}/py_quality_assessment/quality_assessment"
deps = [
"../../../../rtc_base:checks",
"//third_party/abseil-cpp/absl/strings",
]
}
rtc_executable("vad") {
testonly = true
sources = [ "quality_assessment/vad.cc" ]
deps = [
"../../../../common_audio",
"../../../../rtc_base:logging",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
]
}
rtc_executable("apm_vad") {
testonly = true
sources = [ "quality_assessment/apm_vad.cc" ]
deps = [
"../..",
"../../../../common_audio",
"../../../../rtc_base:logging",
"../../vad",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
]
}
rtc_executable("sound_level") {
testonly = true
sources = [ "quality_assessment/sound_level.cc" ]
deps = [
"../..",
"../../../../common_audio",
"../../../../rtc_base:logging",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",
]
}
copy("lib_unit_tests") {
testonly = true
sources = [
"quality_assessment/annotations_unittest.py",
"quality_assessment/echo_path_simulation_unittest.py",
"quality_assessment/eval_scores_unittest.py",
"quality_assessment/fake_external_vad.py",
"quality_assessment/input_mixer_unittest.py",
"quality_assessment/signal_processing_unittest.py",
"quality_assessment/simulation_unittest.py",
"quality_assessment/test_data_generation_unittest.py",
]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ]
}
copy("scripts_unit_tests") {
testonly = true
sources = [ "apm_quality_assessment_unittest.py" ]
visibility = [ ":*" ] # Only targets in this file can depend on this.
outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ]
}
}

View File

@ -1,5 +0,0 @@
aleloi@webrtc.org
alessiob@webrtc.org
henrik.lundin@webrtc.org
ivoc@webrtc.org
peah@webrtc.org

View File

@ -1,125 +0,0 @@
# APM Quality Assessment tool
Python wrapper of APM simulators (e.g., `audioproc_f`) with which quality
assessment can be automatized. The tool allows to simulate different noise
conditions, input signals, APM configurations and it computes different scores.
Once the scores are computed, the results can be easily exported to an HTML page
which allows to listen to the APM input and output signals and also the
reference one used for evaluation.
## Dependencies
- OS: Linux
- Python 2.7
- Python libraries: enum34, numpy, scipy, pydub (0.17.0+), pandas (0.20.1+),
pyquery (1.2+), jsmin (2.2+), csscompressor (0.9.4)
- It is recommended that a dedicated Python environment is used
- install `virtualenv`
- `$ sudo apt-get install python-virtualenv`
- setup a new Python environment (e.g., `my_env`)
- `$ cd ~ && virtualenv my_env`
- activate the new Python environment
- `$ source ~/my_env/bin/activate`
- add dependcies via `pip`
- `(my_env)$ pip install enum34 numpy pydub scipy pandas pyquery jsmin \`
`csscompressor`
- PolqaOem64 (see http://www.polqa.info/)
- Tested with POLQA Library v1.180 / P863 v2.400
- Aachen Impulse Response (AIR) Database
- Download https://www2.iks.rwth-aachen.de/air/air_database_release_1_4.zip
- Input probing signals and noise tracks (you can make your own dataset - *1)
## Build
- Compile WebRTC
- Go to `out/Default/py_quality_assessment` and check that
`apm_quality_assessment.py` exists
## Unit tests
- Compile WebRTC
- Go to `out/Default/py_quality_assessment`
- Run `python -m unittest discover -p "*_unittest.py"`
## First time setup
- Deploy PolqaOem64 and set the `POLQA_PATH` environment variable
- e.g., `$ export POLQA_PATH=/var/opt/PolqaOem64`
- Deploy the AIR Database and set the `AECHEN_IR_DATABASE_PATH` environment
variable
- e.g., `$ export AECHEN_IR_DATABASE_PATH=/var/opt/AIR_1_4`
- Deploy probing signal tracks into
- `out/Default/py_quality_assessment/probing_signals` (*1)
- Deploy noise tracks into
- `out/Default/py_quality_assessment/noise_tracks` (*1, *2)
(*1) You can use custom files as long as they are mono tracks sampled at 48kHz
encoded in the 16 bit signed format (it is recommended that the tracks are
converted and exported with Audacity).
## Usage (scores computation)
- Go to `out/Default/py_quality_assessment`
- Check the `apm_quality_assessment.sh` as an example script to parallelize the
experiments
- Adjust the script according to your preferences (e.g., output path)
- Run `apm_quality_assessment.sh`
- The script will end by opening the browser and showing ALL the computed
scores
## Usage (export reports)
Showing all the results at once can be confusing. You therefore may want to
export separate reports. In this case, you can use the
`apm_quality_assessment_export.py` script as follows:
- Set `--output_dir, -o` to the same value used in `apm_quality_assessment.sh`
- Use regular expressions to select/filter out scores by
- APM configurations: `--config_names, -c`
- capture signals: `--capture_names, -i`
- render signals: `--render_names, -r`
- echo simulator: `--echo_simulator_names, -e`
- test data generators: `--test_data_generators, -t`
- scores: `--eval_scores, -s`
- Assign a suffix to the report name using `-f <suffix>`
For instance:
```
$ ./apm_quality_assessment_export.py \
-o output/ \
-c "(^default$)|(.*AE.*)" \
-t \(white_noise\) \
-s \(polqa\) \
-f echo
```
## Usage (boxplot)
After generating stats, it can help to visualize how a score depends on a
certain APM simulator parameter. The `apm_quality_assessment_boxplot.py` script
helps with that, producing plots similar to [this
one](https://matplotlib.org/mpl_examples/pylab_examples/boxplot_demo_06.png).
Suppose some scores come from running the APM simulator `audioproc_f` with
or without the level controller: `--lc=1` or `--lc=0`. Then two boxplots
side by side can be generated with
```
$ ./apm_quality_assessment_boxplot.py \
-o /path/to/output
-v <score_name>
-n /path/to/dir/with/apm_configs
-z lc
```
## Troubleshooting
The input wav file must be:
- sampled at a sample rate that is a multiple of 100 (required by POLQA)
- in the 16 bit format (required by `audioproc_f`)
- encoded in the Microsoft WAV signed 16 bit PCM format (Audacity default
when exporting)
Depending on the license, the POLQA tool may take “breaks” as a way to limit the
throughput. When this happens, the APM Quality Assessment tool is slowed down.
For more details about this limitation, check Section 10.9.1 in the POLQA manual
v.1.18.
In case of issues with the POLQA score computation, check
`py_quality_assessment/eval_scores.py` and adapt
`PolqaScore._parse_output_file()`.
The code can be also fixed directly into the build directory (namely,
`out/Default/py_quality_assessment/eval_scores.py`).

View File

@ -1 +0,0 @@
{"-all_default": null}

View File

@ -1,217 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Perform APM module quality assessment on one or more input files using one or
more APM simulator configuration files and one or more test data generators.
Usage: apm_quality_assessment.py -i audio1.wav [audio2.wav ...]
-c cfg1.json [cfg2.json ...]
-n white [echo ...]
-e audio_level [polqa ...]
-o /path/to/output
"""
import argparse
import logging
import os
import sys
import quality_assessment.audioproc_wrapper as audioproc_wrapper
import quality_assessment.echo_path_simulation as echo_path_simulation
import quality_assessment.eval_scores as eval_scores
import quality_assessment.evaluation as evaluation
import quality_assessment.eval_scores_factory as eval_scores_factory
import quality_assessment.external_vad as external_vad
import quality_assessment.test_data_generation as test_data_generation
import quality_assessment.test_data_generation_factory as \
test_data_generation_factory
import quality_assessment.simulation as simulation
_ECHO_PATH_SIMULATOR_NAMES = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
_TEST_DATA_GENERATOR_CLASSES = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_TEST_DATA_GENERATORS_NAMES = _TEST_DATA_GENERATOR_CLASSES.keys()
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
_EVAL_SCORE_WORKER_NAMES = _EVAL_SCORE_WORKER_CLASSES.keys()
_DEFAULT_CONFIG_FILE = 'apm_configs/default.json'
_POLQA_BIN_NAME = 'PolqaOem64'
def _InstanceArgumentsParser():
"""Arguments parser factory.
"""
parser = argparse.ArgumentParser(description=(
'Perform APM module quality assessment on one or more input files using '
'one or more APM simulator configuration files and one or more '
'test data generators.'))
parser.add_argument('-c',
'--config_files',
nargs='+',
required=False,
help=('path to the configuration files defining the '
'arguments with which the APM simulator tool is '
'called'),
default=[_DEFAULT_CONFIG_FILE])
parser.add_argument(
'-i',
'--capture_input_files',
nargs='+',
required=True,
help='path to the capture input wav files (one or more)')
parser.add_argument('-r',
'--render_input_files',
nargs='+',
required=False,
help=('path to the render input wav files; either '
'omitted or one file for each file in '
'--capture_input_files (files will be paired by '
'index)'),
default=None)
parser.add_argument('-p',
'--echo_path_simulator',
required=False,
help=('custom echo path simulator name; required if '
'--render_input_files is specified'),
choices=_ECHO_PATH_SIMULATOR_NAMES,
default=echo_path_simulation.NoEchoPathSimulator.NAME)
parser.add_argument('-t',
'--test_data_generators',
nargs='+',
required=False,
help='custom list of test data generators to use',
choices=_TEST_DATA_GENERATORS_NAMES,
default=_TEST_DATA_GENERATORS_NAMES)
parser.add_argument('--additive_noise_tracks_path', required=False,
help='path to the wav files for the additive',
default=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH)
parser.add_argument('-e',
'--eval_scores',
nargs='+',
required=False,
help='custom list of evaluation scores to use',
choices=_EVAL_SCORE_WORKER_NAMES,
default=_EVAL_SCORE_WORKER_NAMES)
parser.add_argument('-o',
'--output_dir',
required=False,
help=('base path to the output directory in which the '
'output wav files and the evaluation outcomes '
'are saved'),
default='output')
parser.add_argument('--polqa_path',
required=True,
help='path to the POLQA tool')
parser.add_argument('--air_db_path',
required=True,
help='path to the Aechen IR database')
parser.add_argument('--apm_sim_path', required=False,
help='path to the APM simulator tool',
default=audioproc_wrapper. \
AudioProcWrapper. \
DEFAULT_APM_SIMULATOR_BIN_PATH)
parser.add_argument('--echo_metric_tool_bin_path',
required=False,
help=('path to the echo metric binary '
'(required for the echo eval score)'),
default=None)
parser.add_argument(
'--copy_with_identity_generator',
required=False,
help=('If true, the identity test data generator makes a '
'copy of the clean speech input file.'),
default=False)
parser.add_argument('--external_vad_paths',
nargs='+',
required=False,
help=('Paths to external VAD programs. Each must take'
'\'-i <wav file> -o <output>\' inputs'),
default=[])
parser.add_argument('--external_vad_names',
nargs='+',
required=False,
help=('Keys to the vad paths. Must be different and '
'as many as the paths.'),
default=[])
return parser
def _ValidateArguments(args, parser):
if args.capture_input_files and args.render_input_files and (len(
args.capture_input_files) != len(args.render_input_files)):
parser.error(
'--render_input_files and --capture_input_files must be lists '
'having the same length')
sys.exit(1)
if args.render_input_files and not args.echo_path_simulator:
parser.error(
'when --render_input_files is set, --echo_path_simulator is '
'also required')
sys.exit(1)
if len(args.external_vad_names) != len(args.external_vad_paths):
parser.error('If provided, --external_vad_paths and '
'--external_vad_names must '
'have the same number of arguments.')
sys.exit(1)
def main():
# TODO(alessiob): level = logging.INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
_ValidateArguments(args, parser)
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=args.air_db_path,
noise_tracks_path=args.additive_noise_tracks_path,
copy_with_identity=args.copy_with_identity_generator)),
evaluation_score_factory=eval_scores_factory.
EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
evaluator=evaluation.ApmModuleEvaluator(),
external_vads=external_vad.ExternalVad.ConstructVadDict(
args.external_vad_paths, args.external_vad_names))
simulator.Run(config_filepaths=args.config_files,
capture_input_filepaths=args.capture_input_files,
render_input_filepaths=args.render_input_files,
echo_path_simulator_name=args.echo_path_simulator,
test_data_generator_names=args.test_data_generators,
eval_score_names=args.eval_scores,
output_dir=args.output_dir)
sys.exit(0)
if __name__ == '__main__':
main()

View File

@ -1,91 +0,0 @@
#!/bin/bash
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
# Path to the POLQA tool.
if [ -z ${POLQA_PATH} ]; then # Check if defined.
# Default location.
export POLQA_PATH='/var/opt/PolqaOem64'
fi
if [ -d "${POLQA_PATH}" ]; then
echo "POLQA found in ${POLQA_PATH}"
else
echo "POLQA not found in ${POLQA_PATH}"
exit 1
fi
# Path to the Aechen IR database.
if [ -z ${AECHEN_IR_DATABASE_PATH} ]; then # Check if defined.
# Default location.
export AECHEN_IR_DATABASE_PATH='/var/opt/AIR_1_4'
fi
if [ -d "${AECHEN_IR_DATABASE_PATH}" ]; then
echo "AIR database found in ${AECHEN_IR_DATABASE_PATH}"
else
echo "AIR database not found in ${AECHEN_IR_DATABASE_PATH}"
exit 1
fi
# Customize probing signals, test data generators and scores if needed.
CAPTURE_SIGNALS=(probing_signals/*.wav)
TEST_DATA_GENERATORS=( \
"identity" \
"white_noise" \
# "environmental_noise" \
# "reverberation" \
)
SCORES=( \
# "polqa" \
"audio_level_peak" \
"audio_level_mean" \
)
OUTPUT_PATH=output
# Generate standard APM config files.
chmod +x apm_quality_assessment_gencfgs.py
./apm_quality_assessment_gencfgs.py
# Customize APM configurations if needed.
APM_CONFIGS=(apm_configs/*.json)
# Add output path if missing.
if [ ! -d ${OUTPUT_PATH} ]; then
mkdir ${OUTPUT_PATH}
fi
# Start one process for each "probing signal"-"test data source" pair.
chmod +x apm_quality_assessment.py
for capture_signal_filepath in "${CAPTURE_SIGNALS[@]}" ; do
probing_signal_name="$(basename $capture_signal_filepath)"
probing_signal_name="${probing_signal_name%.*}"
for test_data_gen_name in "${TEST_DATA_GENERATORS[@]}" ; do
LOG_FILE="${OUTPUT_PATH}/apm_qa-${probing_signal_name}-"`
`"${test_data_gen_name}.log"
echo "Starting ${probing_signal_name} ${test_data_gen_name} "`
`"(see ${LOG_FILE})"
./apm_quality_assessment.py \
--polqa_path ${POLQA_PATH}\
--air_db_path ${AECHEN_IR_DATABASE_PATH}\
-i ${capture_signal_filepath} \
-o ${OUTPUT_PATH} \
-t ${test_data_gen_name} \
-c "${APM_CONFIGS[@]}" \
-e "${SCORES[@]}" > $LOG_FILE 2>&1 &
done
done
# Join Python processes running apm_quality_assessment.py.
wait
# Export results.
chmod +x ./apm_quality_assessment_export.py
./apm_quality_assessment_export.py -o ${OUTPUT_PATH}
# Show results in the browser.
RESULTS_FILE="$(realpath ${OUTPUT_PATH}/results.html)"
sensible-browser "file://${RESULTS_FILE}" > /dev/null 2>&1 &

View File

@ -1,154 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Shows boxplots of given score for different values of selected
parameters. Can be used to compare scores by audioproc_f flag.
Usage: apm_quality_assessment_boxplot.py -o /path/to/output
-v polqa
-n /path/to/dir/with/apm_configs
-z audioproc_f_arg1 [arg2 ...]
Arguments --config_names, --render_names, --echo_simulator_names,
--test_data_generators, --eval_scores can be used to filter the data
used for plotting.
"""
import collections
import logging
import matplotlib.pyplot as plt
import os
import quality_assessment.data_access as data_access
import quality_assessment.collect_data as collect_data
def InstanceArgumentsParser():
"""Arguments parser factory.
"""
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Shows boxplot of given score for different values of selected'
'parameters. Can be used to compare scores by audioproc_f flag')
parser.add_argument('-v',
'--eval_score',
required=True,
help=('Score name for constructing boxplots'))
parser.add_argument(
'-n',
'--config_dir',
required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument('-z',
'--params_to_plot',
required=True,
nargs='+',
help=('audioproc_f parameter values'
'by which to group scores (no leading dash)'))
return parser
def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
"""Filters data on the values of one or more parameters.
Args:
data_frame: pandas.DataFrame of all used input data.
filter_params: each config of the input data is assumed to have
exactly one parameter from `filter_params` defined. Every value
of the parameters in `filter_params` is a key in the returned
dict; the associated value is all cells of the data with that
value of the parameter.
score_name: Name of score which value is boxplotted. Currently cannot do
more than one value.
config_dir: path to dir with APM configs.
Returns: dictionary, key is a param value, result is all scores for
that param value (see `filter_params` for explanation).
"""
results = collections.defaultdict(dict)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + '.json'))
data_with_config = data_frame[data_frame.apm_config == config_name]
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
score_name]
# Exactly one of `params_to_plot` must match:
(matching_param, ) = [
x for x in filter_params if '-' + x in config_json
]
# Add scores for every track to the result.
for capture_name in data_cell_scores.capture:
result_score = float(data_cell_scores[data_cell_scores.capture ==
capture_name].score)
config_dict = results[config_json['-' + matching_param]]
if capture_name not in config_dict:
config_dict[capture_name] = {}
config_dict[capture_name][matching_param] = result_score
return results
def _FlattenToScoresList(config_param_score_dict):
"""Extracts a list of scores from input data structure.
Args:
config_param_score_dict: of the form {'capture_name':
{'param_name' : score_value,.. } ..}
Returns: Plain list of all score value present in input data
structure
"""
result = []
for capture_name in config_param_score_dict:
result += list(config_param_score_dict[capture_name].values())
return result
def main():
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = InstanceArgumentsParser()
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Filter the data by `args.params_to_plot`
scores_filtered = FilterScoresByParams(scores_data_frame,
args.params_to_plot,
args.eval_score, args.config_dir)
data_list = sorted(scores_filtered.items())
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
data_labels = [x for (x, _) in data_list]
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
axes.boxplot(data_values, labels=data_labels)
axes.set_ylabel(args.eval_score)
axes.set_xlabel('/'.join(args.params_to_plot))
plt.show()
if __name__ == "__main__":
main()

View File

@ -1,63 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Export the scores computed by the apm_quality_assessment.py script into an
HTML file.
"""
import logging
import os
import sys
import quality_assessment.collect_data as collect_data
import quality_assessment.export as export
def _BuildOutputFilename(filename_suffix):
"""Builds the filename for the exported file.
Args:
filename_suffix: suffix for the output file name.
Returns:
A string.
"""
if filename_suffix is None:
return 'results.html'
return 'results-{}.html'.format(filename_suffix)
def main():
# Init.
logging.basicConfig(
level=logging.DEBUG) # TODO(alessio): INFO once debugged.
parser = collect_data.InstanceArgumentsParser()
parser.add_argument('-f',
'--filename_suffix',
help=('suffix of the exported file'))
parser.description = ('Exports pre-computed APM module quality assessment '
'results into HTML tables')
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Export.
output_filepath = os.path.join(args.output_dir,
_BuildOutputFilename(args.filename_suffix))
exporter = export.HtmlExport(output_filepath)
exporter.Export(scores_data_frame)
logging.info('output file successfully written in %s', output_filepath)
sys.exit(0)
if __name__ == '__main__':
main()

View File

@ -1,128 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Generate .json files with which the APM module can be tested using the
apm_quality_assessment.py script and audioproc_f as APM simulator.
"""
import logging
import os
import quality_assessment.data_access as data_access
OUTPUT_PATH = os.path.abspath('apm_configs')
def _GenerateDefaultOverridden(config_override):
"""Generates one or more APM overriden configurations.
For each item in config_override, it overrides the default configuration and
writes a new APM configuration file.
The default settings are loaded via "-all_default".
Check "src/modules/audio_processing/test/audioproc_float.cc" and search
for "if (FLAG_all_default) {".
For instance, in 55eb6d621489730084927868fed195d3645a9ec9 the default is this:
settings.use_aec = rtc::Optional<bool>(true);
settings.use_aecm = rtc::Optional<bool>(false);
settings.use_agc = rtc::Optional<bool>(true);
settings.use_bf = rtc::Optional<bool>(false);
settings.use_ed = rtc::Optional<bool>(false);
settings.use_hpf = rtc::Optional<bool>(true);
settings.use_le = rtc::Optional<bool>(true);
settings.use_ns = rtc::Optional<bool>(true);
settings.use_ts = rtc::Optional<bool>(true);
settings.use_vad = rtc::Optional<bool>(true);
Args:
config_override: dict of APM configuration file names as keys; the values
are dict instances encoding the audioproc_f flags.
"""
for config_filename in config_override:
config = config_override[config_filename]
config['-all_default'] = None
config_filepath = os.path.join(
OUTPUT_PATH, 'default-{}.json'.format(config_filename))
logging.debug('config file <%s> | %s', config_filepath, config)
data_access.AudioProcConfigFile.Save(config_filepath, config)
logging.info('config file created: <%s>', config_filepath)
def _GenerateAllDefaultButOne():
"""Disables the flags enabled by default one-by-one.
"""
config_sets = {
'no_AEC': {
'-aec': 0,
},
'no_AGC': {
'-agc': 0,
},
'no_HP_filter': {
'-hpf': 0,
},
'no_level_estimator': {
'-le': 0,
},
'no_noise_suppressor': {
'-ns': 0,
},
'no_transient_suppressor': {
'-ts': 0,
},
'no_vad': {
'-vad': 0,
},
}
_GenerateDefaultOverridden(config_sets)
def _GenerateAllDefaultPlusOne():
"""Enables the flags disabled by default one-by-one.
"""
config_sets = {
'with_AECM': {
'-aec': 0,
'-aecm': 1,
}, # AEC and AECM are exclusive.
'with_AGC_limiter': {
'-agc_limiter': 1,
},
'with_AEC_delay_agnostic': {
'-delay_agnostic': 1,
},
'with_drift_compensation': {
'-drift_compensation': 1,
},
'with_residual_echo_detector': {
'-ed': 1,
},
'with_AEC_extended_filter': {
'-extended_filter': 1,
},
'with_LC': {
'-lc': 1,
},
'with_refined_adaptive_filter': {
'-refined_adaptive_filter': 1,
},
}
_GenerateDefaultOverridden(config_sets)
def main():
logging.basicConfig(level=logging.INFO)
_GenerateAllDefaultPlusOne()
_GenerateAllDefaultButOne()
if __name__ == '__main__':
main()

View File

@ -1,189 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Finds the APM configuration that maximizes a provided metric by
parsing the output generated apm_quality_assessment.py.
"""
from __future__ import division
import collections
import logging
import os
import quality_assessment.data_access as data_access
import quality_assessment.collect_data as collect_data
def _InstanceArgumentsParser():
"""Arguments parser factory. Extends the arguments from 'collect_data'
with a few extra for selecting what parameters to optimize for.
"""
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Rudimentary optimization of a function over different parameter'
'combinations.')
parser.add_argument(
'-n',
'--config_dir',
required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument('-p',
'--params',
required=True,
nargs='+',
help=('parameters to parse from the config files in'
'config_dir'))
parser.add_argument(
'-z',
'--params_not_to_optimize',
required=False,
nargs='+',
default=[],
help=('parameters from `params` not to be optimized for'))
return parser
def _ConfigurationAndScores(data_frame, params, params_not_to_optimize,
config_dir):
"""Returns a list of all configurations and scores.
Args:
data_frame: A pandas data frame with the scores and config name
returned by _FindScores.
params: The parameter names to parse from configs the config
directory
params_not_to_optimize: The parameter names which shouldn't affect
the optimal parameter
selection. E.g., fixed settings and not
tunable parameters.
config_dir: Path to folder with config files.
Returns:
Dictionary of the form
{param_combination: [{params: {param1: value1, ...},
scores: {score1: value1, ...}}]}.
The key `param_combination` runs over all parameter combinations
of the parameters in `params` and not in
`params_not_to_optimize`. A corresponding value is a list of all
param combinations for params in `params_not_to_optimize` and
their scores.
"""
results = collections.defaultdict(list)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
score_names = data_frame['eval_score_name'].drop_duplicates(
).values.tolist()
# Normalize the scores
normalization_constants = {}
for score_name in score_names:
scores = data_frame[data_frame.eval_score_name == score_name].score
normalization_constants[score_name] = max(scores)
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
param_combination = collections.namedtuple("ParamCombination",
params_to_optimize)
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + ".json"))
scores = {}
data_cell = data_frame[data_frame.apm_config == config_name]
for score_name in score_names:
data_cell_scores = data_cell[data_cell.eval_score_name ==
score_name].score
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
scores[score_name] /= normalization_constants[score_name]
result = {'scores': scores, 'params': {}}
config_optimize_params = {}
for param in params:
if param in params_to_optimize:
config_optimize_params[param] = config_json['-' + param]
else:
result['params'][param] = config_json['-' + param]
current_param_combination = param_combination(**config_optimize_params)
results[current_param_combination].append(result)
return results
def _FindOptimalParameter(configs_and_scores, score_weighting):
"""Finds the config producing the maximal score.
Args:
configs_and_scores: structure of the form returned by
_ConfigurationAndScores
score_weighting: a function to weight together all score values of
the form [{params: {param1: value1, ...}, scores:
{score1: value1, ...}}] into a numeric
value
Returns:
the config that has the largest values of `score_weighting` applied
to its scores.
"""
min_score = float('+inf')
best_params = None
for config in configs_and_scores:
scores_and_params = configs_and_scores[config]
current_score = score_weighting(scores_and_params)
if current_score < min_score:
min_score = current_score
best_params = config
logging.debug("Score: %f", current_score)
logging.debug("Config: %s", str(config))
return best_params
def _ExampleWeighting(scores_and_configs):
"""Example argument to `_FindOptimalParameter`
Args:
scores_and_configs: a list of configs and scores, in the form
described in _FindOptimalParameter
Returns:
numeric value, the sum of all scores
"""
res = 0
for score_config in scores_and_configs:
res += sum(score_config['scores'].values())
return res
def main():
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug('Src path <%s>', src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
all_scores = _ConfigurationAndScores(scores_data_frame, args.params,
args.params_not_to_optimize,
args.config_dir)
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
logging.info('Optimal parameter combination: <%s>', opt_param)
logging.info('It\'s score values: <%s>', all_scores[opt_param])
if __name__ == "__main__":
main()

View File

@ -1,28 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the apm_quality_assessment module.
"""
import sys
import unittest
import mock
import apm_quality_assessment
class TestSimulationScript(unittest.TestCase):
"""Unit tests for the apm_quality_assessment module.
"""
def testMain(self):
# Exit with error code if no arguments are passed.
with self.assertRaises(SystemExit) as cm, mock.patch.object(
sys, 'argv', ['apm_quality_assessment.py']):
apm_quality_assessment.main()
self.assertGreater(cm.exception.code, 0)

View File

@ -1 +0,0 @@
You can use this folder for the output generated by the apm_quality_assessment scripts.

View File

@ -1,7 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.

View File

@ -1,296 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Extraction of annotations from audio files.
"""
from __future__ import division
import logging
import os
import shutil
import struct
import subprocess
import sys
import tempfile
try:
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import external_vad
from . import exceptions
from . import signal_processing
class AudioAnnotationsExtractor(object):
"""Extracts annotations from audio files.
"""
class VadType(object):
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
def __init__(self, value):
if (not isinstance(value, int)) or not 0 <= value <= 7:
raise exceptions.InitializationException('Invalid vad type: ' +
value)
self._value = value
def Contains(self, vad_type):
return self._value | vad_type == self._value
def __str__(self):
vads = []
if self.Contains(self.ENERGY_THRESHOLD):
vads.append("energy")
if self.Contains(self.WEBRTC_COMMON_AUDIO):
vads.append("common_audio")
if self.Contains(self.WEBRTC_APM):
vads.append("apm")
return "VadType({})".format(", ".join(vads))
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
# Level estimation params.
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
_LEVEL_FRAME_SIZE_MS = 1.0
# The time constants in ms indicate the time it takes for the level estimate
# to go down/up by 1 db if the signal is zero.
_LEVEL_ATTACK_MS = 5.0
_LEVEL_DECAY_MS = 20.0
# VAD params.
_VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
os.pardir, os.pardir)
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
_VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
def __init__(self, vad_type, external_vads=None):
self._signal = None
self._level = None
self._level_frame_size = None
self._common_audio_vad = None
self._energy_vad = None
self._apm_vad_probs = None
self._apm_vad_rms = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None
self._c_decay = None
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
if external_vads is None:
external_vads = {}
self._external_vads = external_vads
assert len(self._external_vads) == len(external_vads), (
'The external VAD names must be unique.')
for vad in external_vads.values():
if not isinstance(vad, external_vad.ExternalVad):
raise exceptions.InitializationException('Invalid vad type: ' +
str(type(vad)))
logging.info('External VAD used for annotation: ' + str(vad.name))
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
self._VAD_WEBRTC_APM_PATH
@classmethod
def GetOutputFileNameTemplate(cls):
return cls._OUTPUT_FILENAME_TEMPLATE
def GetLevel(self):
return self._level
def GetLevelFrameSize(self):
return self._level_frame_size
@classmethod
def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return self._energy_vad
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return self._common_audio_vad
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
raise exceptions.InitializationException('Invalid vad type: ' +
vad_type)
def GetVadFrameSize(self):
return self._vad_frame_size
def GetVadFrameSizeMs(self):
return self._vad_frame_size_ms
def Extract(self, filepath):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(
filepath)
if self._signal.channels != 1:
raise NotImplementedError(
'Multiple-channel annotations not implemented')
# Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 *
(self._LEVEL_FRAME_SIZE_MS))
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
self._LEVEL_ATTACK_MS))
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
self._LEVEL_DECAY_MS))
# Compute level.
self._LevelEstimation()
# Ideal VAD output, it requires clean speech with high SNR as input.
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
# Naive VAD based on level thresholding.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._energy_vad = np.uint8(self._level > vad_threshold)
self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
# WebRTC common_audio/ VAD.
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
for extvad_name in self._external_vads:
self._external_vads[extvad_name].Run(filepath)
def Save(self, output_path, annotation_name=""):
ext_kwargs = {
'extvad_conf-' + ext_vad:
self._external_vads[ext_vad].GetVadOutput()
for ext_vad in self._external_vads
}
np.savez_compressed(file=os.path.join(
output_path,
self.GetOutputFileNameTemplate().format(annotation_name)),
level=self._level,
level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._common_audio_vad,
vad_energy_output=self._energy_vad,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms,
**ext_kwargs)
def _LevelEstimation(self):
# Read samples.
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._signal).astype(np.float32) / 32768.0
num_frames = len(samples) // self._level_frame_size
num_samples = num_frames * self._level_frame_size
# Envelope.
self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
(num_frames, self._level_frame_size)),
axis=1)
assert len(self._level) == num_frames
# Envelope smoothing.
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
for i in range(1, num_frames):
self._level[i] = smooth(
self._level[i], self._level[i - 1], self._c_attack if
(self._level[i] > self._level[i - 1]) else self._c_decay)
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
self._common_audio_vad = None
self._vad_frame_size = None
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
output_file_path
],
cwd=self._VAD_WEBRTC_PATH)
# Read bytes.
with open(output_file_path, 'rb') as f:
raw_data = f.read()
# Parse side information.
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
assert self._vad_frame_size_ms in [10, 20, 30]
extra_bits = struct.unpack('B', raw_data[-1])[0]
assert 0 <= extra_bits <= 8
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes -
2) - extra_bits # 8 frames for each byte.
self._common_audio_vad = np.zeros(num_frames, np.uint8)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._common_audio_vad[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message +
')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def _RunWebRtcApmVad(self, wav_file_path):
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path_probs = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
output_file_path_rms = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
output_file_path_probs, '-o_rms', output_file_path_rms
],
cwd=self._VAD_WEBRTC_PATH)
# Parse annotations.
self._apm_vad_probs = np.fromfile(output_file_path_probs,
np.double)
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
except Exception as e:
logging.error('Error while running the WebRTC APM VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)

View File

@ -1,160 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the annotations module.
"""
from __future__ import division
import logging
import os
import shutil
import tempfile
import unittest
import numpy as np
from . import annotations
from . import external_vad
from . import input_signal_creator
from . import signal_processing
class TestAnnotationsExtraction(unittest.TestCase):
"""Unit tests for the annotations module.
"""
_CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
| _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
| _VAD_TYPE_CLASS.WEBRTC_APM)
def setUp(self):
"""Create temporary folder."""
self._tmp_path = tempfile.mkdtemp()
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone)
self._sample_rate = pure_tone.frame_rate
def tearDown(self):
"""Recursively delete temporary folder."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' +
(self._tmp_path))
def testFrameSizes(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(
samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testVoiceActivityDetectors(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(
self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(
float(np.sum(vad_output)) / len(vad_output), 0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(
self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(
float(np.sum(vad_output)) / len(vad_output), 0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
# pylint: disable=unpacking-non-sequence
(vad_probs,
vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(
float(np.sum(vad_probs)) / len(vad_probs), 0.5)
self.assertGreaterEqual(
float(np.sum(vad_rms)) / len(vad_rms), 20000)
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
num_frames).astype(np.float32) * frame_size_ms / 1000.0
level = e.GetLevel()
t_level = frame_times_s(num_frames=len(level),
frame_size_ms=e.GetLevelFrameSizeMs())
t_vad = frame_times_s(num_frames=len(vad_output),
frame_size_ms=e.GetVadFrameSizeMs())
import matplotlib.pyplot as plt
plt.figure()
plt.hold(True)
plt.plot(t_level, level)
plt.plot(t_vad, vad_output * np.max(level), '.')
plt.show()
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, "fake-annotation")
data = np.load(
os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)
def testEmptyExternalShouldNotCrash(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
annotations.AudioAnnotationsExtractor(vad_type_value, {})
def testFakeExternalSaveLoad(self):
def FakeExternalFactory():
return external_vad.ExternalVad(
os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fake_external_vad.py'), 'fake')
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
e = annotations.AudioAnnotationsExtractor(
vad_type_value, {'fake': FakeExternalFactory()})
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, annotation_name="fake-annotation")
data = np.load(
os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
data['extvad_conf-fake'])

View File

@ -1,96 +0,0 @@
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
//
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file in the root of the source
// tree. An additional intellectual property rights grant can be found
// in the file PATENTS. All contributing project authors may
// be found in the AUTHORS file in the root of the source tree.
#include <array>
#include <fstream>
#include <memory>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "common_audio/wav_file.h"
#include "modules/audio_processing/vad/voice_activity_detector.h"
#include "rtc_base/logging.h"
ABSL_FLAG(std::string, i, "", "Input wav file");
ABSL_FLAG(std::string, o_probs, "", "VAD probabilities output file");
ABSL_FLAG(std::string, o_rms, "", "VAD output file");
namespace webrtc {
namespace test {
namespace {
constexpr uint8_t kAudioFrameLengthMilliseconds = 10;
constexpr int kMaxSampleRate = 48000;
constexpr size_t kMaxFrameLen =
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
const std::string input_file = absl::GetFlag(FLAGS_i);
const std::string output_probs_file = absl::GetFlag(FLAGS_o_probs);
const std::string output_file = absl::GetFlag(FLAGS_o_rms);
// Open wav input file and check properties.
WavReader wav_reader(input_file);
if (wav_reader.num_channels() != 1) {
RTC_LOG(LS_ERROR) << "Only mono wav files supported";
return 1;
}
if (wav_reader.sample_rate() > kMaxSampleRate) {
RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
<< ")";
return 1;
}
const size_t audio_frame_len = rtc::CheckedDivExact(
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
if (audio_frame_len > kMaxFrameLen) {
RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
return 1;
}
// Create output file and write header.
std::ofstream out_probs_file(output_probs_file, std::ofstream::binary);
std::ofstream out_rms_file(output_file, std::ofstream::binary);
// Run VAD and write decisions.
VoiceActivityDetector vad;
std::array<int16_t, kMaxFrameLen> samples;
while (true) {
// Process frame.
const auto read_samples =
wav_reader.ReadSamples(audio_frame_len, samples.data());
if (read_samples < audio_frame_len) {
break;
}
vad.ProcessChunk(samples.data(), audio_frame_len, wav_reader.sample_rate());
// Write output.
auto probs = vad.chunkwise_voice_probabilities();
auto rms = vad.chunkwise_rms();
RTC_CHECK_EQ(probs.size(), rms.size());
RTC_CHECK_EQ(sizeof(double), 8);
for (const auto& p : probs) {
out_probs_file.write(reinterpret_cast<const char*>(&p), 8);
}
for (const auto& r : rms) {
out_rms_file.write(reinterpret_cast<const char*>(&r), 8);
}
}
out_probs_file.close();
out_rms_file.close();
return 0;
}
} // namespace
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}

View File

@ -1,100 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Class implementing a wrapper for APM simulators.
"""
import cProfile
import logging
import os
import subprocess
from . import data_access
from . import exceptions
class AudioProcWrapper(object):
"""Wrapper for APM simulators.
"""
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(
os.path.join(os.pardir, 'audioproc_f'))
OUTPUT_FILENAME = 'output.wav'
def __init__(self, simulator_bin_path):
"""Ctor.
Args:
simulator_bin_path: path to the APM simulator binary.
"""
self._simulator_bin_path = simulator_bin_path
self._config = None
self._output_signal_filepath = None
# Profiler instance to measure running time.
self._profiler = cProfile.Profile()
@property
def output_filepath(self):
return self._output_signal_filepath
def Run(self,
config_filepath,
capture_input_filepath,
output_path,
render_input_filepath=None):
"""Runs APM simulator.
Args:
config_filepath: path to the configuration file specifying the arguments
for the APM simulator.
capture_input_filepath: path to the capture audio track input file (aka
forward or near-end).
output_path: path of the audio track output file.
render_input_filepath: path to the render audio track input file (aka
reverse or far-end).
"""
# Init.
self._output_signal_filepath = os.path.join(output_path,
self.OUTPUT_FILENAME)
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
# Skip if the output has already been generated.
if os.path.exists(self._output_signal_filepath) and os.path.exists(
profiling_stats_filepath):
return
# Load configuration.
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
# Set remaining parameters.
if not os.path.exists(capture_input_filepath):
raise exceptions.FileNotFoundError(
'cannot find capture input file')
self._config['-i'] = capture_input_filepath
self._config['-o'] = self._output_signal_filepath
if render_input_filepath is not None:
if not os.path.exists(render_input_filepath):
raise exceptions.FileNotFoundError(
'cannot find render input file')
self._config['-ri'] = render_input_filepath
# Build arguments list.
args = [self._simulator_bin_path]
for param_name in self._config:
args.append(param_name)
if self._config[param_name] is not None:
args.append(str(self._config[param_name]))
logging.debug(' '.join(args))
# Run.
self._profiler.enable()
subprocess.call(args)
self._profiler.disable()
# Save profiling stats.
self._profiler.dump_stats(profiling_stats_filepath)

View File

@ -1,243 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Imports a filtered subset of the scores and configurations computed
by apm_quality_assessment.py into a pandas data frame.
"""
import argparse
import glob
import logging
import os
import re
import sys
try:
import pandas as pd
except ImportError:
logging.critical('Cannot import the third-party Python package pandas')
sys.exit(1)
from . import data_access as data_access
from . import simulation as sim
# Compiled regular expressions used to extract score descriptors.
RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() +
r'(.+)')
RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() +
r'(.+)')
RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() +
r'(.+)')
RE_TEST_DATA_GEN_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)')
RE_TEST_DATA_GEN_PARAMS = re.compile(
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)')
RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() +
r'(.+)(\..+)')
def InstanceArgumentsParser():
"""Arguments parser factory.
"""
parser = argparse.ArgumentParser(
description=('Override this description in a user script by changing'
' `parser.description` of the returned parser.'))
parser.add_argument('-o',
'--output_dir',
required=True,
help=('the same base path used with the '
'apm_quality_assessment tool'))
parser.add_argument(
'-c',
'--config_names',
type=re.compile,
help=('regular expression to filter the APM configuration'
' names'))
parser.add_argument(
'-i',
'--capture_names',
type=re.compile,
help=('regular expression to filter the capture signal '
'names'))
parser.add_argument('-r',
'--render_names',
type=re.compile,
help=('regular expression to filter the render signal '
'names'))
parser.add_argument(
'-e',
'--echo_simulator_names',
type=re.compile,
help=('regular expression to filter the echo simulator '
'names'))
parser.add_argument('-t',
'--test_data_generators',
type=re.compile,
help=('regular expression to filter the test data '
'generator names'))
parser.add_argument(
'-s',
'--eval_scores',
type=re.compile,
help=('regular expression to filter the evaluation score '
'names'))
return parser
def _GetScoreDescriptors(score_filepath):
"""Extracts a score descriptor from the given score file path.
Args:
score_filepath: path to the score file.
Returns:
A tuple of strings (APM configuration name, capture audio track name,
render audio track name, echo simulator name, test data generator name,
test data generator parameters as string, evaluation score name).
"""
fields = score_filepath.split(os.sep)[-7:]
extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]).
groups(0)[0])
return (
extract_name(0, RE_CONFIG_NAME),
extract_name(1, RE_CAPTURE_NAME),
extract_name(2, RE_RENDER_NAME),
extract_name(3, RE_ECHO_SIM_NAME),
extract_name(4, RE_TEST_DATA_GEN_NAME),
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
extract_name(6, RE_SCORE_NAME),
)
def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
test_data_gen_name, score_name, args):
"""Decides whether excluding a score.
A set of optional regular expressions in args is used to determine if the
score should be excluded (depending on its |*_name| descriptors).
Args:
config_name: APM configuration name.
capture_name: capture audio track name.
render_name: render audio track name.
echo_simulator_name: echo simulator name.
test_data_gen_name: test data generator name.
score_name: evaluation score name.
args: parsed arguments.
Returns:
A boolean.
"""
value_regexpr_pairs = [
(config_name, args.config_names),
(capture_name, args.capture_names),
(render_name, args.render_names),
(echo_simulator_name, args.echo_simulator_names),
(test_data_gen_name, args.test_data_generators),
(score_name, args.eval_scores),
]
# Score accepted if each value matches the corresponding regular expression.
for value, regexpr in value_regexpr_pairs:
if regexpr is None:
continue
if not regexpr.match(value):
return True
return False
def FindScores(src_path, args):
"""Given a search path, find scores and return a DataFrame object.
Args:
src_path: Search path pattern.
args: parsed arguments.
Returns:
A DataFrame object.
"""
# Get scores.
scores = []
for score_filepath in glob.iglob(src_path):
# Extract score descriptor fields from the path.
(config_name, capture_name, render_name, echo_simulator_name,
test_data_gen_name, test_data_gen_params,
score_name) = _GetScoreDescriptors(score_filepath)
# Ignore the score if required.
if _ExcludeScore(config_name, capture_name, render_name,
echo_simulator_name, test_data_gen_name, score_name,
args):
logging.info('ignored score: %s %s %s %s %s %s', config_name,
capture_name, render_name, echo_simulator_name,
test_data_gen_name, score_name)
continue
# Read metadata and score.
metadata = data_access.Metadata.LoadAudioTestDataPaths(
os.path.split(score_filepath)[0])
score = data_access.ScoreFile.Load(score_filepath)
# Add a score with its descriptor fields.
scores.append((
metadata['clean_capture_input_filepath'],
metadata['echo_free_capture_filepath'],
metadata['echo_filepath'],
metadata['render_filepath'],
metadata['capture_filepath'],
metadata['apm_output_filepath'],
metadata['apm_reference_filepath'],
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name,
score,
))
return pd.DataFrame(data=scores,
columns=(
'clean_capture_input_filepath',
'echo_free_capture_filepath',
'echo_filepath',
'render_filepath',
'capture_filepath',
'apm_output_filepath',
'apm_reference_filepath',
'apm_config',
'capture',
'render',
'echo_simulator',
'test_data_gen',
'test_data_gen_params',
'eval_score_name',
'score',
))
def ConstructSrcPath(args):
return os.path.join(
args.output_dir,
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
sim.ApmModuleSimulator.GetPrefixRender() + '*',
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
sim.ApmModuleSimulator.GetPrefixScore() + '*')

View File

@ -1,154 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Data access utility functions and classes.
"""
import json
import os
def MakeDirectory(path):
"""Makes a directory recursively without rising exceptions if existing.
Args:
path: path to the directory to be created.
"""
if os.path.exists(path):
return
os.makedirs(path)
class Metadata(object):
"""Data access class to save and load metadata.
"""
def __init__(self):
pass
_GENERIC_METADATA_SUFFIX = '.mdata'
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
@classmethod
def LoadFileMetadata(cls, filepath):
"""Loads generic metadata linked to a file.
Args:
filepath: path to the metadata file to read.
Returns:
A dict.
"""
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
return json.load(f)
@classmethod
def SaveFileMetadata(cls, filepath, metadata):
"""Saves generic metadata linked to a file.
Args:
filepath: path to the metadata file to write.
metadata: a dict.
"""
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
json.dump(metadata, f)
@classmethod
def LoadAudioTestDataPaths(cls, metadata_path):
"""Loads the input and the reference audio track paths.
Args:
metadata_path: path to the directory containing the metadata file.
Returns:
Tuple with the paths to the input and output audio tracks.
"""
metadata_filepath = os.path.join(metadata_path,
cls._AUDIO_TEST_DATA_FILENAME)
with open(metadata_filepath) as f:
return json.load(f)
@classmethod
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
"""Saves the input and the reference audio track paths.
Args:
output_path: path to the directory containing the metadata file.
Keyword Args:
filepaths: collection of audio track file paths to save.
"""
output_filepath = os.path.join(output_path,
cls._AUDIO_TEST_DATA_FILENAME)
with open(output_filepath, 'w') as f:
json.dump(filepaths, f)
class AudioProcConfigFile(object):
"""Data access to load/save APM simulator argument lists.
The arguments stored in the config files are used to control the APM flags.
"""
def __init__(self):
pass
@classmethod
def Load(cls, filepath):
"""Loads a configuration file for an APM simulator.
Args:
filepath: path to the configuration file.
Returns:
A dict containing the configuration.
"""
with open(filepath) as f:
return json.load(f)
@classmethod
def Save(cls, filepath, config):
"""Saves a configuration file for an APM simulator.
Args:
filepath: path to the configuration file.
config: a dict containing the configuration.
"""
with open(filepath, 'w') as f:
json.dump(config, f)
class ScoreFile(object):
"""Data access class to save and load float scalar scores.
"""
def __init__(self):
pass
@classmethod
def Load(cls, filepath):
"""Loads a score from file.
Args:
filepath: path to the score file.
Returns:
A float encoding the score.
"""
with open(filepath) as f:
return float(f.readline().strip())
@classmethod
def Save(cls, filepath, score):
"""Saves a score into a file.
Args:
filepath: path to the score file.
score: float encoding the score.
"""
with open(filepath, 'w') as f:
f.write('{0:f}\n'.format(score))

View File

@ -1,136 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Echo path simulation module.
"""
import hashlib
import os
from . import signal_processing
class EchoPathSimulator(object):
"""Abstract class for the echo path simulators.
In general, an echo path simulator is a function of the render signal and
simulates the propagation of the latter into the microphone (e.g., due to
mechanical or electrical paths).
"""
NAME = None
REGISTERED_CLASSES = {}
def __init__(self):
pass
def Simulate(self, output_path):
"""Creates the echo signal and stores it in an audio file (abstract method).
Args:
output_path: Path in which any output can be saved.
Returns:
Path to the generated audio track file or None if no echo is present.
"""
raise NotImplementedError()
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EchoPathSimulator implementation.
Decorator to automatically register the classes that extend
EchoPathSimulator.
Example usage:
@EchoPathSimulator.RegisterClass
class NoEchoPathSimulator(EchoPathSimulator):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@EchoPathSimulator.RegisterClass
class NoEchoPathSimulator(EchoPathSimulator):
"""Simulates absence of echo."""
NAME = 'noecho'
def __init__(self):
EchoPathSimulator.__init__(self)
def Simulate(self, output_path):
return None
@EchoPathSimulator.RegisterClass
class LinearEchoPathSimulator(EchoPathSimulator):
"""Simulates linear echo path.
This class applies a given impulse response to the render input and then it
sums the signal to the capture input signal.
"""
NAME = 'linear'
def __init__(self, render_input_filepath, impulse_response):
"""
Args:
render_input_filepath: Render audio track file.
impulse_response: list or numpy vector of float values.
"""
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
self._impulse_response = impulse_response
def Simulate(self, output_path):
"""Simulates linear echo path."""
# Form the file name with a hash of the impulse response.
impulse_response_hash = hashlib.sha256(
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
echo_filepath = os.path.join(
output_path, 'linear_echo_{}.wav'.format(impulse_response_hash))
# If the simulated echo audio track file does not exists, create it.
if not os.path.exists(echo_filepath):
render = signal_processing.SignalProcessingUtils.LoadWav(
self._render_input_filepath)
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
render, self._impulse_response)
signal_processing.SignalProcessingUtils.SaveWav(
echo_filepath, echo)
return echo_filepath
@EchoPathSimulator.RegisterClass
class RecordedEchoPathSimulator(EchoPathSimulator):
"""Uses recorded echo.
This class uses the clean capture input file name to build the file name of
the corresponding recording containing echo (a predefined suffix is used).
Such a file is expected to be already existing.
"""
NAME = 'recorded'
_FILE_NAME_SUFFIX = '_echo'
def __init__(self, render_input_filepath):
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
def Simulate(self, output_path):
"""Uses recorded echo path."""
path, file_name_ext = os.path.split(self._render_input_filepath)
file_name, file_ext = os.path.splitext(file_name_ext)
echo_filepath = os.path.join(
path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext))
assert os.path.exists(echo_filepath), (
'cannot find the echo audio track file {}'.format(echo_filepath))
return echo_filepath

View File

@ -1,48 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Echo path simulation factory module.
"""
import numpy as np
from . import echo_path_simulation
class EchoPathSimulatorFactory(object):
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
# realistic impulse response.
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15])
def __init__(self):
pass
@classmethod
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
"""Creates an EchoPathSimulator instance given a class object.
Args:
echo_path_simulator_class: EchoPathSimulator class object (not an
instance).
render_input_filepath: Path to the render audio track file.
Returns:
An EchoPathSimulator instance.
"""
assert render_input_filepath is not None or (
echo_path_simulator_class ==
echo_path_simulation.NoEchoPathSimulator)
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
return echo_path_simulation.NoEchoPathSimulator()
elif echo_path_simulator_class == (
echo_path_simulation.LinearEchoPathSimulator):
return echo_path_simulation.LinearEchoPathSimulator(
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
else:
return echo_path_simulator_class(render_input_filepath)

View File

@ -1,82 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the echo path simulation module.
"""
import shutil
import os
import tempfile
import unittest
import pydub
from . import echo_path_simulation
from . import echo_path_simulation_factory
from . import signal_processing
class TestEchoPathSimulators(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
# Create and save white noise.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._audio_track_num_samples = (
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
self._audio_track_filepath = os.path.join(self._tmp_path,
'white_noise.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._audio_track_filepath, white_noise)
# Make a copy the white noise audio track file; it will be used by
# echo_path_simulation.RecordedEchoPathSimulator.
shutil.copy(self._audio_track_filepath,
os.path.join(self._tmp_path, 'white_noise_echo.wav'))
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def testRegisteredClasses(self):
# Check that there is at least one registered echo path simulator.
registered_classes = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance factory.
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
# Try each registered echo path simulator.
for echo_path_simulator_name in registered_classes:
simulator = factory.GetInstance(
echo_path_simulator_class=registered_classes[
echo_path_simulator_name],
render_input_filepath=self._audio_track_filepath)
echo_filepath = simulator.Simulate(self._tmp_path)
if echo_filepath is None:
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
echo_path_simulator_name)
# No other tests in this case.
continue
# Check that the echo audio track file exists and its length is greater or
# equal to that of the render audio track.
self.assertTrue(os.path.exists(echo_filepath))
echo = signal_processing.SignalProcessingUtils.LoadWav(
echo_filepath)
self.assertGreaterEqual(
signal_processing.SignalProcessingUtils.CountSamples(echo),
self._audio_track_num_samples)

View File

@ -1,427 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Evaluation score abstract class and implementations.
"""
from __future__ import division
import logging
import os
import re
import subprocess
import sys
try:
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import data_access
from . import exceptions
from . import signal_processing
class EvaluationScore(object):
NAME = None
REGISTERED_CLASSES = {}
def __init__(self, score_filename_prefix):
self._score_filename_prefix = score_filename_prefix
self._input_signal_metadata = None
self._reference_signal = None
self._reference_signal_filepath = None
self._tested_signal = None
self._tested_signal_filepath = None
self._output_filepath = None
self._score = None
self._render_signal_filepath = None
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EvaluationScore implementation.
Decorator to automatically register the classes that extend EvaluationScore.
Example usage:
@EvaluationScore.RegisterClass
class AudioLevelScore(EvaluationScore):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@property
def output_filepath(self):
return self._output_filepath
@property
def score(self):
return self._score
def SetInputSignalMetadata(self, metadata):
"""Sets input signal metadata.
Args:
metadata: dict instance.
"""
self._input_signal_metadata = metadata
def SetReferenceSignalFilepath(self, filepath):
"""Sets the path to the audio track used as reference signal.
Args:
filepath: path to the reference audio track.
"""
self._reference_signal_filepath = filepath
def SetTestedSignalFilepath(self, filepath):
"""Sets the path to the audio track used as test signal.
Args:
filepath: path to the test audio track.
"""
self._tested_signal_filepath = filepath
def SetRenderSignalFilepath(self, filepath):
"""Sets the path to the audio track used as render signal.
Args:
filepath: path to the test audio track.
"""
self._render_signal_filepath = filepath
def Run(self, output_path):
"""Extracts the score for the set test data pair.
Args:
output_path: path to the directory where the output is written.
"""
self._output_filepath = os.path.join(
output_path, self._score_filename_prefix + self.NAME + '.txt')
try:
# If the score has already been computed, load.
self._LoadScore()
logging.debug('score found and loaded')
except IOError:
# Compute the score.
logging.debug('score not found, compute')
self._Run(output_path)
def _Run(self, output_path):
# Abstract method.
raise NotImplementedError()
def _LoadReferenceSignal(self):
assert self._reference_signal_filepath is not None
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._reference_signal_filepath)
def _LoadTestedSignal(self):
assert self._tested_signal_filepath is not None
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._tested_signal_filepath)
def _LoadScore(self):
return data_access.ScoreFile.Load(self._output_filepath)
def _SaveScore(self):
return data_access.ScoreFile.Save(self._output_filepath, self._score)
@EvaluationScore.RegisterClass
class AudioLevelPeakScore(EvaluationScore):
"""Peak audio level score.
Defined as the difference between the peak audio level of the tested and
the reference signals.
Unit: dB
Ideal: 0 dB
Worst case: +/-inf dB
"""
NAME = 'audio_level_peak'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
self._SaveScore()
@EvaluationScore.RegisterClass
class MeanAudioLevelScore(EvaluationScore):
"""Mean audio level score.
Defined as the difference between the mean audio level of the tested and
the reference signals.
Unit: dB
Ideal: 0 dB
Worst case: +/-inf dB
"""
NAME = 'audio_level_mean'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
dbfs_diffs_sum = 0.0
seconds = min(len(self._tested_signal), len(
self._reference_signal)) // 1000
for t in range(seconds):
t0 = t * seconds
t1 = t0 + seconds
dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
self._reference_signal[t0:t1].dBFS)
self._score = dbfs_diffs_sum / float(seconds)
self._SaveScore()
@EvaluationScore.RegisterClass
class EchoMetric(EvaluationScore):
"""Echo score.
Proportion of detected echo.
Unit: ratio
Ideal: 0
Worst case: 1
"""
NAME = 'echo_metric'
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
# POLQA binary file path.
self._echo_detector_bin_filepath = echo_detector_bin_filepath
if not os.path.exists(self._echo_detector_bin_filepath):
logging.error('cannot find EchoMetric tool binary file')
raise exceptions.FileNotFoundError()
self._echo_detector_bin_path, _ = os.path.split(
self._echo_detector_bin_filepath)
def _Run(self, output_path):
echo_detector_out_filepath = os.path.join(output_path,
'echo_detector.out')
if os.path.exists(echo_detector_out_filepath):
os.unlink(echo_detector_out_filepath)
logging.debug("Render signal filepath: %s",
self._render_signal_filepath)
if not os.path.exists(self._render_signal_filepath):
logging.error(
"Render input required for evaluating the echo metric.")
args = [
self._echo_detector_bin_filepath, '--output_file',
echo_detector_out_filepath, '--', '-i',
self._tested_signal_filepath, '-ri', self._render_signal_filepath
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._echo_detector_bin_path)
# Parse Echo detector tool output and extract the score.
self._score = self._ParseOutputFile(echo_detector_out_filepath)
self._SaveScore()
@classmethod
def _ParseOutputFile(cls, echo_metric_file_path):
"""
Parses the POLQA tool output formatted as a table ('-t' option).
Args:
polqa_out_filepath: path to the POLQA tool output file.
Returns:
The score as a number in [0, 1].
"""
with open(echo_metric_file_path) as f:
return float(f.read())
@EvaluationScore.RegisterClass
class PolqaScore(EvaluationScore):
"""POLQA score.
See http://www.polqa.info/.
Unit: MOS
Ideal: 4.5
Worst case: 1.0
"""
NAME = 'polqa'
def __init__(self, score_filename_prefix, polqa_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
# POLQA binary file path.
self._polqa_bin_filepath = polqa_bin_filepath
if not os.path.exists(self._polqa_bin_filepath):
logging.error('cannot find POLQA tool binary file')
raise exceptions.FileNotFoundError()
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
def _Run(self, output_path):
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
if os.path.exists(polqa_out_filepath):
os.unlink(polqa_out_filepath)
args = [
self._polqa_bin_filepath,
'-t',
'-q',
'-Overwrite',
'-Ref',
self._reference_signal_filepath,
'-Test',
self._tested_signal_filepath,
'-LC',
'NB',
'-Out',
polqa_out_filepath,
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._polqa_tool_path)
# Parse POLQA tool output and extract the score.
polqa_output = self._ParseOutputFile(polqa_out_filepath)
self._score = float(polqa_output['PolqaScore'])
self._SaveScore()
@classmethod
def _ParseOutputFile(cls, polqa_out_filepath):
"""
Parses the POLQA tool output formatted as a table ('-t' option).
Args:
polqa_out_filepath: path to the POLQA tool output file.
Returns:
A dict.
"""
data = []
with open(polqa_out_filepath) as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('*'):
# Ignore comments.
continue
# Read fields.
data.append(re.split(r'\t+', line))
# Two rows expected (header and values).
assert len(data) == 2, 'Cannot parse POLQA output'
number_of_fields = len(data[0])
assert number_of_fields == len(data[1])
# Build and return a dictionary with field names (header) as keys and the
# corresponding field values as values.
return {
data[0][index]: data[1][index]
for index in range(number_of_fields)
}
@EvaluationScore.RegisterClass
class TotalHarmonicDistorsionScore(EvaluationScore):
"""Total harmonic distorsion plus noise score.
Total harmonic distorsion plus noise score.
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
Unit: -.
Ideal: 0.
Worst case: +inf
"""
NAME = 'thd'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
self._input_frequency = None
def _Run(self, output_path):
self._CheckInputSignal()
self._LoadTestedSignal()
if self._tested_signal.channels != 1:
raise exceptions.EvaluationScoreException(
'unsupported number of channels')
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._tested_signal)
# Init.
num_samples = len(samples)
duration = len(self._tested_signal) / 1000.0
scaling = 2.0 / num_samples
max_freq = self._tested_signal.frame_rate / 2
f0_freq = float(self._input_frequency)
t = np.linspace(0, duration, num_samples)
# Analyze harmonics.
b_terms = []
n = 1
while f0_freq * n < max_freq:
x_n = np.sum(
samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
y_n = np.sum(
samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
b_terms.append(np.sqrt(x_n**2 + y_n**2))
n += 1
output_without_fundamental = samples - b_terms[0] * np.sin(
2.0 * np.pi * f0_freq * t)
distortion_and_noise = np.sqrt(
np.sum(output_without_fundamental**2) * np.pi * scaling)
# TODO(alessiob): Fix or remove if not needed.
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
# TODO(alessiob): Check the range of `thd_plus_noise` and update the class
# docstring above if accordingly.
thd_plus_noise = distortion_and_noise / b_terms[0]
self._score = thd_plus_noise
self._SaveScore()
def _CheckInputSignal(self):
# Check input signal and get properties.
try:
if self._input_signal_metadata['signal'] != 'pure_tone':
raise exceptions.EvaluationScoreException(
'The THD score requires a pure tone as input signal')
self._input_frequency = self._input_signal_metadata['frequency']
if self._input_signal_metadata[
'test_data_gen_name'] != 'identity' or (
self._input_signal_metadata['test_data_gen_config'] !=
'default'):
raise exceptions.EvaluationScoreException(
'The THD score cannot be used with any test data generator other '
'than "identity"')
except TypeError:
raise exceptions.EvaluationScoreException(
'The THD score requires an input signal with associated metadata'
)
except KeyError:
raise exceptions.EvaluationScoreException(
'Invalid input signal metadata to compute the THD score')

View File

@ -1,55 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""EvaluationScore factory class.
"""
import logging
from . import exceptions
from . import eval_scores
class EvaluationScoreWorkerFactory(object):
"""Factory class used to instantiate evaluation score workers.
The ctor gets the parametrs that are used to instatiate the evaluation score
workers.
"""
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
self._score_filename_prefix = None
self._polqa_tool_bin_path = polqa_tool_bin_path
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
def SetScoreFilenamePrefix(self, prefix):
self._score_filename_prefix = prefix
def GetInstance(self, evaluation_score_class):
"""Creates an EvaluationScore instance given a class object.
Args:
evaluation_score_class: EvaluationScore class object (not an instance).
Returns:
An EvaluationScore instance.
"""
if self._score_filename_prefix is None:
raise exceptions.InitializationException(
'The score file name prefix for evaluation score workers is not set'
)
logging.debug('factory producing a %s evaluation score',
evaluation_score_class)
if evaluation_score_class == eval_scores.PolqaScore:
return eval_scores.PolqaScore(self._score_filename_prefix,
self._polqa_tool_bin_path)
elif evaluation_score_class == eval_scores.EchoMetric:
return eval_scores.EchoMetric(self._score_filename_prefix,
self._echo_metric_tool_bin_path)
else:
return evaluation_score_class(self._score_filename_prefix)

View File

@ -1,137 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the eval_scores module.
"""
import os
import shutil
import tempfile
import unittest
import pydub
from . import data_access
from . import eval_scores
from . import eval_scores_factory
from . import signal_processing
class TestEvalScores(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
fake_tested_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def testRegisteredClasses(self):
# Evaluation score names to exclude (tested separately).
exceptions = ['thd', 'echo_metric']
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Check that there is at least one registered evaluation score worker.
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None))
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
if eval_score_name in exceptions:
continue
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Check output.
score = data_access.ScoreFile.Load(
eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
def testTotalHarmonicDistorsionScore(self):
# Init.
pure_tone_freq = 5000.0
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
eval_score_worker.SetInputSignalMetadata({
'signal':
'pure_tone',
'frequency':
pure_tone_freq,
'test_data_gen_name':
'identity',
'test_data_gen_config':
'default',
})
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Create 3 test signals: pure tone, pure tone + white noise, white noise
# only.
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
template, pure_tone_freq)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
template)
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
pure_tone, white_noise)
# Compute scores for increasingly distorted pure tone signals.
scores = [None, None, None]
for index, tested_signal in enumerate(
[pure_tone, noisy_tone, white_noise]):
# Save signal.
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
signal_processing.SignalProcessingUtils.SaveWav(
tmp_filepath, tested_signal)
# Compute score.
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
eval_score_worker.Run(self._output_path)
scores[index] = eval_score_worker.score
# Remove output file to avoid caching.
os.remove(eval_score_worker.output_filepath)
# Validate scores (lowest score with a pure tone).
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))

View File

@ -1,57 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Evaluator of the APM module.
"""
import logging
class ApmModuleEvaluator(object):
"""APM evaluator class.
"""
def __init__(self):
pass
@classmethod
def Run(cls, evaluation_score_workers, apm_input_metadata,
apm_output_filepath, reference_input_filepath,
render_input_filepath, output_path):
"""Runs the evaluation.
Iterates over the given evaluation score workers.
Args:
evaluation_score_workers: list of EvaluationScore instances.
apm_input_metadata: dictionary with metadata of the APM input.
apm_output_filepath: path to the audio track file with the APM output.
reference_input_filepath: path to the reference audio track file.
output_path: output path.
Returns:
A dict of evaluation score name and score pairs.
"""
# Init.
scores = {}
for evaluation_score_worker in evaluation_score_workers:
logging.info(' computing <%s> score',
evaluation_score_worker.NAME)
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
evaluation_score_worker.SetReferenceSignalFilepath(
reference_input_filepath)
evaluation_score_worker.SetTestedSignalFilepath(
apm_output_filepath)
evaluation_score_worker.SetRenderSignalFilepath(
render_input_filepath)
evaluation_score_worker.Run(output_path)
scores[
evaluation_score_worker.NAME] = evaluation_score_worker.score
return scores

View File

@ -1,45 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Exception classes.
"""
class FileNotFoundError(Exception):
"""File not found exception.
"""
pass
class SignalProcessingException(Exception):
"""Signal processing exception.
"""
pass
class InputMixerException(Exception):
"""Input mixer exception.
"""
pass
class InputSignalCreatorException(Exception):
"""Input signal creator exception.
"""
pass
class EvaluationScoreException(Exception):
"""Evaluation score exception.
"""
pass
class InitializationException(Exception):
"""Initialization exception.
"""
pass

View File

@ -1,426 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import functools
import hashlib
import logging
import os
import re
import sys
try:
import csscompressor
except ImportError:
logging.critical(
'Cannot import the third-party Python package csscompressor')
sys.exit(1)
try:
import jsmin
except ImportError:
logging.critical('Cannot import the third-party Python package jsmin')
sys.exit(1)
class HtmlExport(object):
"""HTML exporter class for APM quality scores."""
_NEW_LINE = '\n'
# CSS and JS file paths.
_PATH = os.path.dirname(os.path.realpath(__file__))
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
_CSS_MINIFIED = True
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
_JS_MINIFIED = True
def __init__(self, output_filepath):
self._scores_data_frame = None
self._output_filepath = output_filepath
def Export(self, scores_data_frame):
"""Exports scores into an HTML file.
Args:
scores_data_frame: DataFrame instance.
"""
self._scores_data_frame = scores_data_frame
html = [
'<html>',
self._BuildHeader(),
('<script type="text/javascript">'
'(function () {'
'window.addEventListener(\'load\', function () {'
'var inspector = new AudioInspector();'
'});'
'})();'
'</script>'), '<body>',
self._BuildBody(), '</body>', '</html>'
]
self._Save(self._output_filepath, self._NEW_LINE.join(html))
def _BuildHeader(self):
"""Builds the <head> section of the HTML file.
The header contains the page title and either embedded or linked CSS and JS
files.
Returns:
A string with <head>...</head> HTML.
"""
html = ['<head>', '<title>Results</title>']
# Add Material Design hosted libs.
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
'css?family=Roboto:300,400,500,700" type="text/css">')
html.append(
'<link rel="stylesheet" href="https://fonts.googleapis.com/'
'icon?family=Material+Icons">')
html.append(
'<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
'material.indigo-pink.min.css">')
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
'material.min.js"></script>')
# Embed custom JavaScript and CSS files.
html.append('<script>')
with open(self._JS_FILEPATH) as f:
html.append(
jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
f.read().rstrip()))
html.append('</script>')
html.append('<style>')
with open(self._CSS_FILEPATH) as f:
html.append(
csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
f.read().rstrip()))
html.append('</style>')
html.append('</head>')
return self._NEW_LINE.join(html)
def _BuildBody(self):
"""Builds the content of the <body> section."""
score_names = self._scores_data_frame[
'eval_score_name'].drop_duplicates().values.tolist()
html = [
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
'mdl-layout--fixed-tabs">'),
'<header class="mdl-layout__header">',
'<div class="mdl-layout__header-row">',
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
self._output_filepath),
'</div>',
]
# Tab selectors.
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
for tab_index, score_name in enumerate(score_names):
is_active = tab_index == 0
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
'{}</a>'.format(tab_index,
' is-active' if is_active else '',
self._FormatName(score_name)))
html.append('</div>')
html.append('</header>')
html.append(
'<main class="mdl-layout__content" style="overflow-x: auto;">')
# Tabs content.
for tab_index, score_name in enumerate(score_names):
html.append('<section class="mdl-layout__tab-panel{}" '
'id="score-tab-{}">'.format(
' is-active' if is_active else '', tab_index))
html.append('<div class="page-content">')
html.append(
self._BuildScoreTab(score_name, ('s{}'.format(tab_index), )))
html.append('</div>')
html.append('</section>')
html.append('</main>')
html.append('</div>')
# Add snackbar for notifications.
html.append(
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
'<div class="mdl-snackbar__text"></div>'
'<button type="button" class="mdl-snackbar__action"></button>'
'</div>')
return self._NEW_LINE.join(html)
def _BuildScoreTab(self, score_name, anchor_data):
"""Builds the content of a tab."""
# Find unique values.
scores = self._scores_data_frame[
self._scores_data_frame.eval_score_name == score_name]
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
test_data_gen_configs = sorted(
self._FindUniqueTuples(scores,
['test_data_gen', 'test_data_gen_params']))
html = [
'<div class="mdl-grid">',
'<div class="mdl-layout-spacer"></div>',
'<div class="mdl-cell mdl-cell--10-col">',
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
'style="width: 100%;">'),
]
# Header.
html.append('<thead><tr><th>APM config / Test data generator</th>')
for test_data_gen_info in test_data_gen_configs:
html.append('<th>{} {}</th>'.format(
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for apm_config in apm_configs:
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
html.append(
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.
format(
dialog_id,
self._BuildScoreTableCell(score_name,
test_data_gen_info[0],
test_data_gen_info[1],
apm_config[0])))
html.append('</tr>')
html.append('</tbody>')
html.append(
'</table></div><div class="mdl-layout-spacer"></div></div>')
html.append(
self._BuildScoreStatsInspectorDialogs(score_name, apm_configs,
test_data_gen_configs,
anchor_data))
return self._NEW_LINE.join(html)
def _BuildScoreTableCell(self, score_name, test_data_gen,
test_data_gen_params, apm_config):
"""Builds the content of a table cell for a score table."""
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
test_data_gen,
test_data_gen_params)
stats = self._ComputeScoreStats(scores)
html = []
items_id_prefix = (score_name + test_data_gen + test_data_gen_params +
apm_config)
if stats['count'] == 1:
# Show the only available score.
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
item_id, scores['score'].mean()))
html.append(
'<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
'</div>'.format(item_id, 'single value'))
else:
# Show stats.
for stat_name in ['min', 'max', 'mean', 'std dev']:
item_id = hashlib.md5(
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
html.append('<div id="stats-{0}">{1:f}</div>'.format(
item_id, stats[stat_name]))
html.append(
'<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
'</div>'.format(item_id, stat_name))
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs,
test_data_gen_configs, anchor_data):
"""Builds a set of score stats inspector dialogs."""
html = []
for apm_config in apm_configs:
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
html.append('<dialog class="mdl-dialog" id="{}" '
'style="width: 40%;">'.format(dialog_id))
# Content.
html.append('<div class="mdl-dialog__content">')
html.append(
'<h6><strong>APM config preset</strong>: {}<br/>'
'<strong>Test data generator</strong>: {} ({})</h6>'.
format(self._FormatName(apm_config[0]),
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append(
self._BuildScoreStatsInspectorDialog(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1], anchor_data + (dialog_id, )))
html.append('</div>')
# Actions.
html.append('<div class="mdl-dialog__actions">')
html.append('<button type="button" class="mdl-button" '
'onclick="closeScoreStatsInspector()">'
'Close</button>')
html.append('</div>')
html.append('</dialog>')
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialog(self, score_name, apm_config,
test_data_gen, test_data_gen_params,
anchor_data):
"""Builds one score stats inspector dialog."""
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
test_data_gen,
test_data_gen_params)
capture_render_pairs = sorted(
self._FindUniqueTuples(scores, ['capture', 'render']))
echo_simulators = sorted(
self._FindUniqueTuples(scores, ['echo_simulator']))
html = [
'<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">'
]
# Header.
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
for echo_simulator in echo_simulators:
html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>')
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for row, (capture, render) in enumerate(capture_render_pairs):
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
capture, render))
for col, echo_simulator in enumerate(echo_simulators):
score_tuple = self._SliceDataForScoreStatsTableCell(
scores, capture, render, echo_simulator[0])
cell_class = 'r{}c{}'.format(row, col)
html.append('<td class="single-score-cell {}">{}</td>'.format(
cell_class,
self._BuildScoreStatsInspectorTableCell(
score_tuple, anchor_data + (cell_class, ))))
html.append('</tr>')
html.append('</tbody>')
html.append('</table>')
# Placeholder for the audio inspector.
html.append('<div class="audio-inspector-placeholder"></div>')
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
"""Builds the content of a cell of a score stats inspector."""
anchor = '&'.join(anchor_data)
html = [('<div class="v">{}</div>'
'<button class="mdl-button mdl-js-button mdl-button--icon"'
' data-anchor="{}">'
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
'</button>').format(score_tuple.score, anchor)]
# Add all the available file paths as hidden data.
for field_name in score_tuple.keys():
if field_name.endswith('_filepath'):
html.append(
'<input type="hidden" name="{}" value="{}">'.format(
field_name, score_tuple[field_name]))
return self._NEW_LINE.join(html)
def _SliceDataForScoreTableCell(self, score_name, apm_config,
test_data_gen, test_data_gen_params):
"""Slices `self._scores_data_frame` to extract the data for a tab."""
masks = []
masks.append(self._scores_data_frame.eval_score_name == score_name)
masks.append(self._scores_data_frame.apm_config == apm_config)
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
masks.append(self._scores_data_frame.test_data_gen_params ==
test_data_gen_params)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
return self._scores_data_frame[mask]
@classmethod
def _SliceDataForScoreStatsTableCell(cls, scores, capture, render,
echo_simulator):
"""Slices `scores` to extract the data for a tab."""
masks = []
masks.append(scores.capture == capture)
masks.append(scores.render == render)
masks.append(scores.echo_simulator == echo_simulator)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
sliced_data = scores[mask]
assert len(sliced_data) == 1, 'single score is expected'
return sliced_data.iloc[0]
@classmethod
def _FindUniqueTuples(cls, data_frame, fields):
"""Slices `data_frame` to a list of fields and finds unique tuples."""
return data_frame[fields].drop_duplicates().values.tolist()
@classmethod
def _ComputeScoreStats(cls, data_frame):
"""Computes score stats."""
scores = data_frame['score']
return {
'count': scores.count(),
'min': scores.min(),
'max': scores.max(),
'mean': scores.mean(),
'std dev': scores.std(),
}
@classmethod
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config,
test_data_gen, test_data_gen_params):
"""Assigns a unique name to a dialog."""
return 'score-stats-dialog-' + hashlib.md5(
'score-stats-inspector-{}-{}-{}-{}'.format(
score_name, apm_config, test_data_gen,
test_data_gen_params).encode('utf-8')).hexdigest()
@classmethod
def _Save(cls, output_filepath, html):
"""Writes the HTML file.
Args:
output_filepath: output file path.
html: string with the HTML content.
"""
with open(output_filepath, 'w') as f:
f.write(html)
@classmethod
def _FormatName(cls, name):
"""Formats a name.
Args:
name: a string.
Returns:
A copy of name in which underscores and dashes are replaced with a space.
"""
return re.sub(r'[_\-]', ' ', name)

View File

@ -1,86 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the export module.
"""
import logging
import os
import shutil
import tempfile
import unittest
import pyquery as pq
from . import audioproc_wrapper
from . import collect_data
from . import eval_scores_factory
from . import evaluation
from . import export
from . import simulation
from . import test_data_generation_factory
class TestExport(unittest.TestCase):
"""Unit tests for the export module.
"""
_CLEAN_TMP_OUTPUT = True
def setUp(self):
"""Creates temporary data to export."""
self._tmp_path = tempfile.mkdtemp()
# Run a fake experiment to produce data to export.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=[
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
],
test_data_generator_names=['identity', 'white_noise'],
eval_score_names=['audio_level_peak', 'audio_level_mean'],
output_dir=self._tmp_path)
# Export results.
p = collect_data.InstanceArgumentsParser()
args = p.parse_args(['--output_dir', self._tmp_path])
src_path = collect_data.ConstructSrcPath(args)
self._data_to_export = collect_data.FindScores(src_path, args)
def tearDown(self):
"""Recursively deletes temporary folders."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' +
(self._tmp_path))
def testCreateHtmlReport(self):
fn_out = os.path.join(self._tmp_path, 'results.html')
exporter = export.HtmlExport(fn_out)
exporter.Export(self._data_to_export)
document = pq.PyQuery(filename=fn_out)
self.assertIsInstance(document, pq.PyQuery)
# TODO(alessiob): Use PyQuery API to check the HTML file.

View File

@ -1,75 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
from __future__ import division
import logging
import os
import subprocess
import shutil
import sys
import tempfile
try:
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import signal_processing
class ExternalVad(object):
def __init__(self, path_to_binary, name):
"""Args:
path_to_binary: path to binary that accepts '-i <wav>', '-o
<float probabilities>'. There must be one float value per
10ms audio
name: a name to identify the external VAD. Used for saving
the output as extvad_output-<name>.
"""
self._path_to_binary = path_to_binary
self.name = name
assert os.path.exists(self._path_to_binary), (self._path_to_binary)
self._vad_output = None
def Run(self, wav_file_path):
_signal = signal_processing.SignalProcessingUtils.LoadWav(
wav_file_path)
if _signal.channels != 1:
raise NotImplementedError('Multiple-channel'
' annotations not implemented')
if _signal.frame_rate != 48000:
raise NotImplementedError('Frame rates '
'other than 48000 not implemented')
tmp_path = tempfile.mkdtemp()
try:
output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
subprocess.call([
self._path_to_binary, '-i', wav_file_path, '-o',
output_file_path
])
self._vad_output = np.fromfile(output_file_path, np.float32)
except Exception as e:
logging.error('Error while running the ' + self.name + ' VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def GetVadOutput(self):
assert self._vad_output is not None
return self._vad_output
@classmethod
def ConstructVadDict(cls, vad_paths, vad_names):
external_vads = {}
for path, name in zip(vad_paths, vad_names):
external_vads[name] = ExternalVad(path, name)
return external_vads

View File

@ -1,25 +0,0 @@
#!/usr/bin/python
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True)
parser.add_argument('-o', required=True)
args = parser.parse_args()
array = np.arange(100, dtype=np.float32)
array.tofile(open(args.o, 'w'))
if __name__ == '__main__':
main()

View File

@ -1,56 +0,0 @@
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <fstream>
#include <iostream>
#include <string>
#include "absl/strings/string_view.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace test {
namespace {
const char* const kErrorMessage = "-Out /path/to/output/file is mandatory";
// Writes fake output intended to be parsed by
// quality_assessment.eval_scores.PolqaScore.
void WriteOutputFile(absl::string_view output_file_path) {
RTC_CHECK_NE(output_file_path, "");
std::ofstream out(std::string{output_file_path});
RTC_CHECK(!out.bad());
out << "* Fake Polqa output" << std::endl;
out << "FakeField1\tPolqaScore\tFakeField2" << std::endl;
out << "FakeValue1\t3.25\tFakeValue2" << std::endl;
out.close();
}
} // namespace
int main(int argc, char* argv[]) {
// Find "-Out" and use its next argument as output file path.
RTC_CHECK_GE(argc, 3) << kErrorMessage;
const std::string kSoughtFlagName = "-Out";
for (int i = 1; i < argc - 1; ++i) {
if (kSoughtFlagName.compare(argv[i]) == 0) {
WriteOutputFile(argv[i + 1]);
return 0;
}
}
RTC_FATAL() << kErrorMessage;
}
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}

View File

@ -1,97 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Input mixer module.
"""
import logging
import os
from . import exceptions
from . import signal_processing
class ApmInputMixer(object):
"""Class to mix a set of audio segments down to the APM input."""
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
def __init__(self):
pass
@classmethod
def HardClippingLogMessage(cls):
"""Returns the log message used when hard clipping is detected in the mix.
This method is mainly intended to be used by the unit tests.
"""
return cls._HARD_CLIPPING_LOG_MSG
@classmethod
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
"""Mixes capture and echo.
Creates the overall capture input for APM by mixing the "echo-free" capture
signal with the echo signal (e.g., echo simulated via the
echo_path_simulation module).
The echo signal cannot be shorter than the capture signal and the generated
mix will have the same duration of the capture signal. The latter property
is enforced in order to let the input of APM and the reference signal
created by TestDataGenerator have the same length (required for the
evaluation step).
Hard-clipping may occur in the mix; a warning is raised when this happens.
If `echo_filepath` is None, nothing is done and `capture_input_filepath` is
returned.
Args:
speech: AudioSegment instance.
echo_path: AudioSegment instance or None.
Returns:
Path to the mix audio track file.
"""
if echo_filepath is None:
return capture_input_filepath
# Build the mix output file name as a function of the echo file name.
# This ensures that if the internal parameters of the echo path simulator
# change, no erroneous cache hit occurs.
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
capture_input_file_name, _ = os.path.splitext(
os.path.split(capture_input_filepath)[1])
mix_filepath = os.path.join(
output_path,
'mix_capture_{}_{}.wav'.format(capture_input_file_name,
echo_file_name))
# Create the mix if not done yet.
mix = None
if not os.path.exists(mix_filepath):
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
capture_input_filepath)
echo = signal_processing.SignalProcessingUtils.LoadWav(
echo_filepath)
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
signal_processing.SignalProcessingUtils.CountSamples(
echo_free_capture)):
raise exceptions.InputMixerException(
'echo cannot be shorter than capture')
mix = echo_free_capture.overlay(echo)
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
# Check if hard clipping occurs.
if mix is None:
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
return mix_filepath

View File

@ -1,140 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the input mixer module.
"""
import logging
import os
import shutil
import tempfile
import unittest
import mock
from . import exceptions
from . import input_mixer
from . import signal_processing
class TestApmInputMixer(unittest.TestCase):
"""Unit tests for the ApmInputMixer class.
"""
# Audio track file names created in setUp().
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
# Target peak power level (dBFS) of each audio track file created in setUp().
# These values are hand-crafted in order to make saturation happen when
# capture and echo_2 are mixed and the contrary for capture and echo_1.
# None means that the power is not changed.
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
# Audio track file durations in milliseconds.
_DURATIONS = [1000, 1000, 1000, 800, 1200]
_SAMPLE_RATE = 48000
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
# Create audio track files.
self._audio_tracks = {}
for filename, peak_power, duration in zip(self._FILENAMES,
self._MAX_PEAK_POWER_LEVELS,
self._DURATIONS):
audio_track_filepath = os.path.join(self._tmp_path,
'{}.wav'.format(filename))
# Create a pure tone with the target peak power level.
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration=duration, sample_rate=self._SAMPLE_RATE)
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
template)
if peak_power is not None:
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
signal_processing.SignalProcessingUtils.SaveWav(
audio_track_filepath, signal)
self._audio_tracks[filename] = {
'filepath':
audio_track_filepath,
'num_samples':
signal_processing.SignalProcessingUtils.CountSamples(signal)
}
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def testCheckMixSameDuration(self):
"""Checks the duration when mixing capture and echo with same duration."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(
self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
def testRejectShorterEcho(self):
"""Rejects echo signals that are shorter than the capture signal."""
try:
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['shorter']['filepath'])
self.fail('no exception raised')
except exceptions.InputMixerException:
pass
def testCheckMixDurationWithLongerEcho(self):
"""Checks the duration when mixing an echo longer than the capture."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['longer']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(
self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
def testCheckOutputFileNamesConflict(self):
"""Checks that different echo files lead to different output file names."""
mix1_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix1_filepath))
mix2_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
self.assertTrue(os.path.exists(mix2_filepath))
self.assertNotEqual(mix1_filepath, mix2_filepath)
def testHardClippingLogExpected(self):
"""Checks that hard clipping warning is raised when occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
logging.warning.assert_called_once_with(
input_mixer.ApmInputMixer.HardClippingLogMessage())
def testHardClippingLogNotExpected(self):
"""Checks that hard clipping warning is not raised when not occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertNotIn(
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
logging.warning.call_args_list)

View File

@ -1,68 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Input signal creator module.
"""
from . import exceptions
from . import signal_processing
class InputSignalCreator(object):
"""Input signal creator class.
"""
@classmethod
def Create(cls, name, raw_params):
"""Creates a input signal and its metadata.
Args:
name: Input signal creator name.
raw_params: Tuple of parameters to pass to the specific signal creator.
Returns:
(AudioSegment, dict) tuple.
"""
try:
signal = {}
params = {}
if name == 'pure_tone':
params['frequency'] = float(raw_params[0])
params['duration'] = int(raw_params[1])
signal = cls._CreatePureTone(params['frequency'],
params['duration'])
else:
raise exceptions.InputSignalCreatorException(
'Invalid input signal creator name')
# Complete metadata.
params['signal'] = name
return signal, params
except (TypeError, AssertionError) as e:
raise exceptions.InputSignalCreatorException(
'Invalid signal creator parameters: {}'.format(e))
@classmethod
def _CreatePureTone(cls, frequency, duration):
"""
Generates a pure tone at 48000 Hz.
Args:
frequency: Float in (0-24000] (Hz).
duration: Integer (milliseconds).
Returns:
AudioSegment instance.
"""
assert 0 < frequency <= 24000
assert duration > 0
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration)
return signal_processing.SignalProcessingUtils.GeneratePureTone(
template, frequency)

View File

@ -1,32 +0,0 @@
/* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
td.selected-score {
background-color: #DDD;
}
td.single-score-cell{
text-align: center;
}
.audio-inspector {
text-align: center;
}
.audio-inspector div{
margin-bottom: 0;
padding-bottom: 0;
padding-top: 0;
}
.audio-inspector div div{
margin-bottom: 0;
padding-bottom: 0;
padding-top: 0;
}

View File

@ -1,376 +0,0 @@
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
//
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file in the root of the source
// tree. An additional intellectual property rights grant can be found
// in the file PATENTS. All contributing project authors may
// be found in the AUTHORS file in the root of the source tree.
/**
* Opens the score stats inspector dialog.
* @param {String} dialogId: identifier of the dialog to show.
* @return {DOMElement} The dialog element that has been opened.
*/
function openScoreStatsInspector(dialogId) {
var dialog = document.getElementById(dialogId);
dialog.showModal();
return dialog;
}
/**
* Closes the score stats inspector dialog.
*/
function closeScoreStatsInspector() {
var dialog = document.querySelector('dialog[open]');
if (dialog == null)
return;
dialog.close();
}
/**
* Audio inspector class.
* @constructor
*/
function AudioInspector() {
console.debug('Creating an AudioInspector instance.');
this.audioPlayer_ = new Audio();
this.metadata_ = {};
this.currentScore_ = null;
this.audioInspector_ = null;
this.snackbarContainer_ = document.querySelector('#snackbar');
// Get base URL without anchors.
this.baseUrl_ = window.location.href;
var index = this.baseUrl_.indexOf('#');
if (index > 0)
this.baseUrl_ = this.baseUrl_.substr(0, index)
console.info('Base URL set to "' + window.location.href + '".');
window.event.stopPropagation();
this.createTextAreasForCopy_();
this.createAudioInspector_();
this.initializeEventHandlers_();
// When MDL is ready, parse the anchor (if any) to show the requested
// experiment.
var self = this;
document.querySelectorAll('header a')[0].addEventListener(
'mdl-componentupgraded', function() {
if (!self.parseWindowAnchor()) {
// If not experiment is requested, open the first section.
console.info('No anchor parsing, opening the first section.');
document.querySelectorAll('header a > span')[0].click();
}
});
}
/**
* Parse the anchor in the window URL.
* @return {bool} True if the parsing succeeded.
*/
AudioInspector.prototype.parseWindowAnchor = function() {
var index = location.href.indexOf('#');
if (index == -1) {
console.debug('No # found in the URL.');
return false;
}
var anchor = location.href.substr(index - location.href.length + 1);
console.info('Anchor changed: "' + anchor + '".');
var parts = anchor.split('&');
if (parts.length != 3) {
console.info('Ignoring anchor with invalid number of fields.');
return false;
}
var openDialog = document.querySelector('dialog[open]');
try {
// Open the requested dialog if not already open.
if (!openDialog || openDialog.id != parts[1]) {
!openDialog || openDialog.close();
document.querySelectorAll('header a > span')[
parseInt(parts[0].substr(1))].click();
openDialog = openScoreStatsInspector(parts[1]);
}
// Trigger click on cell.
var cell = openDialog.querySelector('td.' + parts[2]);
cell.focus();
cell.click();
this.showNotification_('Experiment selected.');
return true;
} catch (e) {
this.showNotification_('Cannot select experiment :(');
console.error('Exception caught while selecting experiment: "' + e + '".');
}
return false;
}
/**
* Set up the inspector for a new score.
* @param {DOMElement} element: Element linked to the selected score.
*/
AudioInspector.prototype.selectedScoreChange = function(element) {
if (this.currentScore_ == element) { return; }
if (this.currentScore_ != null) {
this.currentScore_.classList.remove('selected-score');
}
this.currentScore_ = element;
this.currentScore_.classList.add('selected-score');
this.stopAudio();
// Read metadata.
var matches = element.querySelectorAll('input[type=hidden]');
this.metadata_ = {};
for (var index = 0; index < matches.length; ++index) {
this.metadata_[matches[index].name] = matches[index].value;
}
// Show the audio inspector interface.
var container = element.parentNode.parentNode.parentNode.parentNode;
var audioInspectorPlaceholder = container.querySelector(
'.audio-inspector-placeholder');
this.moveInspector_(audioInspectorPlaceholder);
};
/**
* Stop playing audio.
*/
AudioInspector.prototype.stopAudio = function() {
console.info('Pausing audio play out.');
this.audioPlayer_.pause();
};
/**
* Show a text message using the snackbar.
*/
AudioInspector.prototype.showNotification_ = function(text) {
try {
this.snackbarContainer_.MaterialSnackbar.showSnackbar({
message: text, timeout: 2000});
} catch (e) {
// Fallback to an alert.
alert(text);
console.warn('Cannot use snackbar: "' + e + '"');
}
}
/**
* Move the audio inspector DOM node into the given parent.
* @param {DOMElement} newParentNode: New parent for the inspector.
*/
AudioInspector.prototype.moveInspector_ = function(newParentNode) {
newParentNode.appendChild(this.audioInspector_);
};
/**
* Play audio file from url.
* @param {string} metadataFieldName: Metadata field name.
*/
AudioInspector.prototype.playAudio = function(metadataFieldName) {
if (this.metadata_[metadataFieldName] == undefined) { return; }
if (this.metadata_[metadataFieldName] == 'None') {
alert('The selected stream was not used during the experiment.');
return;
}
this.stopAudio();
this.audioPlayer_.src = this.metadata_[metadataFieldName];
console.debug('Audio source URL: "' + this.audioPlayer_.src + '"');
this.audioPlayer_.play();
console.info('Playing out audio.');
};
/**
* Create hidden text areas to copy URLs.
*
* For each dialog, one text area is created since it is not possible to select
* text on a text area outside of the active dialog.
*/
AudioInspector.prototype.createTextAreasForCopy_ = function() {
var self = this;
document.querySelectorAll('dialog.mdl-dialog').forEach(function(element) {
var textArea = document.createElement("textarea");
textArea.classList.add('url-copy');
textArea.style.position = 'fixed';
textArea.style.bottom = 0;
textArea.style.left = 0;
textArea.style.width = '2em';
textArea.style.height = '2em';
textArea.style.border = 'none';
textArea.style.outline = 'none';
textArea.style.boxShadow = 'none';
textArea.style.background = 'transparent';
textArea.style.fontSize = '6px';
element.appendChild(textArea);
});
}
/**
* Create audio inspector.
*/
AudioInspector.prototype.createAudioInspector_ = function() {
var buttonIndex = 0;
function getButtonHtml(icon, toolTipText, caption, metadataFieldName) {
var buttonId = 'audioInspectorButton' + buttonIndex++;
html = caption == null ? '' : caption;
html += '<button class="mdl-button mdl-js-button mdl-button--icon ' +
'mdl-js-ripple-effect" id="' + buttonId + '">' +
'<i class="material-icons">' + icon + '</i>' +
'<div class="mdl-tooltip" data-mdl-for="' + buttonId + '">' +
toolTipText +
'</div>';
if (metadataFieldName != null) {
html += '<input type="hidden" value="' + metadataFieldName + '">'
}
html += '</button>'
return html;
}
// TODO(alessiob): Add timeline and highlight current track by changing icon
// color.
this.audioInspector_ = document.createElement('div');
this.audioInspector_.classList.add('audio-inspector');
this.audioInspector_.innerHTML =
'<div class="mdl-grid">' +
'<div class="mdl-layout-spacer"></div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'Simulated echo', 'E<sub>in</sub>',
'echo_filepath') +
'</div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('stop', 'Stop playing [S]', null, '__stop__') +
'</div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'Render stream', 'R<sub>in</sub>',
'render_filepath') +
'</div>' +
'<div class="mdl-layout-spacer"></div>' +
'</div>' +
'<div class="mdl-grid">' +
'<div class="mdl-layout-spacer"></div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'Capture stream (APM input) [1]',
'Y\'<sub>in</sub>', 'capture_filepath') +
'</div>' +
'<div class="mdl-cell mdl-cell--2-col"><strong>APM</strong></div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'APM output [2]', 'Y<sub>out</sub>',
'apm_output_filepath') +
'</div>' +
'<div class="mdl-layout-spacer"></div>' +
'</div>' +
'<div class="mdl-grid">' +
'<div class="mdl-layout-spacer"></div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'Echo-free capture stream',
'Y<sub>in</sub>', 'echo_free_capture_filepath') +
'</div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'Clean capture stream',
'Y<sub>clean</sub>', 'clean_capture_input_filepath') +
'</div>' +
'<div class="mdl-cell mdl-cell--2-col">' +
getButtonHtml('play_arrow', 'APM reference [3]', 'Y<sub>ref</sub>',
'apm_reference_filepath') +
'</div>' +
'<div class="mdl-layout-spacer"></div>' +
'</div>';
// Add an invisible node as initial container for the audio inspector.
var parent = document.createElement('div');
parent.style.display = 'none';
this.moveInspector_(parent);
document.body.appendChild(parent);
};
/**
* Initialize event handlers.
*/
AudioInspector.prototype.initializeEventHandlers_ = function() {
var self = this;
// Score cells.
document.querySelectorAll('td.single-score-cell').forEach(function(element) {
element.onclick = function() {
self.selectedScoreChange(this);
}
});
// Copy anchor URLs icons.
if (document.queryCommandSupported('copy')) {
document.querySelectorAll('td.single-score-cell button').forEach(
function(element) {
element.onclick = function() {
// Find the text area in the dialog.
var textArea = element.closest('dialog').querySelector(
'textarea.url-copy');
// Copy.
textArea.value = self.baseUrl_ + '#' + element.getAttribute(
'data-anchor');
textArea.select();
try {
if (!document.execCommand('copy'))
throw 'Copy returned false';
self.showNotification_('Experiment URL copied.');
} catch (e) {
self.showNotification_('Cannot copy experiment URL :(');
console.error(e);
}
}
});
} else {
self.showNotification_(
'The copy command is disabled. URL copy is not enabled.');
}
// Audio inspector buttons.
this.audioInspector_.querySelectorAll('button').forEach(function(element) {
var target = element.querySelector('input[type=hidden]');
if (target == null) { return; }
element.onclick = function() {
if (target.value == '__stop__') {
self.stopAudio();
} else {
self.playAudio(target.value);
}
};
});
// Dialog close handlers.
var dialogs = document.querySelectorAll('dialog').forEach(function(element) {
element.onclose = function() {
self.stopAudio();
}
});
// Keyboard shortcuts.
window.onkeyup = function(e) {
var key = e.keyCode ? e.keyCode : e.which;
switch (key) {
case 49: // 1.
self.playAudio('capture_filepath');
break;
case 50: // 2.
self.playAudio('apm_output_filepath');
break;
case 51: // 3.
self.playAudio('apm_reference_filepath');
break;
case 83: // S.
case 115: // s.
self.stopAudio();
break;
}
};
// Hash change.
window.onhashchange = function(e) {
self.parseWindowAnchor();
}
};

View File

@ -1,359 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Signal processing utility module.
"""
import array
import logging
import os
import sys
import enum
try:
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
try:
import pydub
import pydub.generators
except ImportError:
logging.critical('Cannot import the third-party Python package pydub')
sys.exit(1)
try:
import scipy.signal
import scipy.fftpack
except ImportError:
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
from . import exceptions
class SignalProcessingUtils(object):
"""Collection of signal processing utilities.
"""
@enum.unique
class MixPadding(enum.Enum):
NO_PADDING = 0
ZERO_PADDING = 1
LOOP = 2
def __init__(self):
pass
@classmethod
def LoadWav(cls, filepath, channels=1):
"""Loads wav file.
Args:
filepath: path to the wav audio track file to load.
channels: number of channels (downmixing to mono by default).
Returns:
AudioSegment instance.
"""
if not os.path.exists(filepath):
logging.error('cannot find the <%s> audio track file', filepath)
raise exceptions.FileNotFoundError()
return pydub.AudioSegment.from_file(filepath,
format='wav',
channels=channels)
@classmethod
def SaveWav(cls, output_filepath, signal):
"""Saves wav file.
Args:
output_filepath: path to the wav audio track file to save.
signal: AudioSegment instance.
"""
return signal.export(output_filepath, format='wav')
@classmethod
def CountSamples(cls, signal):
"""Number of samples per channel.
Args:
signal: AudioSegment instance.
Returns:
An integer.
"""
number_of_samples = len(signal.get_array_of_samples())
assert signal.channels > 0
assert number_of_samples % signal.channels == 0
return number_of_samples / signal.channels
@classmethod
def GenerateSilence(cls, duration=1000, sample_rate=48000):
"""Generates silence.
This method can also be used to create a template AudioSegment instance.
A template can then be used with other Generate*() methods accepting an
AudioSegment instance as argument.
Args:
duration: duration in ms.
sample_rate: sample rate.
Returns:
AudioSegment instance.
"""
return pydub.AudioSegment.silent(duration, sample_rate)
@classmethod
def GeneratePureTone(cls, template, frequency=440.0):
"""Generates a pure tone.
The pure tone is generated with the same duration and in the same format of
the given template signal.
Args:
template: AudioSegment instance.
frequency: Frequency of the pure tone in Hz.
Return:
AudioSegment instance.
"""
if frequency > template.frame_rate >> 1:
raise exceptions.SignalProcessingException('Invalid frequency')
generator = pydub.generators.Sine(sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8,
freq=frequency)
return generator.to_audio_segment(duration=len(template), volume=0.0)
@classmethod
def GenerateWhiteNoise(cls, template):
"""Generates white noise.
The white noise is generated with the same duration and in the same format
of the given template signal.
Args:
template: AudioSegment instance.
Return:
AudioSegment instance.
"""
generator = pydub.generators.WhiteNoise(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8)
return generator.to_audio_segment(duration=len(template), volume=0.0)
@classmethod
def AudioSegmentToRawData(cls, signal):
samples = signal.get_array_of_samples()
if samples.typecode != 'h':
raise exceptions.SignalProcessingException(
'Unsupported samples type')
return np.array(signal.get_array_of_samples(), np.int16)
@classmethod
def Fft(cls, signal, normalize=True):
if signal.channels != 1:
raise NotImplementedError('multiple-channel FFT not implemented')
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
if normalize:
x /= max(abs(np.max(x)), 1.0)
y = scipy.fftpack.fft(x)
return y[:len(y) / 2]
@classmethod
def DetectHardClipping(cls, signal, threshold=2):
"""Detects hard clipping.
Hard clipping is simply detected by counting samples that touch either the
lower or upper bound too many times in a row (according to `threshold`).
The presence of a single sequence of samples meeting such property is enough
to label the signal as hard clipped.
Args:
signal: AudioSegment instance.
threshold: minimum number of samples at full-scale in a row.
Returns:
True if hard clipping is detect, False otherwise.
"""
if signal.channels != 1:
raise NotImplementedError(
'multiple-channel clipping not implemented')
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')
samples = cls.AudioSegmentToRawData(signal)
# Detect adjacent clipped samples.
samples_type_info = np.iinfo(samples.dtype)
mask_min = samples == samples_type_info.min
mask_max = samples == samples_type_info.max
def HasLongSequence(vector, min_legth=threshold):
"""Returns True if there are one or more long sequences of True flags."""
seq_length = 0
for b in vector:
seq_length = seq_length + 1 if b else 0
if seq_length >= min_legth:
return True
return False
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
@classmethod
def ApplyImpulseResponse(cls, signal, impulse_response):
"""Applies an impulse response to a signal.
Args:
signal: AudioSegment instance.
impulse_response: list or numpy vector of float values.
Returns:
AudioSegment instance.
"""
# Get samples.
assert signal.channels == 1, (
'multiple-channel recordings not supported')
samples = signal.get_array_of_samples()
# Convolve.
logging.info(
'applying %d order impulse response to a signal lasting %d ms',
len(impulse_response), len(signal))
convolved_samples = scipy.signal.fftconvolve(in1=samples,
in2=impulse_response,
mode='full').astype(
np.int16)
logging.info('convolution computed')
# Cast.
convolved_samples = array.array(signal.array_type, convolved_samples)
# Verify.
logging.debug('signal length: %d samples', len(samples))
logging.debug('convolved signal length: %d samples',
len(convolved_samples))
assert len(convolved_samples) > len(samples)
# Generate convolved signal AudioSegment instance.
convolved_signal = pydub.AudioSegment(data=convolved_samples,
metadata={
'sample_width':
signal.sample_width,
'frame_rate':
signal.frame_rate,
'frame_width':
signal.frame_width,
'channels': signal.channels,
})
assert len(convolved_signal) > len(signal)
return convolved_signal
@classmethod
def Normalize(cls, signal):
"""Normalizes a signal.
Args:
signal: AudioSegment instance.
Returns:
An AudioSegment instance.
"""
return signal.apply_gain(-signal.max_dBFS)
@classmethod
def Copy(cls, signal):
"""Makes a copy os a signal.
Args:
signal: AudioSegment instance.
Returns:
An AudioSegment instance.
"""
return pydub.AudioSegment(data=signal.get_array_of_samples(),
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
@classmethod
def MixSignals(cls,
signal,
noise,
target_snr=0.0,
pad_noise=MixPadding.NO_PADDING):
"""Mixes `signal` and `noise` with a target SNR.
Mix `signal` and `noise` with a desired SNR by scaling `noise`.
If the target SNR is +/- infinite, a copy of signal/noise is returned.
If `signal` is shorter than `noise`, the length of the mix equals that of
`signal`. Otherwise, the mix length depends on whether padding is applied.
When padding is not applied, that is `pad_noise` is set to NO_PADDING
(default), the mix length equals that of `noise` - i.e., `signal` is
truncated. Otherwise, `noise` is extended and the resulting mix has the same
length of `signal`.
Args:
signal: AudioSegment instance (signal).
noise: AudioSegment instance (noise).
target_snr: float, numpy.Inf or -numpy.Inf (dB).
pad_noise: SignalProcessingUtils.MixPadding, default: NO_PADDING.
Returns:
An AudioSegment instance.
"""
# Handle infinite target SNR.
if target_snr == -np.inf:
# Return a copy of noise.
logging.warning('SNR = -Inf, returning noise')
return cls.Copy(noise)
elif target_snr == np.inf:
# Return a copy of signal.
logging.warning('SNR = +Inf, returning signal')
return cls.Copy(signal)
# Check signal and noise power.
signal_power = float(signal.dBFS)
noise_power = float(noise.dBFS)
if signal_power == -np.inf:
logging.error('signal has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
if noise_power == -np.inf:
logging.error('noise has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
# Mix.
gain_db = signal_power - noise_power - target_snr
signal_duration = len(signal)
noise_duration = len(noise)
if signal_duration <= noise_duration:
# Ignore `pad_noise`, `noise` is truncated if longer that `signal`, the
# mix will have the same length of `signal`.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.NO_PADDING:
# `signal` is longer than `noise`, but no padding is applied to `noise`.
# Truncate `signal`.
return noise.overlay(signal, gain_during_overlay=gain_db)
elif pad_noise == cls.MixPadding.ZERO_PADDING:
# TODO(alessiob): Check that this works as expected.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.LOOP:
# `signal` is longer than `noise`, extend `noise` by looping.
return signal.overlay(noise.apply_gain(gain_db), loop=True)
else:
raise exceptions.SignalProcessingException('invalid padding type')

View File

@ -1,183 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the signal_processing module.
"""
import unittest
import numpy as np
import pydub
from . import exceptions
from . import signal_processing
class TestSignalProcessing(unittest.TestCase):
"""Unit tests for the signal_processing module.
"""
def testMixSignals(self):
# Generate a template signal with which white noise can be generated.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Generate two distinct AudioSegment instances with 1 second of white noise.
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
# Extract samples.
signal_samples = signal.get_array_of_samples()
noise_samples = noise.get_array_of_samples()
# Test target SNR -inf (noise expected).
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, -np.inf)
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
# Test target SNR 0.0 (different data expected).
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, 0.0)
self.assertTrue(len(signal), len(mix_0)) # Check duration.
self.assertTrue(len(noise), len(mix_0))
mix_0_samples = mix_0.get_array_of_samples()
self.assertTrue(
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
self.assertTrue(
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
# Test target SNR +inf (signal expected).
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, np.inf)
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
def testMixSignalsMinInfPower(self):
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
signal, silence, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
silence, signal, 0.0)
def testMixSignalNoiseDifferentLengths(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
# When the signal is shorter than the noise, the mix length always equals
# that of the signal regardless of whether padding is applied.
# No noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
self.assertEqual(len(shorter), len(mix))
# When the signal is longer than the noise, the mix length depends on
# whether padding is applied.
# No noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
self.assertEqual(len(longer), len(mix))
def testMixSignalNoisePaddingTypes(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
# Zero padding: expect pure tone only in 1-2s.
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
# Loop: expect pure tone plus noise in 1-2s.
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
def Energy(signal):
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
signal).astype(np.float32)
return np.sum(samples * samples)
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
e_mix_loop = Energy(mix_loop[-1000:])
self.assertLess(0, e_mix_zero_pad)
self.assertLess(e_mix_zero_pad, e_mix_loop)
def testMixSignalSnr(self):
# Test signals.
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
def ToneAmplitudes(mix):
"""Returns the amplitude of the coefficients #16 and #192, which
correspond to the tones at 250 and 3k Hz respectively."""
mix_fft = np.absolute(
signal_processing.SignalProcessingUtils.Fft(mix))
return mix_fft[16], mix_fft[192]
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low, noise=tone_high, target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high, noise=tone_low, target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low, noise=tone_high, target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high, noise=tone_low, target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)

View File

@ -1,446 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""APM module simulator.
"""
import logging
import os
from . import annotations
from . import data_access
from . import echo_path_simulation
from . import echo_path_simulation_factory
from . import eval_scores
from . import exceptions
from . import input_mixer
from . import input_signal_creator
from . import signal_processing
from . import test_data_generation
class ApmModuleSimulator(object):
"""Audio processing module (APM) simulator class.
"""
_TEST_DATA_GENERATOR_CLASSES = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
_PREFIX_APM_CONFIG = 'apmcfg-'
_PREFIX_CAPTURE = 'capture-'
_PREFIX_RENDER = 'render-'
_PREFIX_ECHO_SIMULATOR = 'echosim-'
_PREFIX_TEST_DATA_GEN = 'datagen-'
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
_PREFIX_SCORE = 'score-'
def __init__(self,
test_data_generator_factory,
evaluation_score_factory,
ap_wrapper,
evaluator,
external_vads=None):
if external_vads is None:
external_vads = {}
self._test_data_generator_factory = test_data_generator_factory
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor(
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
external_vads)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(
self._PREFIX_TEST_DATA_GEN_PARAMS)
self._evaluation_score_factory.SetScoreFilenamePrefix(
self._PREFIX_SCORE)
# Properties for each run.
self._base_output_path = None
self._output_cache_path = None
self._test_data_generators = None
self._evaluation_score_workers = None
self._config_filepaths = None
self._capture_input_filepaths = None
self._render_input_filepaths = None
self._echo_path_simulator_class = None
@classmethod
def GetPrefixApmConfig(cls):
return cls._PREFIX_APM_CONFIG
@classmethod
def GetPrefixCapture(cls):
return cls._PREFIX_CAPTURE
@classmethod
def GetPrefixRender(cls):
return cls._PREFIX_RENDER
@classmethod
def GetPrefixEchoSimulator(cls):
return cls._PREFIX_ECHO_SIMULATOR
@classmethod
def GetPrefixTestDataGenerator(cls):
return cls._PREFIX_TEST_DATA_GEN
@classmethod
def GetPrefixTestDataGeneratorParameters(cls):
return cls._PREFIX_TEST_DATA_GEN_PARAMS
@classmethod
def GetPrefixScore(cls):
return cls._PREFIX_SCORE
def Run(self,
config_filepaths,
capture_input_filepaths,
test_data_generator_names,
eval_score_names,
output_dir,
render_input_filepaths=None,
echo_path_simulator_name=(
echo_path_simulation.NoEchoPathSimulator.NAME)):
"""Runs the APM simulation.
Initializes paths and required instances, then runs all the simulations.
The render input can be optionally added. If added, the number of capture
input audio tracks and the number of render input audio tracks have to be
equal. The two lists are used to form pairs of capture and render input.
Args:
config_filepaths: set of APM configuration files to test.
capture_input_filepaths: set of capture input audio track files to test.
test_data_generator_names: set of test data generator names to test.
eval_score_names: set of evaluation score names to test.
output_dir: base path to the output directory for wav files and outcomes.
render_input_filepaths: set of render input audio track files to test.
echo_path_simulator_name: name of the echo path simulator to use when
render input is provided.
"""
assert render_input_filepaths is None or (
len(capture_input_filepaths) == len(render_input_filepaths)), (
'render input set size not matching input set size')
assert render_input_filepaths is None or echo_path_simulator_name in (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
'invalid echo path simulator')
self._base_output_path = os.path.abspath(output_dir)
# Output path used to cache the data shared across simulations.
self._output_cache_path = os.path.join(self._base_output_path,
'_cache')
# Instance test data generators.
self._test_data_generators = [
self._test_data_generator_factory.GetInstance(
test_data_generators_class=(
self._TEST_DATA_GENERATOR_CLASSES[name]))
for name in (test_data_generator_names)
]
# Instance evaluation score workers.
self._evaluation_score_workers = [
self._evaluation_score_factory.GetInstance(
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name])
for (name) in eval_score_names
]
# Set APM configuration file paths.
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
# Set probing signal file paths.
if render_input_filepaths is None:
# Capture input only.
self._capture_input_filepaths = self._CreatePathsCollection(
capture_input_filepaths)
self._render_input_filepaths = None
else:
# Set both capture and render input signals.
self._SetTestInputSignalFilePaths(capture_input_filepaths,
render_input_filepaths)
# Set the echo path simulator class.
self._echo_path_simulator_class = (
echo_path_simulation.EchoPathSimulator.
REGISTERED_CLASSES[echo_path_simulator_name])
self._SimulateAll()
def _SimulateAll(self):
"""Runs all the simulations.
Iterates over the combinations of APM configurations, probing signals, and
test data generators. This method is mainly responsible for the creation of
the cache and output directories required in order to call _Simulate().
"""
without_render_input = self._render_input_filepaths is None
# Try different APM config files.
for config_name in self._config_filepaths:
config_filepath = self._config_filepaths[config_name]
# Try different capture-render pairs.
for capture_input_name in self._capture_input_filepaths:
# Output path for the capture signal annotations.
capture_annotations_cache_path = os.path.join(
self._output_cache_path,
self._PREFIX_CAPTURE + capture_input_name)
data_access.MakeDirectory(capture_annotations_cache_path)
# Capture.
capture_input_filepath = self._capture_input_filepaths[
capture_input_name]
if not os.path.exists(capture_input_filepath):
# If the input signal file does not exist, try to create using the
# available input signal creators.
self._CreateInputSignal(capture_input_filepath)
assert os.path.exists(capture_input_filepath)
self._ExtractCaptureAnnotations(
capture_input_filepath, capture_annotations_cache_path)
# Render and simulated echo path (optional).
render_input_filepath = None if without_render_input else (
self._render_input_filepaths[capture_input_name])
render_input_name = '(none)' if without_render_input else (
self._ExtractFileName(render_input_filepath))
echo_path_simulator = (echo_path_simulation_factory.
EchoPathSimulatorFactory.GetInstance(
self._echo_path_simulator_class,
render_input_filepath))
# Try different test data generators.
for test_data_generators in self._test_data_generators:
logging.info(
'APM config preset: <%s>, capture: <%s>, render: <%s>,'
'test data generator: <%s>, echo simulator: <%s>',
config_name, capture_input_name, render_input_name,
test_data_generators.NAME, echo_path_simulator.NAME)
# Output path for the generated test data.
test_data_cache_path = os.path.join(
capture_annotations_cache_path,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(test_data_cache_path)
logging.debug('test data cache path: <%s>',
test_data_cache_path)
# Output path for the echo simulator and APM input mixer output.
echo_test_data_cache_path = os.path.join(
test_data_cache_path,
'echosim-{}'.format(echo_path_simulator.NAME))
data_access.MakeDirectory(echo_test_data_cache_path)
logging.debug('echo test data cache path: <%s>',
echo_test_data_cache_path)
# Full output path.
output_path = os.path.join(
self._base_output_path,
self._PREFIX_APM_CONFIG + config_name,
self._PREFIX_CAPTURE + capture_input_name,
self._PREFIX_RENDER + render_input_name,
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(output_path)
logging.debug('output path: <%s>', output_path)
self._Simulate(test_data_generators,
capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path,
config_filepath, echo_path_simulator)
@staticmethod
def _CreateInputSignal(input_signal_filepath):
"""Creates a missing input signal file.
The file name is parsed to extract input signal creator and params. If a
creator is matched and the parameters are valid, a new signal is generated
and written in `input_signal_filepath`.
Args:
input_signal_filepath: Path to the input signal audio file to write.
Raises:
InputSignalCreatorException
"""
filename = os.path.splitext(
os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
def _ExtractCaptureAnnotations(self,
input_filepath,
output_path,
annotation_name=""):
self._annotator.Extract(input_filepath)
self._annotator.Save(output_path, annotation_name)
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path, config_filepath,
echo_path_simulator):
"""Runs a single set of simulation.
Simulates a given combination of APM configuration, probing signal, and
test data generator. It iterates over the test data generator
internal configurations.
Args:
test_data_generators: TestDataGenerator instance.
clean_capture_input_filepath: capture input audio track file to be
processed by a test data generator and
not affected by echo.
render_input_filepath: render input audio track file to test.
test_data_cache_path: path for the generated test audio track files.
echo_test_data_cache_path: path for the echo simulator.
output_path: base output path for the test data generator.
config_filepath: APM configuration file to test.
echo_path_simulator: EchoPathSimulator instance.
"""
# Generate pairs of noisy input and reference signal files.
test_data_generators.Generate(
input_signal_filepath=clean_capture_input_filepath,
test_data_cache_path=test_data_cache_path,
base_output_path=output_path)
# Extract metadata linked to the clean input file (if any).
apm_input_metadata = None
try:
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
clean_capture_input_filepath)
except IOError as e:
apm_input_metadata = {}
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
apm_input_metadata['test_data_gen_config'] = None
# For each test data pair, simulate a call and evaluate.
for config_name in test_data_generators.config_names:
logging.info(' - test data generator config: <%s>', config_name)
apm_input_metadata['test_data_gen_config'] = config_name
# Paths to the test data generator output.
# Note that the reference signal does not depend on the render input
# which is optional.
noisy_capture_input_filepath = (
test_data_generators.noisy_signal_filepaths[config_name])
reference_signal_filepath = (
test_data_generators.reference_signal_filepaths[config_name])
# Output path for the evaluation (e.g., APM output file).
evaluation_output_path = test_data_generators.apm_output_paths[
config_name]
# Paths to the APM input signals.
echo_path_filepath = echo_path_simulator.Simulate(
echo_test_data_cache_path)
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
echo_test_data_cache_path, noisy_capture_input_filepath,
echo_path_filepath)
# Extract annotations for the APM input mix.
apm_input_basepath, apm_input_filename = os.path.split(
apm_input_filepath)
self._ExtractCaptureAnnotations(
apm_input_filepath, apm_input_basepath,
os.path.splitext(apm_input_filename)[0] + '-')
# Simulate a call using APM.
self._audioproc_wrapper.Run(
config_filepath=config_filepath,
capture_input_filepath=apm_input_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path)
try:
# Evaluate.
self._evaluator.Run(
evaluation_score_workers=self._evaluation_score_workers,
apm_input_metadata=apm_input_metadata,
apm_output_filepath=self._audioproc_wrapper.
output_filepath,
reference_input_filepath=reference_signal_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path,
)
# Save simulation metadata.
data_access.Metadata.SaveAudioTestDataPaths(
output_path=evaluation_output_path,
clean_capture_input_filepath=clean_capture_input_filepath,
echo_free_capture_filepath=noisy_capture_input_filepath,
echo_filepath=echo_path_filepath,
render_filepath=render_input_filepath,
capture_filepath=apm_input_filepath,
apm_output_filepath=self._audioproc_wrapper.
output_filepath,
apm_reference_filepath=reference_signal_filepath,
apm_config_filepath=config_filepath,
)
except exceptions.EvaluationScoreException as e:
logging.warning('the evaluation failed: %s', e.message)
continue
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
render_input_filepaths):
"""Sets input and render input file paths collections.
Pairs the input and render input files by storing the file paths into two
collections. The key is the file name of the input file.
Args:
capture_input_filepaths: list of file paths.
render_input_filepaths: list of file paths.
"""
self._capture_input_filepaths = {}
self._render_input_filepaths = {}
assert len(capture_input_filepaths) == len(render_input_filepaths)
for capture_input_filepath, render_input_filepath in zip(
capture_input_filepaths, render_input_filepaths):
name = self._ExtractFileName(capture_input_filepath)
self._capture_input_filepaths[name] = os.path.abspath(
capture_input_filepath)
self._render_input_filepaths[name] = os.path.abspath(
render_input_filepath)
@classmethod
def _CreatePathsCollection(cls, filepaths):
"""Creates a collection of file paths.
Given a list of file paths, makes a collection with one item for each file
path. The value is absolute path, the key is the file name without
extenstion.
Args:
filepaths: list of file paths.
Returns:
A dict.
"""
filepaths_collection = {}
for filepath in filepaths:
name = cls._ExtractFileName(filepath)
filepaths_collection[name] = os.path.abspath(filepath)
return filepaths_collection
@classmethod
def _ExtractFileName(cls, filepath):
return os.path.splitext(os.path.split(filepath)[-1])[0]

View File

@ -1,203 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the simulation module.
"""
import logging
import os
import shutil
import tempfile
import unittest
import mock
import pydub
from . import audioproc_wrapper
from . import eval_scores_factory
from . import evaluation
from . import external_vad
from . import signal_processing
from . import simulation
from . import test_data_generation_factory
class TestApmModuleSimulator(unittest.TestCase):
"""Unit tests for the ApmModuleSimulator class.
"""
def setUp(self):
"""Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
self._tmp_path = tempfile.mkdtemp()
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path,
'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
shutil.rmtree(self._tmp_path)
def testSimulation(self):
# Instance dependencies to mock and inject.
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
# Instance non-mocked dependencies.
test_data_generator_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False))
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator,
external_vads={
'fake':
external_vad.ExternalVad(
os.path.join(os.path.dirname(__file__),
'fake_external_vad.py'), 'fake')
})
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level_mean', 'polqa']
# Run all simulations.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise test data
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(
any([os.path.exists(input_file) for input_file in (input_files)]))
# The input files are created during the simulation.
simulator.Run(config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(
all([os.path.exists(input_file) for input_file in (input_files)]))
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
eval_scores = ['thd']
# Should work.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=eval_scores,
output_dir=self._output_path)
self.assertFalse(logging.warning.called)
# Warning expected.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['white_noise'], # Not allowed with THD.
eval_score_names=eval_scores,
output_dir=self._output_path)
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))

View File

@ -1,127 +0,0 @@
// Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
//
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file in the root of the source
// tree. An additional intellectual property rights grant can be found
// in the file PATENTS. All contributing project authors may
// be found in the AUTHORS file in the root of the source tree.
#include <algorithm>
#include <array>
#include <cmath>
#include <fstream>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "common_audio/include/audio_util.h"
#include "common_audio/wav_file.h"
#include "rtc_base/logging.h"
ABSL_FLAG(std::string, i, "", "Input wav file");
ABSL_FLAG(std::string, oc, "", "Config output file");
ABSL_FLAG(std::string, ol, "", "Levels output file");
ABSL_FLAG(float, a, 5.f, "Attack (ms)");
ABSL_FLAG(float, d, 20.f, "Decay (ms)");
ABSL_FLAG(int, f, 10, "Frame length (ms)");
namespace webrtc {
namespace test {
namespace {
constexpr int kMaxSampleRate = 48000;
constexpr uint8_t kMaxFrameLenMs = 30;
constexpr size_t kMaxFrameLen = kMaxFrameLenMs * kMaxSampleRate / 1000;
const double kOneDbReduction = DbToRatio(-1.0);
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
// Check parameters.
if (absl::GetFlag(FLAGS_f) < 1 || absl::GetFlag(FLAGS_f) > kMaxFrameLenMs) {
RTC_LOG(LS_ERROR) << "Invalid frame length (min: 1, max: " << kMaxFrameLenMs
<< ")";
return 1;
}
if (absl::GetFlag(FLAGS_a) < 0 || absl::GetFlag(FLAGS_d) < 0) {
RTC_LOG(LS_ERROR) << "Attack and decay must be non-negative";
return 1;
}
// Open wav input file and check properties.
const std::string input_file = absl::GetFlag(FLAGS_i);
const std::string config_output_file = absl::GetFlag(FLAGS_oc);
const std::string levels_output_file = absl::GetFlag(FLAGS_ol);
WavReader wav_reader(input_file);
if (wav_reader.num_channels() != 1) {
RTC_LOG(LS_ERROR) << "Only mono wav files supported";
return 1;
}
if (wav_reader.sample_rate() > kMaxSampleRate) {
RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
<< ")";
return 1;
}
// Map from milliseconds to samples.
const size_t audio_frame_length = rtc::CheckedDivExact(
absl::GetFlag(FLAGS_f) * wav_reader.sample_rate(), 1000);
auto time_const = [](double c) {
return std::pow(kOneDbReduction, absl::GetFlag(FLAGS_f) / c);
};
const float attack =
absl::GetFlag(FLAGS_a) == 0.0 ? 0.0 : time_const(absl::GetFlag(FLAGS_a));
const float decay =
absl::GetFlag(FLAGS_d) == 0.0 ? 0.0 : time_const(absl::GetFlag(FLAGS_d));
// Write config to file.
std::ofstream out_config(config_output_file);
out_config << "{"
"'frame_len_ms': "
<< absl::GetFlag(FLAGS_f)
<< ", "
"'attack_ms': "
<< absl::GetFlag(FLAGS_a)
<< ", "
"'decay_ms': "
<< absl::GetFlag(FLAGS_d) << "}\n";
out_config.close();
// Measure level frame-by-frame.
std::ofstream out_levels(levels_output_file, std::ofstream::binary);
std::array<int16_t, kMaxFrameLen> samples;
float level_prev = 0.f;
while (true) {
// Process frame.
const auto read_samples =
wav_reader.ReadSamples(audio_frame_length, samples.data());
if (read_samples < audio_frame_length)
break; // EOF.
// Frame peak level.
std::transform(samples.begin(), samples.begin() + audio_frame_length,
samples.begin(), [](int16_t s) { return std::abs(s); });
const int16_t peak_level = *std::max_element(
samples.cbegin(), samples.cbegin() + audio_frame_length);
const float level_curr = static_cast<float>(peak_level) / 32768.f;
// Temporal smoothing.
auto smooth = [&level_prev, &level_curr](float c) {
return (1.0 - c) * level_curr + c * level_prev;
};
level_prev = smooth(level_curr > level_prev ? attack : decay);
// Write output.
out_levels.write(reinterpret_cast<const char*>(&level_prev), sizeof(float));
}
out_levels.close();
return 0;
}
} // namespace
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}

View File

@ -1,526 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Test data generators producing signals pairs intended to be used to
test the APM module. Each pair consists of a noisy input and a reference signal.
The former is used as APM input and it is generated by adding noise to a
clean audio track. The reference is the expected APM output.
Throughout this file, the following naming convention is used:
- input signal: the clean signal (e.g., speech),
- noise signal: the noise to be summed up to the input signal (e.g., white
noise, Gaussian noise),
- noisy signal: input + noise.
The noise signal may or may not be a function of the clean signal. For
instance, white noise is independently generated, whereas reverberation is
obtained by convolving the input signal with an impulse response.
"""
import logging
import os
import shutil
import sys
try:
import scipy.io
except ImportError:
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
from . import data_access
from . import exceptions
from . import signal_processing
class TestDataGenerator(object):
"""Abstract class responsible for the generation of noisy signals.
Given a clean signal, it generates two streams named noisy signal and
reference. The former is the clean signal deteriorated by the noise source,
the latter goes through the same deterioration process, but more "gently".
Noisy signal and reference are produced so that the reference is the signal
expected at the output of the APM module when the latter is fed with the noisy
signal.
An test data generator generates one or more pairs.
"""
NAME = None
REGISTERED_CLASSES = {}
def __init__(self, output_directory_prefix):
self._output_directory_prefix = output_directory_prefix
# Init dictionaries with one entry for each test data generator
# configuration (e.g., different SNRs).
# Noisy audio track files (stored separately in a cache folder).
self._noisy_signal_filepaths = None
# Path to be used for the APM simulation output files.
self._apm_output_paths = None
# Reference audio track files (stored separately in a cache folder).
self._reference_signal_filepaths = None
self.Clear()
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers a TestDataGenerator implementation.
Decorator to automatically register the classes that extend
TestDataGenerator.
Example usage:
@TestDataGenerator.RegisterClass
class IdentityGenerator(TestDataGenerator):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@property
def config_names(self):
return self._noisy_signal_filepaths.keys()
@property
def noisy_signal_filepaths(self):
return self._noisy_signal_filepaths
@property
def apm_output_paths(self):
return self._apm_output_paths
@property
def reference_signal_filepaths(self):
return self._reference_signal_filepaths
def Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates a set of noisy input and reference audiotrack file pairs.
This method initializes an empty set of pairs and calls the _Generate()
method implemented in a concrete class.
Args:
input_signal_filepath: path to the clean input audio track file.
test_data_cache_path: path to the cache of the generated audio track
files.
base_output_path: base path where output is written.
"""
self.Clear()
self._Generate(input_signal_filepath, test_data_cache_path,
base_output_path)
def Clear(self):
"""Clears the generated output path dictionaries.
"""
self._noisy_signal_filepaths = {}
self._apm_output_paths = {}
self._reference_signal_filepaths = {}
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Abstract method to be implemented in each concrete class.
"""
raise NotImplementedError()
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
snr_value_pairs):
"""Adds noisy-reference signal pairs.
Args:
base_output_path: noisy tracks base output path.
noisy_mix_filepaths: nested dictionary of noisy signal paths organized
by noisy track name and SNR level.
snr_value_pairs: list of SNR pairs.
"""
for noise_track_name in noisy_mix_filepaths:
for snr_noisy, snr_refence in snr_value_pairs:
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
noise_track_name, snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[noise_track_name]
[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_refence],
output_path=output_path)
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
reference_signal_filepath, output_path):
"""Adds one noisy-reference signal pair.
Args:
config_name: name of the APM configuration.
noisy_signal_filepath: path to noisy audio track file.
reference_signal_filepath: path to reference audio track file.
output_path: APM output path.
"""
assert config_name not in self._noisy_signal_filepaths
self._noisy_signal_filepaths[config_name] = os.path.abspath(
noisy_signal_filepath)
self._apm_output_paths[config_name] = os.path.abspath(output_path)
self._reference_signal_filepaths[config_name] = os.path.abspath(
reference_signal_filepath)
def _MakeDir(self, base_output_path, test_data_generator_config_name):
output_path = os.path.join(
base_output_path,
self._output_directory_prefix + test_data_generator_config_name)
data_access.MakeDirectory(output_path)
return output_path
@TestDataGenerator.RegisterClass
class IdentityTestDataGenerator(TestDataGenerator):
"""Generator that adds no noise.
Both the noisy and the reference signals are the input signal.
"""
NAME = 'identity'
def __init__(self, output_directory_prefix, copy_with_identity):
TestDataGenerator.__init__(self, output_directory_prefix)
self._copy_with_identity = copy_with_identity
@property
def copy_with_identity(self):
return self._copy_with_identity
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
config_name = 'default'
output_path = self._MakeDir(base_output_path, config_name)
if self._copy_with_identity:
input_signal_filepath_new = os.path.join(
test_data_cache_path,
os.path.split(input_signal_filepath)[1])
logging.info('copying ' + input_signal_filepath + ' to ' +
(input_signal_filepath_new))
shutil.copy(input_signal_filepath, input_signal_filepath_new)
input_signal_filepath = input_signal_filepath_new
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=input_signal_filepath,
reference_signal_filepath=input_signal_filepath,
output_path=output_path)
@TestDataGenerator.RegisterClass
class WhiteNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds white noise.
"""
NAME = 'white_noise'
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Create the noise track.
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
input_signal)
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths = {}
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[snr] = noisy_signal_filepath
# Add all the noisy-reference signal pairs.
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
output_path=output_path)
# TODO(alessiob): remove comment when class implemented.
# @TestDataGenerator.RegisterClass
class NarrowBandNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds narrow-band noise.
"""
NAME = 'narrow_band_noise'
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
# TODO(alessiob): implement.
pass
@TestDataGenerator.RegisterClass
class AdditiveNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds noise loops.
This generator uses all the wav files in a given path (default: noise_tracks/)
and mixes them to the clean speech with different target SNRs (hard-coded).
"""
NAME = 'additive_noise'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__),
os.pardir, 'noise_tracks')
# TODO(alessiob): Make the list of SNR pairs customizable.
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
def __init__(self, output_directory_prefix, noise_tracks_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._noise_tracks_path = noise_tracks_path
self._noise_tracks_file_names = [
n for n in os.listdir(self._noise_tracks_path)
if n.lower().endswith('.wav')
]
if len(self._noise_tracks_file_names) == 0:
raise exceptions.InitializationException(
'No wav files found in the noise tracks path %s' %
(self._noise_tracks_path))
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates test data pairs using environmental noise.
For each noise track and pair of SNR values, the following two audio tracks
are created: the noisy signal and the reference signal. The former is
obtained by mixing the (clean) input signal to the corresponding noise
track enforcing the target SNR.
"""
# Init.
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
noisy_mix_filepaths = {}
for noise_track_filename in self._noise_tracks_file_names:
# Load the noise track.
noise_track_name, _ = os.path.splitext(noise_track_filename)
noise_track_filepath = os.path.join(self._noise_tracks_path,
noise_track_filename)
if not os.path.exists(noise_track_filepath):
logging.error('cannot find the <%s> noise track',
noise_track_filename)
raise exceptions.FileNotFoundError()
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[noise_track_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
noise_track_name, snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal,
noise_signal,
snr,
pad_noise=signal_processing.SignalProcessingUtils.
MixPadding.LOOP)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[noise_track_name][
snr] = noisy_signal_filepath
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
@TestDataGenerator.RegisterClass
class ReverberationTestDataGenerator(TestDataGenerator):
"""Generator that adds reverberation noise.
TODO(alessiob): Make this class more generic since the impulse response can be
anything (not just reverberation); call it e.g.,
ConvolutionalNoiseTestDataGenerator.
"""
NAME = 'reverberation'
_IMPULSE_RESPONSES = {
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
}
_MAX_IMPULSE_RESPONSE_LENGTH = None
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 5 dB higher.
_SNR_VALUE_PAIRS = [
[3, 8], # Smallest noise.
[-3, 2], # Largest noise.
]
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
def __init__(self, output_directory_prefix, aechen_ir_database_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._aechen_ir_database_path = aechen_ir_database_path
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates test data pairs using reverberation noise.
For each impulse response, one noise track is created. For each impulse
response and pair of SNR values, the following 2 audio tracks are
created: the noisy signal and the reference signal. The former is
obtained by mixing the (clean) input signal to the corresponding noise
track enforcing the target SNR.
"""
# Init.
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
noisy_mix_filepaths = {}
for impulse_response_name in self._IMPULSE_RESPONSES:
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
impulse_response_name)
noise_track_filepath = os.path.join(test_data_cache_path,
noise_track_filename)
noise_signal = None
try:
# Load noise track.
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
except exceptions.FileNotFoundError:
# Generate noise track by applying the impulse response.
impulse_response_filepath = os.path.join(
self._aechen_ir_database_path,
self._IMPULSE_RESPONSES[impulse_response_name])
noise_signal = self._GenerateNoiseTrack(
noise_track_filepath, input_signal,
impulse_response_filepath)
assert noise_signal is not None
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[impulse_response_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
impulse_response_name, snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[impulse_response_name][
snr] = noisy_signal_filepath
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
impulse_response_filepath):
"""Generates noise track.
Generate a signal by convolving input_signal with the impulse response in
impulse_response_filepath; then save to noise_track_filepath.
Args:
noise_track_filepath: output file path for the noise track.
input_signal: (clean) input signal samples.
impulse_response_filepath: impulse response file path.
Returns:
AudioSegment instance.
"""
# Load impulse response.
data = scipy.io.loadmat(impulse_response_filepath)
impulse_response = data['h_air'].flatten()
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
logging.info('truncating impulse response from %d to %d samples',
len(impulse_response),
self._MAX_IMPULSE_RESPONSE_LENGTH)
impulse_response = impulse_response[:self.
_MAX_IMPULSE_RESPONSE_LENGTH]
# Apply impulse response.
processed_signal = (
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
input_signal, impulse_response))
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noise_track_filepath, processed_signal)
return processed_signal

View File

@ -1,71 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""TestDataGenerator factory class.
"""
import logging
from . import exceptions
from . import test_data_generation
class TestDataGeneratorFactory(object):
"""Factory class used to create test data generators.
Usage: Create a factory passing parameters to the ctor with which the
generators will be produced.
"""
def __init__(self, aechen_ir_database_path, noise_tracks_path,
copy_with_identity):
"""Ctor.
Args:
aechen_ir_database_path: Path to the Aechen Impulse Response database.
noise_tracks_path: Path to the noise tracks to add.
copy_with_identity: Flag indicating whether the identity generator has to
make copies of the clean speech input files.
"""
self._output_directory_prefix = None
self._aechen_ir_database_path = aechen_ir_database_path
self._noise_tracks_path = noise_tracks_path
self._copy_with_identity = copy_with_identity
def SetOutputDirectoryPrefix(self, prefix):
self._output_directory_prefix = prefix
def GetInstance(self, test_data_generators_class):
"""Creates an TestDataGenerator instance given a class object.
Args:
test_data_generators_class: TestDataGenerator class object (not an
instance).
Returns:
TestDataGenerator instance.
"""
if self._output_directory_prefix is None:
raise exceptions.InitializationException(
'The output directory prefix for test data generators is not set'
)
logging.debug('factory producing %s', test_data_generators_class)
if test_data_generators_class == (
test_data_generation.IdentityTestDataGenerator):
return test_data_generation.IdentityTestDataGenerator(
self._output_directory_prefix, self._copy_with_identity)
elif test_data_generators_class == (
test_data_generation.ReverberationTestDataGenerator):
return test_data_generation.ReverberationTestDataGenerator(
self._output_directory_prefix, self._aechen_ir_database_path)
elif test_data_generators_class == (
test_data_generation.AdditiveNoiseTestDataGenerator):
return test_data_generation.AdditiveNoiseTestDataGenerator(
self._output_directory_prefix, self._noise_tracks_path)
else:
return test_data_generators_class(self._output_directory_prefix)

View File

@ -1,207 +0,0 @@
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the test_data_generation module.
"""
import os
import shutil
import tempfile
import unittest
import numpy as np
import scipy.io
from . import test_data_generation
from . import test_data_generation_factory
from . import signal_processing
class TestTestDataGenerators(unittest.TestCase):
"""Unit tests for the test_data_generation module.
"""
def setUp(self):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._test_data_cache_path = tempfile.mkdtemp()
self._fake_air_db_path = tempfile.mkdtemp()
# Fake AIR DB impulse responses.
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
# impulse responses. When changed, the coupling below between
# impulse_response_mat_file_names and
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
impulse_response_mat_file_names = [
'air_binaural_lecture_0_0_1.mat',
'air_binaural_booth_0_0_1.mat',
]
for impulse_response_mat_file_name in impulse_response_mat_file_names:
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
scipy.io.savemat(
os.path.join(self._fake_air_db_path,
impulse_response_mat_file_name), data)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._test_data_cache_path)
shutil.rmtree(self._fake_air_db_path)
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
# Check that there is at least one registered test data generator.
registered_classes = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance generators factory.
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=self._fake_air_db_path,
noise_tracks_path=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH,
copy_with_identity=False)
generators_factory.SetOutputDirectoryPrefix('datagen-')
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Load input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Try each registered test data generator.
for generator_name in registered_classes:
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])
# Generate the noisy input - reference pairs.
generator.Generate(input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
# Perform checks.
self._CheckGeneratedPairsListSizes(generator)
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
self._CheckGeneratedPairsOutputPaths(generator)
def testTestidentityDataGenerator(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
def GetNoiseReferenceFilePaths(identity_generator):
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
reference_signal_filepaths = identity_generator.reference_signal_filepaths
assert noisy_signal_filepaths.keys(
) == reference_signal_filepaths.keys()
assert len(noisy_signal_filepaths.keys()) == 1
key = noisy_signal_filepaths.keys()[0]
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
# Test the `copy_with_identity` flag.
for copy_with_identity in [False, True]:
# Instance the generator through the factory.
factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=copy_with_identity)
factory.SetOutputDirectoryPrefix('datagen-')
generator = factory.GetInstance(
test_data_generation.IdentityTestDataGenerator)
# Check `copy_with_identity` is set correctly.
self.assertEqual(copy_with_identity, generator.copy_with_identity)
# Generate test data and extract the paths to the noise and the reference
# files.
generator.Generate(input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
noisy_signal_filepath, reference_signal_filepath = (
GetNoiseReferenceFilePaths(generator))
# Check that a copy is made if and only if `copy_with_identity` is True.
if copy_with_identity:
self.assertNotEqual(noisy_signal_filepath,
input_signal_filepath)
self.assertNotEqual(reference_signal_filepath,
input_signal_filepath)
else:
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
self.assertEqual(reference_signal_filepath,
input_signal_filepath)
def _CheckGeneratedPairsListSizes(self, generator):
config_names = generator.config_names
number_of_pairs = len(config_names)
self.assertEqual(number_of_pairs,
len(generator.noisy_signal_filepaths))
self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
self.assertEqual(number_of_pairs,
len(generator.reference_signal_filepaths))
def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
"""Checks duration of the generated signals.
Checks that the noisy input and the reference tracks are audio files
with duration equal to or greater than that of the input signal.
Args:
generator: TestDataGenerator instance.
input_signal: AudioSegment instance.
"""
input_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
# Load the noisy input file.
noisy_signal_filepath = generator.noisy_signal_filepaths[
config_name]
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
noisy_signal_filepath)
# Check noisy input signal length.
noisy_signal_length = (signal_processing.SignalProcessingUtils.
CountSamples(noisy_signal))
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
# Load the reference file.
reference_signal_filepath = generator.reference_signal_filepaths[
config_name]
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
reference_signal_filepath)
# Check noisy input signal length.
reference_signal_length = (signal_processing.SignalProcessingUtils.
CountSamples(reference_signal))
self.assertGreaterEqual(reference_signal_length,
input_signal_length)
def _CheckGeneratedPairsOutputPaths(self, generator):
"""Checks that the output path created by the generator exists.
Args:
generator: TestDataGenerator instance.
"""
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
output_path = generator.apm_output_paths[config_name]
self.assertTrue(os.path.exists(output_path))

View File

@ -1,103 +0,0 @@
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
//
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file in the root of the source
// tree. An additional intellectual property rights grant can be found
// in the file PATENTS. All contributing project authors may
// be found in the AUTHORS file in the root of the source tree.
#include "common_audio/vad/include/vad.h"
#include <array>
#include <fstream>
#include <memory>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "common_audio/wav_file.h"
#include "rtc_base/logging.h"
ABSL_FLAG(std::string, i, "", "Input wav file");
ABSL_FLAG(std::string, o, "", "VAD output file");
namespace webrtc {
namespace test {
namespace {
// The allowed values are 10, 20 or 30 ms.
constexpr uint8_t kAudioFrameLengthMilliseconds = 30;
constexpr int kMaxSampleRate = 48000;
constexpr size_t kMaxFrameLen =
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
constexpr uint8_t kBitmaskBuffSize = 8;
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
const std::string input_file = absl::GetFlag(FLAGS_i);
const std::string output_file = absl::GetFlag(FLAGS_o);
// Open wav input file and check properties.
WavReader wav_reader(input_file);
if (wav_reader.num_channels() != 1) {
RTC_LOG(LS_ERROR) << "Only mono wav files supported";
return 1;
}
if (wav_reader.sample_rate() > kMaxSampleRate) {
RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
<< ")";
return 1;
}
const size_t audio_frame_length = rtc::CheckedDivExact(
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
if (audio_frame_length > kMaxFrameLen) {
RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
return 1;
}
// Create output file and write header.
std::ofstream out_file(output_file, std::ofstream::binary);
const char audio_frame_length_ms = kAudioFrameLengthMilliseconds;
out_file.write(&audio_frame_length_ms, 1); // Header.
// Run VAD and write decisions.
std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal);
std::array<int16_t, kMaxFrameLen> samples;
char buff = 0; // Buffer to write one bit per frame.
uint8_t next = 0; // Points to the next bit to write in `buff`.
while (true) {
// Process frame.
const auto read_samples =
wav_reader.ReadSamples(audio_frame_length, samples.data());
if (read_samples < audio_frame_length)
break;
const auto is_speech = vad->VoiceActivity(
samples.data(), audio_frame_length, wav_reader.sample_rate());
// Write output.
buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);
if (++next == kBitmaskBuffSize) {
out_file.write(&buff, 1); // Flush.
buff = 0; // Reset.
next = 0;
}
}
// Finalize.
char extra_bits = 0;
if (next > 0) {
extra_bits = kBitmaskBuffSize - next;
out_file.write(&buff, 1); // Flush.
}
out_file.write(&extra_bits, 1);
out_file.close();
return 0;
}
} // namespace
} // namespace test
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::test::main(argc, argv);
}