Reformat python files checked by pylint (part 1/2).
After recently changing .pylintrc (see [1]) we discovered that the presubmit check always checks all the python files when just one python file gets updated. This CL moves all these files one step closer to what the linter wants. Autogenerated with: # Added all the files under pylint control to ~/Desktop/to-reformat cat ~/Desktop/to-reformat | xargs sed -i '1i\\' git cl format --python --full This is part 1 out of 2. The second part will fix function names and will not be automated. [1] - https://webrtc-review.googlesource.com/c/src/+/186664 No-Presubmit: True Bug: webrtc:12114 Change-Id: Idfec4d759f209a2090440d0af2413a1ddc01b841 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/190980 Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org> Reviewed-by: Karl Wiberg <kwiberg@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32530}
This commit is contained in:
parent
d3a3e9ef36
commit
8cc6695652
2065
PRESUBMIT.py
2065
PRESUBMIT.py
File diff suppressed because it is too large
Load Diff
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""
|
||||
This script is the wrapper that runs the low-bandwidth audio test.
|
||||
|
||||
@ -23,315 +22,352 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
|
||||
|
||||
NO_TOOLS_ERROR_MESSAGE = (
|
||||
'Could not find PESQ or POLQA at %s.\n'
|
||||
'\n'
|
||||
'To fix this run:\n'
|
||||
' python %s %s\n'
|
||||
'\n'
|
||||
'Note that these tools are Google-internal due to licensing, so in order to '
|
||||
'use them you will have to get your own license and manually put them in the '
|
||||
'right location.\n'
|
||||
'See https://cs.chromium.org/chromium/src/third_party/webrtc/tools_webrtc/'
|
||||
'download_tools.py?rcl=bbceb76f540159e2dba0701ac03c514f01624130&l=13')
|
||||
'Could not find PESQ or POLQA at %s.\n'
|
||||
'\n'
|
||||
'To fix this run:\n'
|
||||
' python %s %s\n'
|
||||
'\n'
|
||||
'Note that these tools are Google-internal due to licensing, so in order to '
|
||||
'use them you will have to get your own license and manually put them in the '
|
||||
'right location.\n'
|
||||
'See https://cs.chromium.org/chromium/src/third_party/webrtc/tools_webrtc/'
|
||||
'download_tools.py?rcl=bbceb76f540159e2dba0701ac03c514f01624130&l=13')
|
||||
|
||||
|
||||
def _LogCommand(command):
|
||||
logging.info('Running %r', command)
|
||||
return command
|
||||
logging.info('Running %r', command)
|
||||
return command
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
parser = argparse.ArgumentParser(description='Run low-bandwidth audio tests.')
|
||||
parser.add_argument('build_dir',
|
||||
help='Path to the build directory (e.g. out/Release).')
|
||||
parser.add_argument('--remove', action='store_true',
|
||||
help='Remove output audio files after testing.')
|
||||
parser.add_argument('--android', action='store_true',
|
||||
help='Perform the test on a connected Android device instead.')
|
||||
parser.add_argument('--adb-path', help='Path to adb binary.', default='adb')
|
||||
parser.add_argument('--num-retries', default='0',
|
||||
help='Number of times to retry the test on Android.')
|
||||
parser.add_argument('--isolated-script-test-perf-output', default=None,
|
||||
help='Path to store perf results in histogram proto format.')
|
||||
parser.add_argument('--extra-test-args', default=[], action='append',
|
||||
help='Extra args to path to the test binary.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Run low-bandwidth audio tests.')
|
||||
parser.add_argument('build_dir',
|
||||
help='Path to the build directory (e.g. out/Release).')
|
||||
parser.add_argument('--remove',
|
||||
action='store_true',
|
||||
help='Remove output audio files after testing.')
|
||||
parser.add_argument(
|
||||
'--android',
|
||||
action='store_true',
|
||||
help='Perform the test on a connected Android device instead.')
|
||||
parser.add_argument('--adb-path',
|
||||
help='Path to adb binary.',
|
||||
default='adb')
|
||||
parser.add_argument('--num-retries',
|
||||
default='0',
|
||||
help='Number of times to retry the test on Android.')
|
||||
parser.add_argument(
|
||||
'--isolated-script-test-perf-output',
|
||||
default=None,
|
||||
help='Path to store perf results in histogram proto format.')
|
||||
parser.add_argument('--extra-test-args',
|
||||
default=[],
|
||||
action='append',
|
||||
help='Extra args to path to the test binary.')
|
||||
|
||||
# Ignore Chromium-specific flags
|
||||
parser.add_argument('--test-launcher-summary-output',
|
||||
type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
# Ignore Chromium-specific flags
|
||||
parser.add_argument('--test-launcher-summary-output',
|
||||
type=str,
|
||||
default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
return args
|
||||
|
||||
|
||||
def _GetPlatform():
|
||||
if sys.platform == 'win32':
|
||||
return 'win'
|
||||
elif sys.platform == 'darwin':
|
||||
return 'mac'
|
||||
elif sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
if sys.platform == 'win32':
|
||||
return 'win'
|
||||
elif sys.platform == 'darwin':
|
||||
return 'mac'
|
||||
elif sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
|
||||
|
||||
def _GetExtension():
|
||||
return '.exe' if sys.platform == 'win32' else ''
|
||||
return '.exe' if sys.platform == 'win32' else ''
|
||||
|
||||
|
||||
def _GetPathToTools():
|
||||
tools_dir = os.path.join(SRC_DIR, 'tools_webrtc')
|
||||
toolchain_dir = os.path.join(tools_dir, 'audio_quality')
|
||||
tools_dir = os.path.join(SRC_DIR, 'tools_webrtc')
|
||||
toolchain_dir = os.path.join(tools_dir, 'audio_quality')
|
||||
|
||||
platform = _GetPlatform()
|
||||
ext = _GetExtension()
|
||||
platform = _GetPlatform()
|
||||
ext = _GetExtension()
|
||||
|
||||
pesq_path = os.path.join(toolchain_dir, platform, 'pesq' + ext)
|
||||
if not os.path.isfile(pesq_path):
|
||||
pesq_path = None
|
||||
pesq_path = os.path.join(toolchain_dir, platform, 'pesq' + ext)
|
||||
if not os.path.isfile(pesq_path):
|
||||
pesq_path = None
|
||||
|
||||
polqa_path = os.path.join(toolchain_dir, platform, 'PolqaOem64' + ext)
|
||||
if not os.path.isfile(polqa_path):
|
||||
polqa_path = None
|
||||
polqa_path = os.path.join(toolchain_dir, platform, 'PolqaOem64' + ext)
|
||||
if not os.path.isfile(polqa_path):
|
||||
polqa_path = None
|
||||
|
||||
if (platform != 'mac' and not polqa_path) or not pesq_path:
|
||||
logging.error(NO_TOOLS_ERROR_MESSAGE,
|
||||
toolchain_dir,
|
||||
os.path.join(tools_dir, 'download_tools.py'),
|
||||
toolchain_dir)
|
||||
if (platform != 'mac' and not polqa_path) or not pesq_path:
|
||||
logging.error(NO_TOOLS_ERROR_MESSAGE, toolchain_dir,
|
||||
os.path.join(tools_dir, 'download_tools.py'),
|
||||
toolchain_dir)
|
||||
|
||||
return pesq_path, polqa_path
|
||||
return pesq_path, polqa_path
|
||||
|
||||
|
||||
def ExtractTestRuns(lines, echo=False):
|
||||
"""Extracts information about tests from the output of a test runner.
|
||||
"""Extracts information about tests from the output of a test runner.
|
||||
|
||||
Produces tuples
|
||||
(android_device, test_name, reference_file, degraded_file, cur_perf_results).
|
||||
"""
|
||||
for line in lines:
|
||||
if echo:
|
||||
sys.stdout.write(line)
|
||||
for line in lines:
|
||||
if echo:
|
||||
sys.stdout.write(line)
|
||||
|
||||
# Output from Android has a prefix with the device name.
|
||||
android_prefix_re = r'(?:I\b.+\brun_tests_on_device\((.+?)\)\s*)?'
|
||||
test_re = r'^' + android_prefix_re + (r'TEST (\w+) ([^ ]+?) ([^\s]+)'
|
||||
r' ?([^\s]+)?\s*$')
|
||||
# Output from Android has a prefix with the device name.
|
||||
android_prefix_re = r'(?:I\b.+\brun_tests_on_device\((.+?)\)\s*)?'
|
||||
test_re = r'^' + android_prefix_re + (r'TEST (\w+) ([^ ]+?) ([^\s]+)'
|
||||
r' ?([^\s]+)?\s*$')
|
||||
|
||||
match = re.search(test_re, line)
|
||||
if match:
|
||||
yield match.groups()
|
||||
match = re.search(test_re, line)
|
||||
if match:
|
||||
yield match.groups()
|
||||
|
||||
|
||||
def _GetFile(file_path, out_dir, move=False,
|
||||
android=False, adb_prefix=('adb',)):
|
||||
out_file_name = os.path.basename(file_path)
|
||||
out_file_path = os.path.join(out_dir, out_file_name)
|
||||
def _GetFile(file_path,
|
||||
out_dir,
|
||||
move=False,
|
||||
android=False,
|
||||
adb_prefix=('adb', )):
|
||||
out_file_name = os.path.basename(file_path)
|
||||
out_file_path = os.path.join(out_dir, out_file_name)
|
||||
|
||||
if android:
|
||||
# Pull the file from the connected Android device.
|
||||
adb_command = adb_prefix + ('pull', file_path, out_dir)
|
||||
subprocess.check_call(_LogCommand(adb_command))
|
||||
if move:
|
||||
# Remove that file.
|
||||
adb_command = adb_prefix + ('shell', 'rm', file_path)
|
||||
subprocess.check_call(_LogCommand(adb_command))
|
||||
elif os.path.abspath(file_path) != os.path.abspath(out_file_path):
|
||||
if move:
|
||||
shutil.move(file_path, out_file_path)
|
||||
else:
|
||||
shutil.copy(file_path, out_file_path)
|
||||
if android:
|
||||
# Pull the file from the connected Android device.
|
||||
adb_command = adb_prefix + ('pull', file_path, out_dir)
|
||||
subprocess.check_call(_LogCommand(adb_command))
|
||||
if move:
|
||||
# Remove that file.
|
||||
adb_command = adb_prefix + ('shell', 'rm', file_path)
|
||||
subprocess.check_call(_LogCommand(adb_command))
|
||||
elif os.path.abspath(file_path) != os.path.abspath(out_file_path):
|
||||
if move:
|
||||
shutil.move(file_path, out_file_path)
|
||||
else:
|
||||
shutil.copy(file_path, out_file_path)
|
||||
|
||||
return out_file_path
|
||||
return out_file_path
|
||||
|
||||
|
||||
def _RunPesq(executable_path, reference_file, degraded_file,
|
||||
def _RunPesq(executable_path,
|
||||
reference_file,
|
||||
degraded_file,
|
||||
sample_rate_hz=16000):
|
||||
directory = os.path.dirname(reference_file)
|
||||
assert os.path.dirname(degraded_file) == directory
|
||||
directory = os.path.dirname(reference_file)
|
||||
assert os.path.dirname(degraded_file) == directory
|
||||
|
||||
# Analyze audio.
|
||||
command = [executable_path, '+%d' % sample_rate_hz,
|
||||
os.path.basename(reference_file),
|
||||
os.path.basename(degraded_file)]
|
||||
# Need to provide paths in the current directory due to a bug in PESQ:
|
||||
# On Mac, for some 'path/to/file.wav', if 'file.wav' is longer than
|
||||
# 'path/to', PESQ crashes.
|
||||
out = subprocess.check_output(_LogCommand(command),
|
||||
cwd=directory, stderr=subprocess.STDOUT)
|
||||
# Analyze audio.
|
||||
command = [
|
||||
executable_path,
|
||||
'+%d' % sample_rate_hz,
|
||||
os.path.basename(reference_file),
|
||||
os.path.basename(degraded_file)
|
||||
]
|
||||
# Need to provide paths in the current directory due to a bug in PESQ:
|
||||
# On Mac, for some 'path/to/file.wav', if 'file.wav' is longer than
|
||||
# 'path/to', PESQ crashes.
|
||||
out = subprocess.check_output(_LogCommand(command),
|
||||
cwd=directory,
|
||||
stderr=subprocess.STDOUT)
|
||||
|
||||
# Find the scores in stdout of PESQ.
|
||||
match = re.search(
|
||||
r'Prediction \(Raw MOS, MOS-LQO\):\s+=\s+([\d.]+)\s+([\d.]+)', out)
|
||||
if match:
|
||||
raw_mos, _ = match.groups()
|
||||
# Find the scores in stdout of PESQ.
|
||||
match = re.search(
|
||||
r'Prediction \(Raw MOS, MOS-LQO\):\s+=\s+([\d.]+)\s+([\d.]+)', out)
|
||||
if match:
|
||||
raw_mos, _ = match.groups()
|
||||
|
||||
return {'pesq_mos': (raw_mos, 'unitless')}
|
||||
else:
|
||||
logging.error('PESQ: %s', out.splitlines()[-1])
|
||||
return {}
|
||||
return {'pesq_mos': (raw_mos, 'unitless')}
|
||||
else:
|
||||
logging.error('PESQ: %s', out.splitlines()[-1])
|
||||
return {}
|
||||
|
||||
|
||||
def _RunPolqa(executable_path, reference_file, degraded_file):
|
||||
# Analyze audio.
|
||||
command = [executable_path, '-q', '-LC', 'NB',
|
||||
'-Ref', reference_file, '-Test', degraded_file]
|
||||
process = subprocess.Popen(_LogCommand(command),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, err = process.communicate()
|
||||
# Analyze audio.
|
||||
command = [
|
||||
executable_path, '-q', '-LC', 'NB', '-Ref', reference_file, '-Test',
|
||||
degraded_file
|
||||
]
|
||||
process = subprocess.Popen(_LogCommand(command),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
out, err = process.communicate()
|
||||
|
||||
# Find the scores in stdout of POLQA.
|
||||
match = re.search(r'\bMOS-LQO:\s+([\d.]+)', out)
|
||||
# Find the scores in stdout of POLQA.
|
||||
match = re.search(r'\bMOS-LQO:\s+([\d.]+)', out)
|
||||
|
||||
if process.returncode != 0 or not match:
|
||||
if process.returncode == 2:
|
||||
logging.warning('%s (2)', err.strip())
|
||||
logging.warning('POLQA license error, skipping test.')
|
||||
else:
|
||||
logging.error('%s (%d)', err.strip(), process.returncode)
|
||||
return {}
|
||||
if process.returncode != 0 or not match:
|
||||
if process.returncode == 2:
|
||||
logging.warning('%s (2)', err.strip())
|
||||
logging.warning('POLQA license error, skipping test.')
|
||||
else:
|
||||
logging.error('%s (%d)', err.strip(), process.returncode)
|
||||
return {}
|
||||
|
||||
mos_lqo, = match.groups()
|
||||
return {'polqa_mos_lqo': (mos_lqo, 'unitless')}
|
||||
mos_lqo, = match.groups()
|
||||
return {'polqa_mos_lqo': (mos_lqo, 'unitless')}
|
||||
|
||||
|
||||
def _MergeInPerfResultsFromCcTests(histograms, run_perf_results_file):
|
||||
from tracing.value import histogram_set
|
||||
from tracing.value import histogram_set
|
||||
|
||||
cc_histograms = histogram_set.HistogramSet()
|
||||
with open(run_perf_results_file, 'rb') as f:
|
||||
contents = f.read()
|
||||
if not contents:
|
||||
return
|
||||
cc_histograms = histogram_set.HistogramSet()
|
||||
with open(run_perf_results_file, 'rb') as f:
|
||||
contents = f.read()
|
||||
if not contents:
|
||||
return
|
||||
|
||||
cc_histograms.ImportProto(contents)
|
||||
cc_histograms.ImportProto(contents)
|
||||
|
||||
histograms.Merge(cc_histograms)
|
||||
histograms.Merge(cc_histograms)
|
||||
|
||||
|
||||
Analyzer = collections.namedtuple('Analyzer', ['name', 'func', 'executable',
|
||||
'sample_rate_hz'])
|
||||
Analyzer = collections.namedtuple(
|
||||
'Analyzer', ['name', 'func', 'executable', 'sample_rate_hz'])
|
||||
|
||||
|
||||
def _ConfigurePythonPath(args):
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkout_root = os.path.abspath(
|
||||
os.path.join(script_dir, os.pardir, os.pardir))
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkout_root = os.path.abspath(
|
||||
os.path.join(script_dir, os.pardir, os.pardir))
|
||||
|
||||
# TODO(https://crbug.com/1029452): Use a copy rule and add these from the out
|
||||
# dir like for the third_party/protobuf code.
|
||||
sys.path.insert(0, os.path.join(checkout_root, 'third_party', 'catapult',
|
||||
'tracing'))
|
||||
# TODO(https://crbug.com/1029452): Use a copy rule and add these from the out
|
||||
# dir like for the third_party/protobuf code.
|
||||
sys.path.insert(
|
||||
0, os.path.join(checkout_root, 'third_party', 'catapult', 'tracing'))
|
||||
|
||||
# The low_bandwidth_audio_perf_test gn rule will build the protobuf stub for
|
||||
# python, so put it in the path for this script before we attempt to import
|
||||
# it.
|
||||
histogram_proto_path = os.path.join(
|
||||
os.path.abspath(args.build_dir), 'pyproto', 'tracing', 'tracing', 'proto')
|
||||
sys.path.insert(0, histogram_proto_path)
|
||||
proto_stub_path = os.path.join(os.path.abspath(args.build_dir), 'pyproto')
|
||||
sys.path.insert(0, proto_stub_path)
|
||||
# The low_bandwidth_audio_perf_test gn rule will build the protobuf stub for
|
||||
# python, so put it in the path for this script before we attempt to import
|
||||
# it.
|
||||
histogram_proto_path = os.path.join(os.path.abspath(args.build_dir),
|
||||
'pyproto', 'tracing', 'tracing',
|
||||
'proto')
|
||||
sys.path.insert(0, histogram_proto_path)
|
||||
proto_stub_path = os.path.join(os.path.abspath(args.build_dir), 'pyproto')
|
||||
sys.path.insert(0, proto_stub_path)
|
||||
|
||||
# Fail early in case the proto hasn't been built.
|
||||
try:
|
||||
import histogram_pb2
|
||||
except ImportError as e:
|
||||
logging.exception(e)
|
||||
raise ImportError('Could not import histogram_pb2. You need to build the '
|
||||
'low_bandwidth_audio_perf_test target before invoking '
|
||||
'this script. Expected to find '
|
||||
'histogram_pb2.py in %s.' % histogram_proto_path)
|
||||
# Fail early in case the proto hasn't been built.
|
||||
try:
|
||||
import histogram_pb2
|
||||
except ImportError as e:
|
||||
logging.exception(e)
|
||||
raise ImportError(
|
||||
'Could not import histogram_pb2. You need to build the '
|
||||
'low_bandwidth_audio_perf_test target before invoking '
|
||||
'this script. Expected to find '
|
||||
'histogram_pb2.py in %s.' % histogram_proto_path)
|
||||
|
||||
|
||||
def main():
|
||||
# pylint: disable=W0101
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('Invoked with %s', str(sys.argv))
|
||||
# pylint: disable=W0101
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('Invoked with %s', str(sys.argv))
|
||||
|
||||
args = _ParseArgs()
|
||||
args = _ParseArgs()
|
||||
|
||||
_ConfigurePythonPath(args)
|
||||
_ConfigurePythonPath(args)
|
||||
|
||||
# Import catapult modules here after configuring the pythonpath.
|
||||
from tracing.value import histogram_set
|
||||
from tracing.value.diagnostics import reserved_infos
|
||||
from tracing.value.diagnostics import generic_set
|
||||
# Import catapult modules here after configuring the pythonpath.
|
||||
from tracing.value import histogram_set
|
||||
from tracing.value.diagnostics import reserved_infos
|
||||
from tracing.value.diagnostics import generic_set
|
||||
|
||||
pesq_path, polqa_path = _GetPathToTools()
|
||||
if pesq_path is None:
|
||||
return 1
|
||||
pesq_path, polqa_path = _GetPathToTools()
|
||||
if pesq_path is None:
|
||||
return 1
|
||||
|
||||
out_dir = os.path.join(args.build_dir, '..')
|
||||
if args.android:
|
||||
test_command = [os.path.join(args.build_dir, 'bin',
|
||||
'run_low_bandwidth_audio_test'),
|
||||
'-v', '--num-retries', args.num_retries]
|
||||
else:
|
||||
test_command = [os.path.join(args.build_dir, 'low_bandwidth_audio_test')]
|
||||
out_dir = os.path.join(args.build_dir, '..')
|
||||
if args.android:
|
||||
test_command = [
|
||||
os.path.join(args.build_dir, 'bin',
|
||||
'run_low_bandwidth_audio_test'), '-v',
|
||||
'--num-retries', args.num_retries
|
||||
]
|
||||
else:
|
||||
test_command = [
|
||||
os.path.join(args.build_dir, 'low_bandwidth_audio_test')
|
||||
]
|
||||
|
||||
analyzers = [Analyzer('pesq', _RunPesq, pesq_path, 16000)]
|
||||
# Check if POLQA can run at all, or skip the 48 kHz tests entirely.
|
||||
example_path = os.path.join(SRC_DIR, 'resources',
|
||||
'voice_engine', 'audio_tiny48.wav')
|
||||
if polqa_path and _RunPolqa(polqa_path, example_path, example_path):
|
||||
analyzers.append(Analyzer('polqa', _RunPolqa, polqa_path, 48000))
|
||||
analyzers = [Analyzer('pesq', _RunPesq, pesq_path, 16000)]
|
||||
# Check if POLQA can run at all, or skip the 48 kHz tests entirely.
|
||||
example_path = os.path.join(SRC_DIR, 'resources', 'voice_engine',
|
||||
'audio_tiny48.wav')
|
||||
if polqa_path and _RunPolqa(polqa_path, example_path, example_path):
|
||||
analyzers.append(Analyzer('polqa', _RunPolqa, polqa_path, 48000))
|
||||
|
||||
histograms = histogram_set.HistogramSet()
|
||||
for analyzer in analyzers:
|
||||
# Start the test executable that produces audio files.
|
||||
test_process = subprocess.Popen(
|
||||
_LogCommand(test_command + [
|
||||
histograms = histogram_set.HistogramSet()
|
||||
for analyzer in analyzers:
|
||||
# Start the test executable that produces audio files.
|
||||
test_process = subprocess.Popen(_LogCommand(test_command + [
|
||||
'--sample_rate_hz=%d' % analyzer.sample_rate_hz,
|
||||
'--test_case_prefix=%s' % analyzer.name,
|
||||
] + args.extra_test_args),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
perf_results_file = None
|
||||
try:
|
||||
lines = iter(test_process.stdout.readline, '')
|
||||
for result in ExtractTestRuns(lines, echo=True):
|
||||
(android_device, test_name, reference_file, degraded_file,
|
||||
perf_results_file) = result
|
||||
] + args.extra_test_args),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
perf_results_file = None
|
||||
try:
|
||||
lines = iter(test_process.stdout.readline, '')
|
||||
for result in ExtractTestRuns(lines, echo=True):
|
||||
(android_device, test_name, reference_file, degraded_file,
|
||||
perf_results_file) = result
|
||||
|
||||
adb_prefix = (args.adb_path,)
|
||||
if android_device:
|
||||
adb_prefix += ('-s', android_device)
|
||||
adb_prefix = (args.adb_path, )
|
||||
if android_device:
|
||||
adb_prefix += ('-s', android_device)
|
||||
|
||||
reference_file = _GetFile(reference_file, out_dir,
|
||||
android=args.android, adb_prefix=adb_prefix)
|
||||
degraded_file = _GetFile(degraded_file, out_dir, move=True,
|
||||
android=args.android, adb_prefix=adb_prefix)
|
||||
reference_file = _GetFile(reference_file,
|
||||
out_dir,
|
||||
android=args.android,
|
||||
adb_prefix=adb_prefix)
|
||||
degraded_file = _GetFile(degraded_file,
|
||||
out_dir,
|
||||
move=True,
|
||||
android=args.android,
|
||||
adb_prefix=adb_prefix)
|
||||
|
||||
analyzer_results = analyzer.func(analyzer.executable,
|
||||
reference_file, degraded_file)
|
||||
for metric, (value, units) in analyzer_results.items():
|
||||
hist = histograms.CreateHistogram(metric, units, [value])
|
||||
user_story = generic_set.GenericSet([test_name])
|
||||
hist.diagnostics[reserved_infos.STORIES.name] = user_story
|
||||
analyzer_results = analyzer.func(analyzer.executable,
|
||||
reference_file, degraded_file)
|
||||
for metric, (value, units) in analyzer_results.items():
|
||||
hist = histograms.CreateHistogram(metric, units, [value])
|
||||
user_story = generic_set.GenericSet([test_name])
|
||||
hist.diagnostics[reserved_infos.STORIES.name] = user_story
|
||||
|
||||
# Output human readable results.
|
||||
print 'RESULT %s: %s= %s %s' % (metric, test_name, value, units)
|
||||
# Output human readable results.
|
||||
print 'RESULT %s: %s= %s %s' % (metric, test_name, value,
|
||||
units)
|
||||
|
||||
if args.remove:
|
||||
os.remove(reference_file)
|
||||
os.remove(degraded_file)
|
||||
finally:
|
||||
test_process.terminate()
|
||||
if perf_results_file:
|
||||
perf_results_file = _GetFile(perf_results_file, out_dir, move=True,
|
||||
android=args.android, adb_prefix=adb_prefix)
|
||||
_MergeInPerfResultsFromCcTests(histograms, perf_results_file)
|
||||
if args.remove:
|
||||
os.remove(perf_results_file)
|
||||
if args.remove:
|
||||
os.remove(reference_file)
|
||||
os.remove(degraded_file)
|
||||
finally:
|
||||
test_process.terminate()
|
||||
if perf_results_file:
|
||||
perf_results_file = _GetFile(perf_results_file,
|
||||
out_dir,
|
||||
move=True,
|
||||
android=args.android,
|
||||
adb_prefix=adb_prefix)
|
||||
_MergeInPerfResultsFromCcTests(histograms, perf_results_file)
|
||||
if args.remove:
|
||||
os.remove(perf_results_file)
|
||||
|
||||
if args.isolated_script_test_perf_output:
|
||||
with open(args.isolated_script_test_perf_output, 'wb') as f:
|
||||
f.write(histograms.AsProto().SerializeToString())
|
||||
if args.isolated_script_test_perf_output:
|
||||
with open(args.isolated_script_test_perf_output, 'wb') as f:
|
||||
f.write(histograms.AsProto().SerializeToString())
|
||||
|
||||
return test_process.wait()
|
||||
return test_process.wait()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -11,7 +11,6 @@ import os
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PARENT_DIR = os.path.join(SCRIPT_DIR, os.pardir)
|
||||
sys.path.append(PARENT_DIR)
|
||||
@ -19,46 +18,51 @@ import low_bandwidth_audio_test
|
||||
|
||||
|
||||
class TestExtractTestRuns(unittest.TestCase):
|
||||
def _TestLog(self, log, *expected):
|
||||
self.assertEqual(
|
||||
tuple(low_bandwidth_audio_test.ExtractTestRuns(log.splitlines(True))),
|
||||
expected)
|
||||
def _TestLog(self, log, *expected):
|
||||
self.assertEqual(
|
||||
tuple(
|
||||
low_bandwidth_audio_test.ExtractTestRuns(
|
||||
log.splitlines(True))), expected)
|
||||
|
||||
def testLinux(self):
|
||||
self._TestLog(LINUX_LOG,
|
||||
(None, 'GoodNetworkHighBitrate',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/LowBandwidth_GoodNetworkHighBitrate.wav', None),
|
||||
(None, 'Mobile2GNetwork',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/LowBandwidth_Mobile2GNetwork.wav', None),
|
||||
(None, 'PCGoodNetworkHighBitrate',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_PCGoodNetworkHighBitrate.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_perf_48.json'),
|
||||
(None, 'PCMobile2GNetwork',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_PCMobile2GNetwork.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_perf_48.json'))
|
||||
def testLinux(self):
|
||||
self._TestLog(
|
||||
LINUX_LOG,
|
||||
(None, 'GoodNetworkHighBitrate',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/LowBandwidth_GoodNetworkHighBitrate.wav', None),
|
||||
(None, 'Mobile2GNetwork',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/LowBandwidth_Mobile2GNetwork.wav', None),
|
||||
(None, 'PCGoodNetworkHighBitrate',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_PCGoodNetworkHighBitrate.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_perf_48.json'),
|
||||
(None, 'PCMobile2GNetwork',
|
||||
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_PCMobile2GNetwork.wav',
|
||||
'/webrtc/src/out/PCLowBandwidth_perf_48.json'))
|
||||
|
||||
def testAndroid(self):
|
||||
self._TestLog(ANDROID_LOG,
|
||||
('ddfa6149', 'Mobile2GNetwork',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
'/sdcard/chromium_tests_root/LowBandwidth_Mobile2GNetwork.wav', None),
|
||||
('TA99205CNO', 'GoodNetworkHighBitrate',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
'/sdcard/chromium_tests_root/LowBandwidth_GoodNetworkHighBitrate.wav',
|
||||
None),
|
||||
('ddfa6149', 'PCMobile2GNetwork',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
'/sdcard/chromium_tests_root/PCLowBandwidth_PCMobile2GNetwork.wav',
|
||||
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'),
|
||||
('TA99205CNO', 'PCGoodNetworkHighBitrate',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
('/sdcard/chromium_tests_root/'
|
||||
'PCLowBandwidth_PCGoodNetworkHighBitrate.wav'),
|
||||
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'))
|
||||
def testAndroid(self):
|
||||
self._TestLog(ANDROID_LOG, (
|
||||
'ddfa6149', 'Mobile2GNetwork',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
'/sdcard/chromium_tests_root/LowBandwidth_Mobile2GNetwork.wav',
|
||||
None
|
||||
), (
|
||||
'TA99205CNO', 'GoodNetworkHighBitrate',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
'/sdcard/chromium_tests_root/LowBandwidth_GoodNetworkHighBitrate.wav',
|
||||
None
|
||||
), (
|
||||
'ddfa6149', 'PCMobile2GNetwork',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
'/sdcard/chromium_tests_root/PCLowBandwidth_PCMobile2GNetwork.wav',
|
||||
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'
|
||||
), ('TA99205CNO', 'PCGoodNetworkHighBitrate',
|
||||
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
|
||||
('/sdcard/chromium_tests_root/'
|
||||
'PCLowBandwidth_PCGoodNetworkHighBitrate.wav'),
|
||||
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'))
|
||||
|
||||
|
||||
LINUX_LOG = r'''\
|
||||
@ -233,6 +237,5 @@ I 16.608s tear_down_device(ddfa6149) Wrote device cache: /webrtc/src/out/debu
|
||||
I 16.608s tear_down_device(TA99205CNO) Wrote device cache: /webrtc/src/out/debug-android/device_cache_TA99305CMO.json
|
||||
'''
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -15,110 +15,113 @@ import time
|
||||
|
||||
from com.android.monkeyrunner import MonkeyRunner, MonkeyDevice
|
||||
|
||||
|
||||
def main():
|
||||
parser = OptionParser()
|
||||
parser = OptionParser()
|
||||
|
||||
parser.add_option('--devname', dest='devname', help='The device id')
|
||||
parser.add_option('--devname', dest='devname', help='The device id')
|
||||
|
||||
parser.add_option(
|
||||
'--videooutsave',
|
||||
dest='videooutsave',
|
||||
help='The path where to save the video out file on local computer')
|
||||
parser.add_option(
|
||||
'--videooutsave',
|
||||
dest='videooutsave',
|
||||
help='The path where to save the video out file on local computer')
|
||||
|
||||
parser.add_option(
|
||||
'--videoout',
|
||||
dest='videoout',
|
||||
help='The path where to put the video out file')
|
||||
parser.add_option('--videoout',
|
||||
dest='videoout',
|
||||
help='The path where to put the video out file')
|
||||
|
||||
parser.add_option(
|
||||
'--videoout_width',
|
||||
dest='videoout_width',
|
||||
type='int',
|
||||
help='The width for the video out file')
|
||||
parser.add_option('--videoout_width',
|
||||
dest='videoout_width',
|
||||
type='int',
|
||||
help='The width for the video out file')
|
||||
|
||||
parser.add_option(
|
||||
'--videoout_height',
|
||||
dest='videoout_height',
|
||||
type='int',
|
||||
help='The height for the video out file')
|
||||
parser.add_option('--videoout_height',
|
||||
dest='videoout_height',
|
||||
type='int',
|
||||
help='The height for the video out file')
|
||||
|
||||
parser.add_option(
|
||||
'--videoin',
|
||||
dest='videoin',
|
||||
help='The path where to read input file instead of camera')
|
||||
parser.add_option(
|
||||
'--videoin',
|
||||
dest='videoin',
|
||||
help='The path where to read input file instead of camera')
|
||||
|
||||
parser.add_option(
|
||||
'--call_length',
|
||||
dest='call_length',
|
||||
type='int',
|
||||
help='The length of the call')
|
||||
parser.add_option('--call_length',
|
||||
dest='call_length',
|
||||
type='int',
|
||||
help='The length of the call')
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
print (options, args)
|
||||
print(options, args)
|
||||
|
||||
devname = options.devname
|
||||
devname = options.devname
|
||||
|
||||
videoin = options.videoin
|
||||
videoin = options.videoin
|
||||
|
||||
videoout = options.videoout
|
||||
videoout_width = options.videoout_width
|
||||
videoout_height = options.videoout_height
|
||||
videoout = options.videoout
|
||||
videoout_width = options.videoout_width
|
||||
videoout_height = options.videoout_height
|
||||
|
||||
videooutsave = options.videooutsave
|
||||
videooutsave = options.videooutsave
|
||||
|
||||
call_length = options.call_length or 10
|
||||
call_length = options.call_length or 10
|
||||
|
||||
room = ''.join(random.choice(string.ascii_letters + string.digits)
|
||||
for _ in range(8))
|
||||
room = ''.join(
|
||||
random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
|
||||
# Delete output video file.
|
||||
if videoout:
|
||||
subprocess.check_call(['adb', '-s', devname, 'shell', 'rm', '-f',
|
||||
videoout])
|
||||
# Delete output video file.
|
||||
if videoout:
|
||||
subprocess.check_call(
|
||||
['adb', '-s', devname, 'shell', 'rm', '-f', videoout])
|
||||
|
||||
device = MonkeyRunner.waitForConnection(2, devname)
|
||||
device = MonkeyRunner.waitForConnection(2, devname)
|
||||
|
||||
extras = {
|
||||
'org.appspot.apprtc.USE_VALUES_FROM_INTENT': True,
|
||||
'org.appspot.apprtc.AUDIOCODEC': 'OPUS',
|
||||
'org.appspot.apprtc.LOOPBACK': True,
|
||||
'org.appspot.apprtc.VIDEOCODEC': 'VP8',
|
||||
'org.appspot.apprtc.CAPTURETOTEXTURE': False,
|
||||
'org.appspot.apprtc.CAMERA2': False,
|
||||
'org.appspot.apprtc.ROOMID': room}
|
||||
extras = {
|
||||
'org.appspot.apprtc.USE_VALUES_FROM_INTENT': True,
|
||||
'org.appspot.apprtc.AUDIOCODEC': 'OPUS',
|
||||
'org.appspot.apprtc.LOOPBACK': True,
|
||||
'org.appspot.apprtc.VIDEOCODEC': 'VP8',
|
||||
'org.appspot.apprtc.CAPTURETOTEXTURE': False,
|
||||
'org.appspot.apprtc.CAMERA2': False,
|
||||
'org.appspot.apprtc.ROOMID': room
|
||||
}
|
||||
|
||||
if videoin:
|
||||
extras.update({'org.appspot.apprtc.VIDEO_FILE_AS_CAMERA': videoin})
|
||||
if videoin:
|
||||
extras.update({'org.appspot.apprtc.VIDEO_FILE_AS_CAMERA': videoin})
|
||||
|
||||
if videoout:
|
||||
extras.update({
|
||||
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE': videoout,
|
||||
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_WIDTH': videoout_width,
|
||||
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_HEIGHT': videoout_height})
|
||||
if videoout:
|
||||
extras.update({
|
||||
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE':
|
||||
videoout,
|
||||
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_WIDTH':
|
||||
videoout_width,
|
||||
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_HEIGHT':
|
||||
videoout_height
|
||||
})
|
||||
|
||||
print extras
|
||||
print extras
|
||||
|
||||
device.startActivity(data='https://appr.tc',
|
||||
action='android.intent.action.VIEW',
|
||||
component='org.appspot.apprtc/.ConnectActivity', extras=extras)
|
||||
device.startActivity(data='https://appr.tc',
|
||||
action='android.intent.action.VIEW',
|
||||
component='org.appspot.apprtc/.ConnectActivity',
|
||||
extras=extras)
|
||||
|
||||
print 'Running a call for %d seconds' % call_length
|
||||
for _ in xrange(call_length):
|
||||
sys.stdout.write('.')
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
print '\nEnding call.'
|
||||
print 'Running a call for %d seconds' % call_length
|
||||
for _ in xrange(call_length):
|
||||
sys.stdout.write('.')
|
||||
sys.stdout.flush()
|
||||
time.sleep(1)
|
||||
print '\nEnding call.'
|
||||
|
||||
# Press back to end the call. Will end on both sides.
|
||||
device.press('KEYCODE_BACK', MonkeyDevice.DOWN_AND_UP)
|
||||
# Press back to end the call. Will end on both sides.
|
||||
device.press('KEYCODE_BACK', MonkeyDevice.DOWN_AND_UP)
|
||||
|
||||
if videooutsave:
|
||||
time.sleep(2)
|
||||
if videooutsave:
|
||||
time.sleep(2)
|
||||
|
||||
subprocess.check_call(
|
||||
['adb', '-s', devname, 'pull', videoout, videooutsave])
|
||||
|
||||
subprocess.check_call(['adb', '-s', devname, 'pull',
|
||||
videoout, videooutsave])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""
|
||||
This scripts tests creating an Android Studio project using the
|
||||
generate_gradle.py script and making a debug build using it.
|
||||
@ -23,58 +22,59 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
|
||||
GENERATE_GRADLE_SCRIPT = os.path.join(SRC_DIR,
|
||||
'build/android/gradle/generate_gradle.py')
|
||||
GENERATE_GRADLE_SCRIPT = os.path.join(
|
||||
SRC_DIR, 'build/android/gradle/generate_gradle.py')
|
||||
GRADLEW_BIN = os.path.join(SCRIPT_DIR, 'third_party/gradle/gradlew')
|
||||
|
||||
|
||||
def _RunCommand(argv, cwd=SRC_DIR, **kwargs):
|
||||
logging.info('Running %r', argv)
|
||||
subprocess.check_call(argv, cwd=cwd, **kwargs)
|
||||
logging.info('Running %r', argv)
|
||||
subprocess.check_call(argv, cwd=cwd, **kwargs)
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Test generating Android gradle project.')
|
||||
parser.add_argument('build_dir_android',
|
||||
help='The path to the build directory for Android.')
|
||||
parser.add_argument('--project_dir',
|
||||
help='A temporary directory to put the output.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Test generating Android gradle project.')
|
||||
parser.add_argument('build_dir_android',
|
||||
help='The path to the build directory for Android.')
|
||||
parser.add_argument('--project_dir',
|
||||
help='A temporary directory to put the output.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = _ParseArgs()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = _ParseArgs()
|
||||
|
||||
project_dir = args.project_dir
|
||||
if not project_dir:
|
||||
project_dir = tempfile.mkdtemp()
|
||||
project_dir = args.project_dir
|
||||
if not project_dir:
|
||||
project_dir = tempfile.mkdtemp()
|
||||
|
||||
output_dir = os.path.abspath(args.build_dir_android)
|
||||
project_dir = os.path.abspath(project_dir)
|
||||
output_dir = os.path.abspath(args.build_dir_android)
|
||||
project_dir = os.path.abspath(project_dir)
|
||||
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PATH'] = os.pathsep.join([
|
||||
os.path.join(SRC_DIR, 'third_party', 'depot_tools'), env.get('PATH', '')
|
||||
])
|
||||
_RunCommand([GENERATE_GRADLE_SCRIPT, '--output-directory', output_dir,
|
||||
'--target', '//examples:AppRTCMobile',
|
||||
'--project-dir', project_dir,
|
||||
'--use-gradle-process-resources', '--split-projects'],
|
||||
env=env)
|
||||
_RunCommand([GRADLEW_BIN, 'assembleDebug'], project_dir)
|
||||
finally:
|
||||
# Do not delete temporary directory if user specified it manually.
|
||||
if not args.project_dir:
|
||||
shutil.rmtree(project_dir, True)
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PATH'] = os.pathsep.join([
|
||||
os.path.join(SRC_DIR, 'third_party', 'depot_tools'),
|
||||
env.get('PATH', '')
|
||||
])
|
||||
_RunCommand([
|
||||
GENERATE_GRADLE_SCRIPT, '--output-directory', output_dir,
|
||||
'--target', '//examples:AppRTCMobile', '--project-dir',
|
||||
project_dir, '--use-gradle-process-resources', '--split-projects'
|
||||
],
|
||||
env=env)
|
||||
_RunCommand([GRADLEW_BIN, 'assembleDebug'], project_dir)
|
||||
finally:
|
||||
# Do not delete temporary directory if user specified it manually.
|
||||
if not args.project_dir:
|
||||
shutil.rmtree(project_dir, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -24,124 +24,126 @@ import debug_dump_pb2
|
||||
|
||||
|
||||
def GetNextMessageSize(file_to_parse):
|
||||
data = file_to_parse.read(4)
|
||||
if data == '':
|
||||
return 0
|
||||
return struct.unpack('<I', data)[0]
|
||||
data = file_to_parse.read(4)
|
||||
if data == '':
|
||||
return 0
|
||||
return struct.unpack('<I', data)[0]
|
||||
|
||||
|
||||
def GetNextMessageFromFile(file_to_parse):
|
||||
message_size = GetNextMessageSize(file_to_parse)
|
||||
if message_size == 0:
|
||||
return None
|
||||
try:
|
||||
event = debug_dump_pb2.Event()
|
||||
event.ParseFromString(file_to_parse.read(message_size))
|
||||
except IOError:
|
||||
print 'Invalid message in file'
|
||||
return None
|
||||
return event
|
||||
message_size = GetNextMessageSize(file_to_parse)
|
||||
if message_size == 0:
|
||||
return None
|
||||
try:
|
||||
event = debug_dump_pb2.Event()
|
||||
event.ParseFromString(file_to_parse.read(message_size))
|
||||
except IOError:
|
||||
print 'Invalid message in file'
|
||||
return None
|
||||
return event
|
||||
|
||||
|
||||
def InitMetrics():
|
||||
metrics = {}
|
||||
event = debug_dump_pb2.Event()
|
||||
for metric in event.network_metrics.DESCRIPTOR.fields:
|
||||
metrics[metric.name] = {'time': [], 'value': []}
|
||||
return metrics
|
||||
metrics = {}
|
||||
event = debug_dump_pb2.Event()
|
||||
for metric in event.network_metrics.DESCRIPTOR.fields:
|
||||
metrics[metric.name] = {'time': [], 'value': []}
|
||||
return metrics
|
||||
|
||||
|
||||
def InitDecisions():
|
||||
decisions = {}
|
||||
event = debug_dump_pb2.Event()
|
||||
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
|
||||
decisions[decision.name] = {'time': [], 'value': []}
|
||||
return decisions
|
||||
decisions = {}
|
||||
event = debug_dump_pb2.Event()
|
||||
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
|
||||
decisions[decision.name] = {'time': [], 'value': []}
|
||||
return decisions
|
||||
|
||||
|
||||
def ParseAnaDump(dump_file_to_parse):
|
||||
with open(dump_file_to_parse, 'rb') as file_to_parse:
|
||||
metrics = InitMetrics()
|
||||
decisions = InitDecisions()
|
||||
first_time_stamp = None
|
||||
while True:
|
||||
event = GetNextMessageFromFile(file_to_parse)
|
||||
if event is None:
|
||||
break
|
||||
if first_time_stamp is None:
|
||||
first_time_stamp = event.timestamp
|
||||
if event.type == debug_dump_pb2.Event.ENCODER_RUNTIME_CONFIG:
|
||||
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
|
||||
if event.encoder_runtime_config.HasField(decision.name):
|
||||
decisions[decision.name]['time'].append(event.timestamp -
|
||||
first_time_stamp)
|
||||
decisions[decision.name]['value'].append(
|
||||
getattr(event.encoder_runtime_config, decision.name))
|
||||
if event.type == debug_dump_pb2.Event.NETWORK_METRICS:
|
||||
for metric in event.network_metrics.DESCRIPTOR.fields:
|
||||
if event.network_metrics.HasField(metric.name):
|
||||
metrics[metric.name]['time'].append(event.timestamp -
|
||||
first_time_stamp)
|
||||
metrics[metric.name]['value'].append(
|
||||
getattr(event.network_metrics, metric.name))
|
||||
return (metrics, decisions)
|
||||
with open(dump_file_to_parse, 'rb') as file_to_parse:
|
||||
metrics = InitMetrics()
|
||||
decisions = InitDecisions()
|
||||
first_time_stamp = None
|
||||
while True:
|
||||
event = GetNextMessageFromFile(file_to_parse)
|
||||
if event is None:
|
||||
break
|
||||
if first_time_stamp is None:
|
||||
first_time_stamp = event.timestamp
|
||||
if event.type == debug_dump_pb2.Event.ENCODER_RUNTIME_CONFIG:
|
||||
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
|
||||
if event.encoder_runtime_config.HasField(decision.name):
|
||||
decisions[decision.name]['time'].append(
|
||||
event.timestamp - first_time_stamp)
|
||||
decisions[decision.name]['value'].append(
|
||||
getattr(event.encoder_runtime_config,
|
||||
decision.name))
|
||||
if event.type == debug_dump_pb2.Event.NETWORK_METRICS:
|
||||
for metric in event.network_metrics.DESCRIPTOR.fields:
|
||||
if event.network_metrics.HasField(metric.name):
|
||||
metrics[metric.name]['time'].append(event.timestamp -
|
||||
first_time_stamp)
|
||||
metrics[metric.name]['value'].append(
|
||||
getattr(event.network_metrics, metric.name))
|
||||
return (metrics, decisions)
|
||||
|
||||
|
||||
def main():
|
||||
parser = OptionParser()
|
||||
parser.add_option(
|
||||
"-f", "--dump_file", dest="dump_file_to_parse", help="dump file to parse")
|
||||
parser.add_option(
|
||||
'-m',
|
||||
'--metric_plot',
|
||||
default=[],
|
||||
type=str,
|
||||
help='metric key (name of the metric) to plot',
|
||||
dest='metric_keys',
|
||||
action='append')
|
||||
parser = OptionParser()
|
||||
parser.add_option("-f",
|
||||
"--dump_file",
|
||||
dest="dump_file_to_parse",
|
||||
help="dump file to parse")
|
||||
parser.add_option('-m',
|
||||
'--metric_plot',
|
||||
default=[],
|
||||
type=str,
|
||||
help='metric key (name of the metric) to plot',
|
||||
dest='metric_keys',
|
||||
action='append')
|
||||
|
||||
parser.add_option(
|
||||
'-d',
|
||||
'--decision_plot',
|
||||
default=[],
|
||||
type=str,
|
||||
help='decision key (name of the decision) to plot',
|
||||
dest='decision_keys',
|
||||
action='append')
|
||||
parser.add_option('-d',
|
||||
'--decision_plot',
|
||||
default=[],
|
||||
type=str,
|
||||
help='decision key (name of the decision) to plot',
|
||||
dest='decision_keys',
|
||||
action='append')
|
||||
|
||||
options = parser.parse_args()[0]
|
||||
if options.dump_file_to_parse is None:
|
||||
print "No dump file to parse is set.\n"
|
||||
parser.print_help()
|
||||
exit()
|
||||
(metrics, decisions) = ParseAnaDump(options.dump_file_to_parse)
|
||||
metric_keys = options.metric_keys
|
||||
decision_keys = options.decision_keys
|
||||
plot_count = len(metric_keys) + len(decision_keys)
|
||||
if plot_count == 0:
|
||||
print "You have to set at least one metric or decision to plot.\n"
|
||||
parser.print_help()
|
||||
exit()
|
||||
plots = []
|
||||
if plot_count == 1:
|
||||
f, mp_plot = plt.subplots()
|
||||
plots.append(mp_plot)
|
||||
else:
|
||||
f, mp_plots = plt.subplots(plot_count, sharex=True)
|
||||
plots.extend(mp_plots.tolist())
|
||||
options = parser.parse_args()[0]
|
||||
if options.dump_file_to_parse is None:
|
||||
print "No dump file to parse is set.\n"
|
||||
parser.print_help()
|
||||
exit()
|
||||
(metrics, decisions) = ParseAnaDump(options.dump_file_to_parse)
|
||||
metric_keys = options.metric_keys
|
||||
decision_keys = options.decision_keys
|
||||
plot_count = len(metric_keys) + len(decision_keys)
|
||||
if plot_count == 0:
|
||||
print "You have to set at least one metric or decision to plot.\n"
|
||||
parser.print_help()
|
||||
exit()
|
||||
plots = []
|
||||
if plot_count == 1:
|
||||
f, mp_plot = plt.subplots()
|
||||
plots.append(mp_plot)
|
||||
else:
|
||||
f, mp_plots = plt.subplots(plot_count, sharex=True)
|
||||
plots.extend(mp_plots.tolist())
|
||||
|
||||
for key in metric_keys:
|
||||
plot = plots.pop()
|
||||
plot.grid(True)
|
||||
plot.set_title(key + " (metric)")
|
||||
plot.plot(metrics[key]['time'], metrics[key]['value'])
|
||||
for key in decision_keys:
|
||||
plot = plots.pop()
|
||||
plot.grid(True)
|
||||
plot.set_title(key + " (decision)")
|
||||
plot.plot(decisions[key]['time'], decisions[key]['value'])
|
||||
f.subplots_adjust(hspace=0.3)
|
||||
plt.show()
|
||||
|
||||
for key in metric_keys:
|
||||
plot = plots.pop()
|
||||
plot.grid(True)
|
||||
plot.set_title(key + " (metric)")
|
||||
plot.plot(metrics[key]['time'], metrics[key]['value'])
|
||||
for key in decision_keys:
|
||||
plot = plots.pop()
|
||||
plot.grid(True)
|
||||
plot.set_title(key + " (decision)")
|
||||
plot.plot(decisions[key]['time'], decisions[key]['value'])
|
||||
f.subplots_adjust(hspace=0.3)
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Perform APM module quality assessment on one or more input files using one or
|
||||
more APM simulator configuration files and one or more test data generators.
|
||||
|
||||
@ -47,139 +46,172 @@ _POLQA_BIN_NAME = 'PolqaOem64'
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Perform APM module quality assessment on one or more input files using '
|
||||
'one or more APM simulator configuration files and one or more '
|
||||
'test data generators.'))
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Perform APM module quality assessment on one or more input files using '
|
||||
'one or more APM simulator configuration files and one or more '
|
||||
'test data generators.'))
|
||||
|
||||
parser.add_argument('-c', '--config_files', nargs='+', required=False,
|
||||
help=('path to the configuration files defining the '
|
||||
'arguments with which the APM simulator tool is '
|
||||
'called'),
|
||||
default=[_DEFAULT_CONFIG_FILE])
|
||||
parser.add_argument('-c',
|
||||
'--config_files',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('path to the configuration files defining the '
|
||||
'arguments with which the APM simulator tool is '
|
||||
'called'),
|
||||
default=[_DEFAULT_CONFIG_FILE])
|
||||
|
||||
parser.add_argument('-i', '--capture_input_files', nargs='+', required=True,
|
||||
help='path to the capture input wav files (one or more)')
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--capture_input_files',
|
||||
nargs='+',
|
||||
required=True,
|
||||
help='path to the capture input wav files (one or more)')
|
||||
|
||||
parser.add_argument('-r', '--render_input_files', nargs='+', required=False,
|
||||
help=('path to the render input wav files; either '
|
||||
'omitted or one file for each file in '
|
||||
'--capture_input_files (files will be paired by '
|
||||
'index)'), default=None)
|
||||
parser.add_argument('-r',
|
||||
'--render_input_files',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('path to the render input wav files; either '
|
||||
'omitted or one file for each file in '
|
||||
'--capture_input_files (files will be paired by '
|
||||
'index)'),
|
||||
default=None)
|
||||
|
||||
parser.add_argument('-p', '--echo_path_simulator', required=False,
|
||||
help=('custom echo path simulator name; required if '
|
||||
'--render_input_files is specified'),
|
||||
choices=_ECHO_PATH_SIMULATOR_NAMES,
|
||||
default=echo_path_simulation.NoEchoPathSimulator.NAME)
|
||||
parser.add_argument('-p',
|
||||
'--echo_path_simulator',
|
||||
required=False,
|
||||
help=('custom echo path simulator name; required if '
|
||||
'--render_input_files is specified'),
|
||||
choices=_ECHO_PATH_SIMULATOR_NAMES,
|
||||
default=echo_path_simulation.NoEchoPathSimulator.NAME)
|
||||
|
||||
parser.add_argument('-t', '--test_data_generators', nargs='+', required=False,
|
||||
help='custom list of test data generators to use',
|
||||
choices=_TEST_DATA_GENERATORS_NAMES,
|
||||
default=_TEST_DATA_GENERATORS_NAMES)
|
||||
parser.add_argument('-t',
|
||||
'--test_data_generators',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help='custom list of test data generators to use',
|
||||
choices=_TEST_DATA_GENERATORS_NAMES,
|
||||
default=_TEST_DATA_GENERATORS_NAMES)
|
||||
|
||||
parser.add_argument('--additive_noise_tracks_path', required=False,
|
||||
help='path to the wav files for the additive',
|
||||
default=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH)
|
||||
parser.add_argument('--additive_noise_tracks_path', required=False,
|
||||
help='path to the wav files for the additive',
|
||||
default=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH)
|
||||
|
||||
parser.add_argument('-e', '--eval_scores', nargs='+', required=False,
|
||||
help='custom list of evaluation scores to use',
|
||||
choices=_EVAL_SCORE_WORKER_NAMES,
|
||||
default=_EVAL_SCORE_WORKER_NAMES)
|
||||
parser.add_argument('-e',
|
||||
'--eval_scores',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help='custom list of evaluation scores to use',
|
||||
choices=_EVAL_SCORE_WORKER_NAMES,
|
||||
default=_EVAL_SCORE_WORKER_NAMES)
|
||||
|
||||
parser.add_argument('-o', '--output_dir', required=False,
|
||||
help=('base path to the output directory in which the '
|
||||
'output wav files and the evaluation outcomes '
|
||||
'are saved'),
|
||||
default='output')
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
required=False,
|
||||
help=('base path to the output directory in which the '
|
||||
'output wav files and the evaluation outcomes '
|
||||
'are saved'),
|
||||
default='output')
|
||||
|
||||
parser.add_argument('--polqa_path', required=True,
|
||||
help='path to the POLQA tool')
|
||||
parser.add_argument('--polqa_path',
|
||||
required=True,
|
||||
help='path to the POLQA tool')
|
||||
|
||||
parser.add_argument('--air_db_path', required=True,
|
||||
help='path to the Aechen IR database')
|
||||
parser.add_argument('--air_db_path',
|
||||
required=True,
|
||||
help='path to the Aechen IR database')
|
||||
|
||||
parser.add_argument('--apm_sim_path', required=False,
|
||||
help='path to the APM simulator tool',
|
||||
default=audioproc_wrapper. \
|
||||
AudioProcWrapper. \
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
parser.add_argument('--apm_sim_path', required=False,
|
||||
help='path to the APM simulator tool',
|
||||
default=audioproc_wrapper. \
|
||||
AudioProcWrapper. \
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
|
||||
parser.add_argument('--echo_metric_tool_bin_path', required=False,
|
||||
help=('path to the echo metric binary '
|
||||
'(required for the echo eval score)'),
|
||||
default=None)
|
||||
parser.add_argument('--echo_metric_tool_bin_path',
|
||||
required=False,
|
||||
help=('path to the echo metric binary '
|
||||
'(required for the echo eval score)'),
|
||||
default=None)
|
||||
|
||||
parser.add_argument('--copy_with_identity_generator', required=False,
|
||||
help=('If true, the identity test data generator makes a '
|
||||
'copy of the clean speech input file.'),
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--copy_with_identity_generator',
|
||||
required=False,
|
||||
help=('If true, the identity test data generator makes a '
|
||||
'copy of the clean speech input file.'),
|
||||
default=False)
|
||||
|
||||
parser.add_argument('--external_vad_paths', nargs='+', required=False,
|
||||
help=('Paths to external VAD programs. Each must take'
|
||||
'\'-i <wav file> -o <output>\' inputs'), default=[])
|
||||
parser.add_argument('--external_vad_paths',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('Paths to external VAD programs. Each must take'
|
||||
'\'-i <wav file> -o <output>\' inputs'),
|
||||
default=[])
|
||||
|
||||
parser.add_argument('--external_vad_names', nargs='+', required=False,
|
||||
help=('Keys to the vad paths. Must be different and '
|
||||
'as many as the paths.'), default=[])
|
||||
parser.add_argument('--external_vad_names',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('Keys to the vad paths. Must be different and '
|
||||
'as many as the paths.'),
|
||||
default=[])
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def _ValidateArguments(args, parser):
|
||||
if args.capture_input_files and args.render_input_files and (
|
||||
len(args.capture_input_files) != len(args.render_input_files)):
|
||||
parser.error('--render_input_files and --capture_input_files must be lists '
|
||||
'having the same length')
|
||||
sys.exit(1)
|
||||
if args.capture_input_files and args.render_input_files and (len(
|
||||
args.capture_input_files) != len(args.render_input_files)):
|
||||
parser.error(
|
||||
'--render_input_files and --capture_input_files must be lists '
|
||||
'having the same length')
|
||||
sys.exit(1)
|
||||
|
||||
if args.render_input_files and not args.echo_path_simulator:
|
||||
parser.error('when --render_input_files is set, --echo_path_simulator is '
|
||||
'also required')
|
||||
sys.exit(1)
|
||||
if args.render_input_files and not args.echo_path_simulator:
|
||||
parser.error(
|
||||
'when --render_input_files is set, --echo_path_simulator is '
|
||||
'also required')
|
||||
sys.exit(1)
|
||||
|
||||
if len(args.external_vad_names) != len(args.external_vad_paths):
|
||||
parser.error('If provided, --external_vad_paths and '
|
||||
'--external_vad_names must '
|
||||
'have the same number of arguments.')
|
||||
sys.exit(1)
|
||||
if len(args.external_vad_names) != len(args.external_vad_paths):
|
||||
parser.error('If provided, --external_vad_paths and '
|
||||
'--external_vad_names must '
|
||||
'have the same number of arguments.')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
# TODO(alessiob): level = logging.INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
_ValidateArguments(args, parser)
|
||||
# TODO(alessiob): level = logging.INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
_ValidateArguments(args, parser)
|
||||
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=args.air_db_path,
|
||||
noise_tracks_path=args.additive_noise_tracks_path,
|
||||
copy_with_identity=args.copy_with_identity_generator)),
|
||||
evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
|
||||
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path
|
||||
),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
|
||||
evaluator=evaluation.ApmModuleEvaluator(),
|
||||
external_vads=external_vad.ExternalVad.ConstructVadDict(
|
||||
args.external_vad_paths, args.external_vad_names))
|
||||
simulator.Run(
|
||||
config_filepaths=args.config_files,
|
||||
capture_input_filepaths=args.capture_input_files,
|
||||
render_input_filepaths=args.render_input_files,
|
||||
echo_path_simulator_name=args.echo_path_simulator,
|
||||
test_data_generator_names=args.test_data_generators,
|
||||
eval_score_names=args.eval_scores,
|
||||
output_dir=args.output_dir)
|
||||
sys.exit(0)
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=args.air_db_path,
|
||||
noise_tracks_path=args.additive_noise_tracks_path,
|
||||
copy_with_identity=args.copy_with_identity_generator)),
|
||||
evaluation_score_factory=eval_scores_factory.
|
||||
EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
|
||||
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
|
||||
evaluator=evaluation.ApmModuleEvaluator(),
|
||||
external_vads=external_vad.ExternalVad.ConstructVadDict(
|
||||
args.external_vad_paths, args.external_vad_names))
|
||||
simulator.Run(config_filepaths=args.config_files,
|
||||
capture_input_filepaths=args.capture_input_files,
|
||||
render_input_filepaths=args.render_input_files,
|
||||
echo_path_simulator_name=args.echo_path_simulator,
|
||||
test_data_generator_names=args.test_data_generators,
|
||||
eval_score_names=args.eval_scores,
|
||||
output_dir=args.output_dir)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Shows boxplots of given score for different values of selected
|
||||
parameters. Can be used to compare scores by audioproc_f flag.
|
||||
|
||||
@ -30,29 +29,37 @@ import quality_assessment.collect_data as collect_data
|
||||
|
||||
|
||||
def InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Shows boxplot of given score for different values of selected'
|
||||
'parameters. Can be used to compare scores by audioproc_f flag')
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Shows boxplot of given score for different values of selected'
|
||||
'parameters. Can be used to compare scores by audioproc_f flag')
|
||||
|
||||
parser.add_argument('-v', '--eval_score', required=True,
|
||||
help=('Score name for constructing boxplots'))
|
||||
parser.add_argument('-v',
|
||||
'--eval_score',
|
||||
required=True,
|
||||
help=('Score name for constructing boxplots'))
|
||||
|
||||
parser.add_argument('-n', '--config_dir', required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--config_dir',
|
||||
required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
|
||||
parser.add_argument('-z', '--params_to_plot', required=True,
|
||||
nargs='+', help=('audioproc_f parameter values'
|
||||
'by which to group scores (no leading dash)'))
|
||||
parser.add_argument('-z',
|
||||
'--params_to_plot',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help=('audioproc_f parameter values'
|
||||
'by which to group scores (no leading dash)'))
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
|
||||
"""Filters data on the values of one or more parameters.
|
||||
"""Filters data on the values of one or more parameters.
|
||||
|
||||
Args:
|
||||
data_frame: pandas.DataFrame of all used input data.
|
||||
@ -71,34 +78,36 @@ def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
|
||||
Returns: dictionary, key is a param value, result is all scores for
|
||||
that param value (see `filter_params` for explanation).
|
||||
"""
|
||||
results = collections.defaultdict(dict)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
results = collections.defaultdict(dict)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + '.json'))
|
||||
data_with_config = data_frame[data_frame.apm_config == config_name]
|
||||
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
|
||||
score_name]
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + '.json'))
|
||||
data_with_config = data_frame[data_frame.apm_config == config_name]
|
||||
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
|
||||
score_name]
|
||||
|
||||
# Exactly one of |params_to_plot| must match:
|
||||
(matching_param, ) = [x for x in filter_params if '-' + x in config_json]
|
||||
# Exactly one of |params_to_plot| must match:
|
||||
(matching_param, ) = [
|
||||
x for x in filter_params if '-' + x in config_json
|
||||
]
|
||||
|
||||
# Add scores for every track to the result.
|
||||
for capture_name in data_cell_scores.capture:
|
||||
result_score = float(data_cell_scores[data_cell_scores.capture ==
|
||||
capture_name].score)
|
||||
config_dict = results[config_json['-' + matching_param]]
|
||||
if capture_name not in config_dict:
|
||||
config_dict[capture_name] = {}
|
||||
# Add scores for every track to the result.
|
||||
for capture_name in data_cell_scores.capture:
|
||||
result_score = float(data_cell_scores[data_cell_scores.capture ==
|
||||
capture_name].score)
|
||||
config_dict = results[config_json['-' + matching_param]]
|
||||
if capture_name not in config_dict:
|
||||
config_dict[capture_name] = {}
|
||||
|
||||
config_dict[capture_name][matching_param] = result_score
|
||||
config_dict[capture_name][matching_param] = result_score
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
|
||||
def _FlattenToScoresList(config_param_score_dict):
|
||||
"""Extracts a list of scores from input data structure.
|
||||
"""Extracts a list of scores from input data structure.
|
||||
|
||||
Args:
|
||||
config_param_score_dict: of the form {'capture_name':
|
||||
@ -107,40 +116,39 @@ def _FlattenToScoresList(config_param_score_dict):
|
||||
Returns: Plain list of all score value present in input data
|
||||
structure
|
||||
"""
|
||||
result = []
|
||||
for capture_name in config_param_score_dict:
|
||||
result += list(config_param_score_dict[capture_name].values())
|
||||
return result
|
||||
result = []
|
||||
for capture_name in config_param_score_dict:
|
||||
result += list(config_param_score_dict[capture_name].values())
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
|
||||
# Filter the data by `args.params_to_plot`
|
||||
scores_filtered = FilterScoresByParams(scores_data_frame,
|
||||
args.params_to_plot,
|
||||
args.eval_score,
|
||||
args.config_dir)
|
||||
# Filter the data by `args.params_to_plot`
|
||||
scores_filtered = FilterScoresByParams(scores_data_frame,
|
||||
args.params_to_plot,
|
||||
args.eval_score, args.config_dir)
|
||||
|
||||
data_list = sorted(scores_filtered.items())
|
||||
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
|
||||
data_labels = [x for (x, _) in data_list]
|
||||
data_list = sorted(scores_filtered.items())
|
||||
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
|
||||
data_labels = [x for (x, _) in data_list]
|
||||
|
||||
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
|
||||
axes.boxplot(data_values, labels=data_labels)
|
||||
axes.set_ylabel(args.eval_score)
|
||||
axes.set_xlabel('/'.join(args.params_to_plot))
|
||||
plt.show()
|
||||
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
|
||||
axes.boxplot(data_values, labels=data_labels)
|
||||
axes.set_ylabel(args.eval_score)
|
||||
axes.set_xlabel('/'.join(args.params_to_plot))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Export the scores computed by the apm_quality_assessment.py script into an
|
||||
HTML file.
|
||||
"""
|
||||
@ -20,7 +19,7 @@ import quality_assessment.export as export
|
||||
|
||||
|
||||
def _BuildOutputFilename(filename_suffix):
|
||||
"""Builds the filename for the exported file.
|
||||
"""Builds the filename for the exported file.
|
||||
|
||||
Args:
|
||||
filename_suffix: suffix for the output file name.
|
||||
@ -28,34 +27,37 @@ def _BuildOutputFilename(filename_suffix):
|
||||
Returns:
|
||||
A string.
|
||||
"""
|
||||
if filename_suffix is None:
|
||||
return 'results.html'
|
||||
return 'results-{}.html'.format(filename_suffix)
|
||||
if filename_suffix is None:
|
||||
return 'results.html'
|
||||
return 'results-{}.html'.format(filename_suffix)
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
logging.basicConfig(level=logging.DEBUG) # TODO(alessio): INFO once debugged.
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.add_argument('-f', '--filename_suffix',
|
||||
help=('suffix of the exported file'))
|
||||
parser.description = ('Exports pre-computed APM module quality assessment '
|
||||
'results into HTML tables')
|
||||
args = parser.parse_args()
|
||||
# Init.
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG) # TODO(alessio): INFO once debugged.
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.add_argument('-f',
|
||||
'--filename_suffix',
|
||||
help=('suffix of the exported file'))
|
||||
parser.description = ('Exports pre-computed APM module quality assessment '
|
||||
'results into HTML tables')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
|
||||
# Export.
|
||||
output_filepath = os.path.join(args.output_dir, _BuildOutputFilename(
|
||||
args.filename_suffix))
|
||||
exporter = export.HtmlExport(output_filepath)
|
||||
exporter.Export(scores_data_frame)
|
||||
# Export.
|
||||
output_filepath = os.path.join(args.output_dir,
|
||||
_BuildOutputFilename(args.filename_suffix))
|
||||
exporter = export.HtmlExport(output_filepath)
|
||||
exporter.Export(scores_data_frame)
|
||||
|
||||
logging.info('output file successfully written in %s', output_filepath)
|
||||
sys.exit(0)
|
||||
logging.info('output file successfully written in %s', output_filepath)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Generate .json files with which the APM module can be tested using the
|
||||
apm_quality_assessment.py script and audioproc_f as APM simulator.
|
||||
"""
|
||||
@ -20,7 +19,7 @@ OUTPUT_PATH = os.path.abspath('apm_configs')
|
||||
|
||||
|
||||
def _GenerateDefaultOverridden(config_override):
|
||||
"""Generates one or more APM overriden configurations.
|
||||
"""Generates one or more APM overriden configurations.
|
||||
|
||||
For each item in config_override, it overrides the default configuration and
|
||||
writes a new APM configuration file.
|
||||
@ -45,54 +44,85 @@ def _GenerateDefaultOverridden(config_override):
|
||||
config_override: dict of APM configuration file names as keys; the values
|
||||
are dict instances encoding the audioproc_f flags.
|
||||
"""
|
||||
for config_filename in config_override:
|
||||
config = config_override[config_filename]
|
||||
config['-all_default'] = None
|
||||
for config_filename in config_override:
|
||||
config = config_override[config_filename]
|
||||
config['-all_default'] = None
|
||||
|
||||
config_filepath = os.path.join(OUTPUT_PATH, 'default-{}.json'.format(
|
||||
config_filename))
|
||||
logging.debug('config file <%s> | %s', config_filepath, config)
|
||||
config_filepath = os.path.join(
|
||||
OUTPUT_PATH, 'default-{}.json'.format(config_filename))
|
||||
logging.debug('config file <%s> | %s', config_filepath, config)
|
||||
|
||||
data_access.AudioProcConfigFile.Save(config_filepath, config)
|
||||
logging.info('config file created: <%s>', config_filepath)
|
||||
data_access.AudioProcConfigFile.Save(config_filepath, config)
|
||||
logging.info('config file created: <%s>', config_filepath)
|
||||
|
||||
|
||||
def _GenerateAllDefaultButOne():
|
||||
"""Disables the flags enabled by default one-by-one.
|
||||
"""Disables the flags enabled by default one-by-one.
|
||||
"""
|
||||
config_sets = {
|
||||
'no_AEC': {'-aec': 0,},
|
||||
'no_AGC': {'-agc': 0,},
|
||||
'no_HP_filter': {'-hpf': 0,},
|
||||
'no_level_estimator': {'-le': 0,},
|
||||
'no_noise_suppressor': {'-ns': 0,},
|
||||
'no_transient_suppressor': {'-ts': 0,},
|
||||
'no_vad': {'-vad': 0,},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
config_sets = {
|
||||
'no_AEC': {
|
||||
'-aec': 0,
|
||||
},
|
||||
'no_AGC': {
|
||||
'-agc': 0,
|
||||
},
|
||||
'no_HP_filter': {
|
||||
'-hpf': 0,
|
||||
},
|
||||
'no_level_estimator': {
|
||||
'-le': 0,
|
||||
},
|
||||
'no_noise_suppressor': {
|
||||
'-ns': 0,
|
||||
},
|
||||
'no_transient_suppressor': {
|
||||
'-ts': 0,
|
||||
},
|
||||
'no_vad': {
|
||||
'-vad': 0,
|
||||
},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
|
||||
|
||||
def _GenerateAllDefaultPlusOne():
|
||||
"""Enables the flags disabled by default one-by-one.
|
||||
"""Enables the flags disabled by default one-by-one.
|
||||
"""
|
||||
config_sets = {
|
||||
'with_AECM': {'-aec': 0, '-aecm': 1,}, # AEC and AECM are exclusive.
|
||||
'with_AGC_limiter': {'-agc_limiter': 1,},
|
||||
'with_AEC_delay_agnostic': {'-delay_agnostic': 1,},
|
||||
'with_drift_compensation': {'-drift_compensation': 1,},
|
||||
'with_residual_echo_detector': {'-ed': 1,},
|
||||
'with_AEC_extended_filter': {'-extended_filter': 1,},
|
||||
'with_LC': {'-lc': 1,},
|
||||
'with_refined_adaptive_filter': {'-refined_adaptive_filter': 1,},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
config_sets = {
|
||||
'with_AECM': {
|
||||
'-aec': 0,
|
||||
'-aecm': 1,
|
||||
}, # AEC and AECM are exclusive.
|
||||
'with_AGC_limiter': {
|
||||
'-agc_limiter': 1,
|
||||
},
|
||||
'with_AEC_delay_agnostic': {
|
||||
'-delay_agnostic': 1,
|
||||
},
|
||||
'with_drift_compensation': {
|
||||
'-drift_compensation': 1,
|
||||
},
|
||||
'with_residual_echo_detector': {
|
||||
'-ed': 1,
|
||||
},
|
||||
'with_AEC_extended_filter': {
|
||||
'-extended_filter': 1,
|
||||
},
|
||||
'with_LC': {
|
||||
'-lc': 1,
|
||||
},
|
||||
'with_refined_adaptive_filter': {
|
||||
'-refined_adaptive_filter': 1,
|
||||
},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_GenerateAllDefaultPlusOne()
|
||||
_GenerateAllDefaultButOne()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_GenerateAllDefaultPlusOne()
|
||||
_GenerateAllDefaultButOne()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Finds the APM configuration that maximizes a provided metric by
|
||||
parsing the output generated apm_quality_assessment.py.
|
||||
"""
|
||||
@ -20,33 +19,44 @@ import os
|
||||
import quality_assessment.data_access as data_access
|
||||
import quality_assessment.collect_data as collect_data
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory. Extends the arguments from 'collect_data'
|
||||
"""Arguments parser factory. Extends the arguments from 'collect_data'
|
||||
with a few extra for selecting what parameters to optimize for.
|
||||
"""
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Rudimentary optimization of a function over different parameter'
|
||||
'combinations.')
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Rudimentary optimization of a function over different parameter'
|
||||
'combinations.')
|
||||
|
||||
parser.add_argument('-n', '--config_dir', required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--config_dir',
|
||||
required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
|
||||
parser.add_argument('-p', '--params', required=True, nargs='+',
|
||||
help=('parameters to parse from the config files in'
|
||||
'config_dir'))
|
||||
parser.add_argument('-p',
|
||||
'--params',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help=('parameters to parse from the config files in'
|
||||
'config_dir'))
|
||||
|
||||
parser.add_argument('-z', '--params_not_to_optimize', required=False,
|
||||
nargs='+', default=[],
|
||||
help=('parameters from `params` not to be optimized for'))
|
||||
parser.add_argument(
|
||||
'-z',
|
||||
'--params_not_to_optimize',
|
||||
required=False,
|
||||
nargs='+',
|
||||
default=[],
|
||||
help=('parameters from `params` not to be optimized for'))
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def _ConfigurationAndScores(data_frame, params,
|
||||
params_not_to_optimize, config_dir):
|
||||
"""Returns a list of all configurations and scores.
|
||||
def _ConfigurationAndScores(data_frame, params, params_not_to_optimize,
|
||||
config_dir):
|
||||
"""Returns a list of all configurations and scores.
|
||||
|
||||
Args:
|
||||
data_frame: A pandas data frame with the scores and config name
|
||||
@ -72,47 +82,47 @@ def _ConfigurationAndScores(data_frame, params,
|
||||
param combinations for params in `params_not_to_optimize` and
|
||||
their scores.
|
||||
"""
|
||||
results = collections.defaultdict(list)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
score_names = data_frame['eval_score_name'].drop_duplicates().values.tolist()
|
||||
results = collections.defaultdict(list)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
score_names = data_frame['eval_score_name'].drop_duplicates(
|
||||
).values.tolist()
|
||||
|
||||
# Normalize the scores
|
||||
normalization_constants = {}
|
||||
for score_name in score_names:
|
||||
scores = data_frame[data_frame.eval_score_name == score_name].score
|
||||
normalization_constants[score_name] = max(scores)
|
||||
|
||||
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
|
||||
param_combination = collections.namedtuple("ParamCombination",
|
||||
params_to_optimize)
|
||||
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + ".json"))
|
||||
scores = {}
|
||||
data_cell = data_frame[data_frame.apm_config == config_name]
|
||||
# Normalize the scores
|
||||
normalization_constants = {}
|
||||
for score_name in score_names:
|
||||
data_cell_scores = data_cell[data_cell.eval_score_name ==
|
||||
score_name].score
|
||||
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
|
||||
scores[score_name] /= normalization_constants[score_name]
|
||||
scores = data_frame[data_frame.eval_score_name == score_name].score
|
||||
normalization_constants[score_name] = max(scores)
|
||||
|
||||
result = {'scores': scores, 'params': {}}
|
||||
config_optimize_params = {}
|
||||
for param in params:
|
||||
if param in params_to_optimize:
|
||||
config_optimize_params[param] = config_json['-' + param]
|
||||
else:
|
||||
result['params'][param] = config_json['-' + param]
|
||||
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
|
||||
param_combination = collections.namedtuple("ParamCombination",
|
||||
params_to_optimize)
|
||||
|
||||
current_param_combination = param_combination(
|
||||
**config_optimize_params)
|
||||
results[current_param_combination].append(result)
|
||||
return results
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + ".json"))
|
||||
scores = {}
|
||||
data_cell = data_frame[data_frame.apm_config == config_name]
|
||||
for score_name in score_names:
|
||||
data_cell_scores = data_cell[data_cell.eval_score_name ==
|
||||
score_name].score
|
||||
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
|
||||
scores[score_name] /= normalization_constants[score_name]
|
||||
|
||||
result = {'scores': scores, 'params': {}}
|
||||
config_optimize_params = {}
|
||||
for param in params:
|
||||
if param in params_to_optimize:
|
||||
config_optimize_params[param] = config_json['-' + param]
|
||||
else:
|
||||
result['params'][param] = config_json['-' + param]
|
||||
|
||||
current_param_combination = param_combination(**config_optimize_params)
|
||||
results[current_param_combination].append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _FindOptimalParameter(configs_and_scores, score_weighting):
|
||||
"""Finds the config producing the maximal score.
|
||||
"""Finds the config producing the maximal score.
|
||||
|
||||
Args:
|
||||
configs_and_scores: structure of the form returned by
|
||||
@ -127,53 +137,53 @@ def _FindOptimalParameter(configs_and_scores, score_weighting):
|
||||
to its scores.
|
||||
"""
|
||||
|
||||
min_score = float('+inf')
|
||||
best_params = None
|
||||
for config in configs_and_scores:
|
||||
scores_and_params = configs_and_scores[config]
|
||||
current_score = score_weighting(scores_and_params)
|
||||
if current_score < min_score:
|
||||
min_score = current_score
|
||||
best_params = config
|
||||
logging.debug("Score: %f", current_score)
|
||||
logging.debug("Config: %s", str(config))
|
||||
return best_params
|
||||
min_score = float('+inf')
|
||||
best_params = None
|
||||
for config in configs_and_scores:
|
||||
scores_and_params = configs_and_scores[config]
|
||||
current_score = score_weighting(scores_and_params)
|
||||
if current_score < min_score:
|
||||
min_score = current_score
|
||||
best_params = config
|
||||
logging.debug("Score: %f", current_score)
|
||||
logging.debug("Config: %s", str(config))
|
||||
return best_params
|
||||
|
||||
|
||||
def _ExampleWeighting(scores_and_configs):
|
||||
"""Example argument to `_FindOptimalParameter`
|
||||
"""Example argument to `_FindOptimalParameter`
|
||||
Args:
|
||||
scores_and_configs: a list of configs and scores, in the form
|
||||
described in _FindOptimalParameter
|
||||
Returns:
|
||||
numeric value, the sum of all scores
|
||||
"""
|
||||
res = 0
|
||||
for score_config in scores_and_configs:
|
||||
res += sum(score_config['scores'].values())
|
||||
return res
|
||||
res = 0
|
||||
for score_config in scores_and_configs:
|
||||
res += sum(score_config['scores'].values())
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug('Src path <%s>', src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
all_scores = _ConfigurationAndScores(scores_data_frame,
|
||||
args.params,
|
||||
args.params_not_to_optimize,
|
||||
args.config_dir)
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug('Src path <%s>', src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
all_scores = _ConfigurationAndScores(scores_data_frame, args.params,
|
||||
args.params_not_to_optimize,
|
||||
args.config_dir)
|
||||
|
||||
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
|
||||
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
|
||||
|
||||
logging.info('Optimal parameter combination: <%s>', opt_param)
|
||||
logging.info('It\'s score values: <%s>', all_scores[opt_param])
|
||||
|
||||
logging.info('Optimal parameter combination: <%s>', opt_param)
|
||||
logging.info('It\'s score values: <%s>', all_scores[opt_param])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""
|
||||
|
||||
@ -16,13 +15,14 @@ import mock
|
||||
|
||||
import apm_quality_assessment
|
||||
|
||||
|
||||
class TestSimulationScript(unittest.TestCase):
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""
|
||||
|
||||
def testMain(self):
|
||||
# Exit with error code if no arguments are passed.
|
||||
with self.assertRaises(SystemExit) as cm, mock.patch.object(
|
||||
sys, 'argv', ['apm_quality_assessment.py']):
|
||||
apm_quality_assessment.main()
|
||||
self.assertGreater(cm.exception.code, 0)
|
||||
def testMain(self):
|
||||
# Exit with error code if no arguments are passed.
|
||||
with self.assertRaises(SystemExit) as cm, mock.patch.object(
|
||||
sys, 'argv', ['apm_quality_assessment.py']):
|
||||
apm_quality_assessment.main()
|
||||
self.assertGreater(cm.exception.code, 0)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Extraction of annotations from audio files.
|
||||
"""
|
||||
|
||||
@ -19,10 +18,10 @@ import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import external_vad
|
||||
from . import exceptions
|
||||
@ -30,262 +29,268 @@ from . import signal_processing
|
||||
|
||||
|
||||
class AudioAnnotationsExtractor(object):
|
||||
"""Extracts annotations from audio files.
|
||||
"""Extracts annotations from audio files.
|
||||
"""
|
||||
|
||||
class VadType(object):
|
||||
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
|
||||
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
|
||||
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
|
||||
class VadType(object):
|
||||
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
|
||||
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
|
||||
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
|
||||
|
||||
def __init__(self, value):
|
||||
if (not isinstance(value, int)) or not 0 <= value <= 7:
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + value)
|
||||
self._value = value
|
||||
def __init__(self, value):
|
||||
if (not isinstance(value, int)) or not 0 <= value <= 7:
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
value)
|
||||
self._value = value
|
||||
|
||||
def Contains(self, vad_type):
|
||||
return self._value | vad_type == self._value
|
||||
def Contains(self, vad_type):
|
||||
return self._value | vad_type == self._value
|
||||
|
||||
def __str__(self):
|
||||
vads = []
|
||||
if self.Contains(self.ENERGY_THRESHOLD):
|
||||
vads.append("energy")
|
||||
if self.Contains(self.WEBRTC_COMMON_AUDIO):
|
||||
vads.append("common_audio")
|
||||
if self.Contains(self.WEBRTC_APM):
|
||||
vads.append("apm")
|
||||
return "VadType({})".format(", ".join(vads))
|
||||
def __str__(self):
|
||||
vads = []
|
||||
if self.Contains(self.ENERGY_THRESHOLD):
|
||||
vads.append("energy")
|
||||
if self.Contains(self.WEBRTC_COMMON_AUDIO):
|
||||
vads.append("common_audio")
|
||||
if self.Contains(self.WEBRTC_APM):
|
||||
vads.append("apm")
|
||||
return "VadType({})".format(", ".join(vads))
|
||||
|
||||
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
|
||||
|
||||
# Level estimation params.
|
||||
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
|
||||
_LEVEL_FRAME_SIZE_MS = 1.0
|
||||
# The time constants in ms indicate the time it takes for the level estimate
|
||||
# to go down/up by 1 db if the signal is zero.
|
||||
_LEVEL_ATTACK_MS = 5.0
|
||||
_LEVEL_DECAY_MS = 20.0
|
||||
|
||||
# VAD params.
|
||||
_VAD_THRESHOLD = 1
|
||||
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)), os.pardir, os.pardir)
|
||||
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
|
||||
|
||||
_VAD_WEBRTC_APM_PATH = os.path.join(
|
||||
_VAD_WEBRTC_PATH, 'apm_vad')
|
||||
|
||||
def __init__(self, vad_type, external_vads=None):
|
||||
self._signal = None
|
||||
self._level = None
|
||||
self._level_frame_size = None
|
||||
self._common_audio_vad = None
|
||||
self._energy_vad = None
|
||||
self._apm_vad_probs = None
|
||||
self._apm_vad_rms = None
|
||||
self._vad_frame_size = None
|
||||
self._vad_frame_size_ms = None
|
||||
self._c_attack = None
|
||||
self._c_decay = None
|
||||
|
||||
self._vad_type = self.VadType(vad_type)
|
||||
logging.info('VADs used for annotations: ' + str(self._vad_type))
|
||||
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._external_vads = external_vads
|
||||
|
||||
assert len(self._external_vads) == len(external_vads), (
|
||||
'The external VAD names must be unique.')
|
||||
for vad in external_vads.values():
|
||||
if not isinstance(vad, external_vad.ExternalVad):
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + str(type(vad)))
|
||||
logging.info('External VAD used for annotation: ' +
|
||||
str(vad.name))
|
||||
|
||||
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH
|
||||
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
|
||||
self._VAD_WEBRTC_APM_PATH
|
||||
|
||||
@classmethod
|
||||
def GetOutputFileNameTemplate(cls):
|
||||
return cls._OUTPUT_FILENAME_TEMPLATE
|
||||
|
||||
def GetLevel(self):
|
||||
return self._level
|
||||
|
||||
def GetLevelFrameSize(self):
|
||||
return self._level_frame_size
|
||||
|
||||
@classmethod
|
||||
def GetLevelFrameSizeMs(cls):
|
||||
return cls._LEVEL_FRAME_SIZE_MS
|
||||
|
||||
def GetVadOutput(self, vad_type):
|
||||
if vad_type == self.VadType.ENERGY_THRESHOLD:
|
||||
return self._energy_vad
|
||||
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
|
||||
return self._common_audio_vad
|
||||
elif vad_type == self.VadType.WEBRTC_APM:
|
||||
return (self._apm_vad_probs, self._apm_vad_rms)
|
||||
else:
|
||||
raise exceptions.InitializationException(
|
||||
'Invalid vad type: ' + vad_type)
|
||||
|
||||
def GetVadFrameSize(self):
|
||||
return self._vad_frame_size
|
||||
|
||||
def GetVadFrameSizeMs(self):
|
||||
return self._vad_frame_size_ms
|
||||
|
||||
def Extract(self, filepath):
|
||||
# Load signal.
|
||||
self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
|
||||
if self._signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel annotations not implemented')
|
||||
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
|
||||
|
||||
# Level estimation params.
|
||||
self._level_frame_size = int(self._signal.frame_rate / 1000 * (
|
||||
self._LEVEL_FRAME_SIZE_MS))
|
||||
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION ** (
|
||||
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS))
|
||||
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION ** (
|
||||
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS))
|
||||
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
|
||||
_LEVEL_FRAME_SIZE_MS = 1.0
|
||||
# The time constants in ms indicate the time it takes for the level estimate
|
||||
# to go down/up by 1 db if the signal is zero.
|
||||
_LEVEL_ATTACK_MS = 5.0
|
||||
_LEVEL_DECAY_MS = 20.0
|
||||
|
||||
# Compute level.
|
||||
self._LevelEstimation()
|
||||
# VAD params.
|
||||
_VAD_THRESHOLD = 1
|
||||
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
os.pardir, os.pardir)
|
||||
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
|
||||
|
||||
# Ideal VAD output, it requires clean speech with high SNR as input.
|
||||
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
|
||||
# Naive VAD based on level thresholding.
|
||||
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
|
||||
self._energy_vad = np.uint8(self._level > vad_threshold)
|
||||
self._vad_frame_size = self._level_frame_size
|
||||
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
|
||||
# WebRTC common_audio/ VAD.
|
||||
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
|
||||
# WebRTC modules/audio_processing/ VAD.
|
||||
self._RunWebRtcApmVad(filepath)
|
||||
for extvad_name in self._external_vads:
|
||||
self._external_vads[extvad_name].Run(filepath)
|
||||
_VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
|
||||
|
||||
def Save(self, output_path, annotation_name=""):
|
||||
ext_kwargs = {'extvad_conf-' + ext_vad:
|
||||
self._external_vads[ext_vad].GetVadOutput()
|
||||
for ext_vad in self._external_vads}
|
||||
np.savez_compressed(
|
||||
file=os.path.join(
|
||||
def __init__(self, vad_type, external_vads=None):
|
||||
self._signal = None
|
||||
self._level = None
|
||||
self._level_frame_size = None
|
||||
self._common_audio_vad = None
|
||||
self._energy_vad = None
|
||||
self._apm_vad_probs = None
|
||||
self._apm_vad_rms = None
|
||||
self._vad_frame_size = None
|
||||
self._vad_frame_size_ms = None
|
||||
self._c_attack = None
|
||||
self._c_decay = None
|
||||
|
||||
self._vad_type = self.VadType(vad_type)
|
||||
logging.info('VADs used for annotations: ' + str(self._vad_type))
|
||||
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._external_vads = external_vads
|
||||
|
||||
assert len(self._external_vads) == len(external_vads), (
|
||||
'The external VAD names must be unique.')
|
||||
for vad in external_vads.values():
|
||||
if not isinstance(vad, external_vad.ExternalVad):
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
str(type(vad)))
|
||||
logging.info('External VAD used for annotation: ' + str(vad.name))
|
||||
|
||||
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH
|
||||
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
|
||||
self._VAD_WEBRTC_APM_PATH
|
||||
|
||||
@classmethod
|
||||
def GetOutputFileNameTemplate(cls):
|
||||
return cls._OUTPUT_FILENAME_TEMPLATE
|
||||
|
||||
def GetLevel(self):
|
||||
return self._level
|
||||
|
||||
def GetLevelFrameSize(self):
|
||||
return self._level_frame_size
|
||||
|
||||
@classmethod
|
||||
def GetLevelFrameSizeMs(cls):
|
||||
return cls._LEVEL_FRAME_SIZE_MS
|
||||
|
||||
def GetVadOutput(self, vad_type):
|
||||
if vad_type == self.VadType.ENERGY_THRESHOLD:
|
||||
return self._energy_vad
|
||||
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
|
||||
return self._common_audio_vad
|
||||
elif vad_type == self.VadType.WEBRTC_APM:
|
||||
return (self._apm_vad_probs, self._apm_vad_rms)
|
||||
else:
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
vad_type)
|
||||
|
||||
def GetVadFrameSize(self):
|
||||
return self._vad_frame_size
|
||||
|
||||
def GetVadFrameSizeMs(self):
|
||||
return self._vad_frame_size_ms
|
||||
|
||||
def Extract(self, filepath):
|
||||
# Load signal.
|
||||
self._signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
filepath)
|
||||
if self._signal.channels != 1:
|
||||
raise NotImplementedError(
|
||||
'Multiple-channel annotations not implemented')
|
||||
|
||||
# Level estimation params.
|
||||
self._level_frame_size = int(self._signal.frame_rate / 1000 *
|
||||
(self._LEVEL_FRAME_SIZE_MS))
|
||||
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
|
||||
self._LEVEL_ATTACK_MS))
|
||||
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
|
||||
self._LEVEL_DECAY_MS))
|
||||
|
||||
# Compute level.
|
||||
self._LevelEstimation()
|
||||
|
||||
# Ideal VAD output, it requires clean speech with high SNR as input.
|
||||
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
|
||||
# Naive VAD based on level thresholding.
|
||||
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
|
||||
self._energy_vad = np.uint8(self._level > vad_threshold)
|
||||
self._vad_frame_size = self._level_frame_size
|
||||
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
|
||||
# WebRTC common_audio/ VAD.
|
||||
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
|
||||
# WebRTC modules/audio_processing/ VAD.
|
||||
self._RunWebRtcApmVad(filepath)
|
||||
for extvad_name in self._external_vads:
|
||||
self._external_vads[extvad_name].Run(filepath)
|
||||
|
||||
def Save(self, output_path, annotation_name=""):
|
||||
ext_kwargs = {
|
||||
'extvad_conf-' + ext_vad:
|
||||
self._external_vads[ext_vad].GetVadOutput()
|
||||
for ext_vad in self._external_vads
|
||||
}
|
||||
np.savez_compressed(file=os.path.join(
|
||||
output_path,
|
||||
self.GetOutputFileNameTemplate().format(annotation_name)),
|
||||
level=self._level,
|
||||
level_frame_size=self._level_frame_size,
|
||||
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
|
||||
vad_output=self._common_audio_vad,
|
||||
vad_energy_output=self._energy_vad,
|
||||
vad_frame_size=self._vad_frame_size,
|
||||
vad_frame_size_ms=self._vad_frame_size_ms,
|
||||
vad_probs=self._apm_vad_probs,
|
||||
vad_rms=self._apm_vad_rms,
|
||||
**ext_kwargs
|
||||
)
|
||||
level=self._level,
|
||||
level_frame_size=self._level_frame_size,
|
||||
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
|
||||
vad_output=self._common_audio_vad,
|
||||
vad_energy_output=self._energy_vad,
|
||||
vad_frame_size=self._vad_frame_size,
|
||||
vad_frame_size_ms=self._vad_frame_size_ms,
|
||||
vad_probs=self._apm_vad_probs,
|
||||
vad_rms=self._apm_vad_rms,
|
||||
**ext_kwargs)
|
||||
|
||||
def _LevelEstimation(self):
|
||||
# Read samples.
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._signal).astype(np.float32) / 32768.0
|
||||
num_frames = len(samples) // self._level_frame_size
|
||||
num_samples = num_frames * self._level_frame_size
|
||||
def _LevelEstimation(self):
|
||||
# Read samples.
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._signal).astype(np.float32) / 32768.0
|
||||
num_frames = len(samples) // self._level_frame_size
|
||||
num_samples = num_frames * self._level_frame_size
|
||||
|
||||
# Envelope.
|
||||
self._level = np.max(np.reshape(np.abs(samples[:num_samples]), (
|
||||
num_frames, self._level_frame_size)), axis=1)
|
||||
assert len(self._level) == num_frames
|
||||
# Envelope.
|
||||
self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
|
||||
(num_frames, self._level_frame_size)),
|
||||
axis=1)
|
||||
assert len(self._level) == num_frames
|
||||
|
||||
# Envelope smoothing.
|
||||
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
|
||||
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
|
||||
for i in range(1, num_frames):
|
||||
self._level[i] = smooth(
|
||||
self._level[i], self._level[i - 1], self._c_attack if (
|
||||
self._level[i] > self._level[i - 1]) else self._c_decay)
|
||||
# Envelope smoothing.
|
||||
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
|
||||
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
|
||||
for i in range(1, num_frames):
|
||||
self._level[i] = smooth(
|
||||
self._level[i], self._level[i - 1], self._c_attack if
|
||||
(self._level[i] > self._level[i - 1]) else self._c_decay)
|
||||
|
||||
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
|
||||
self._common_audio_vad = None
|
||||
self._vad_frame_size = None
|
||||
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
|
||||
self._common_audio_vad = None
|
||||
self._vad_frame_size = None
|
||||
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH,
|
||||
'-i', wav_file_path,
|
||||
'-o', output_file_path
|
||||
], cwd=self._VAD_WEBRTC_PATH)
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
|
||||
output_file_path
|
||||
],
|
||||
cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Read bytes.
|
||||
with open(output_file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
# Read bytes.
|
||||
with open(output_file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
|
||||
# Parse side information.
|
||||
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
|
||||
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
|
||||
assert self._vad_frame_size_ms in [10, 20, 30]
|
||||
extra_bits = struct.unpack('B', raw_data[-1])[0]
|
||||
assert 0 <= extra_bits <= 8
|
||||
# Parse side information.
|
||||
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
|
||||
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
|
||||
assert self._vad_frame_size_ms in [10, 20, 30]
|
||||
extra_bits = struct.unpack('B', raw_data[-1])[0]
|
||||
assert 0 <= extra_bits <= 8
|
||||
|
||||
# Init VAD vector.
|
||||
num_bytes = len(raw_data)
|
||||
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
|
||||
self._common_audio_vad = np.zeros(num_frames, np.uint8)
|
||||
# Init VAD vector.
|
||||
num_bytes = len(raw_data)
|
||||
num_frames = 8 * (num_bytes -
|
||||
2) - extra_bits # 8 frames for each byte.
|
||||
self._common_audio_vad = np.zeros(num_frames, np.uint8)
|
||||
|
||||
# Read VAD decisions.
|
||||
for i, byte in enumerate(raw_data[1:-1]):
|
||||
byte = struct.unpack('B', byte)[0]
|
||||
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
|
||||
self._common_audio_vad[i * 8 + j] = int(byte & 1)
|
||||
byte = byte >> 1
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
# Read VAD decisions.
|
||||
for i, byte in enumerate(raw_data[1:-1]):
|
||||
byte = struct.unpack('B', byte)[0]
|
||||
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
|
||||
self._common_audio_vad[i * 8 + j] = int(byte & 1)
|
||||
byte = byte >> 1
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC VAD (' + e.message +
|
||||
')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
def _RunWebRtcApmVad(self, wav_file_path):
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path_probs = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
|
||||
output_file_path_rms = os.path.join(
|
||||
tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
|
||||
def _RunWebRtcApmVad(self, wav_file_path):
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path_probs = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
|
||||
output_file_path_rms = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_APM_PATH,
|
||||
'-i', wav_file_path,
|
||||
'-o_probs', output_file_path_probs,
|
||||
'-o_rms', output_file_path_rms
|
||||
], cwd=self._VAD_WEBRTC_PATH)
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
|
||||
output_file_path_probs, '-o_rms', output_file_path_rms
|
||||
],
|
||||
cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Parse annotations.
|
||||
self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double)
|
||||
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
|
||||
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
|
||||
# Parse annotations.
|
||||
self._apm_vad_probs = np.fromfile(output_file_path_probs,
|
||||
np.double)
|
||||
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
|
||||
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
|
||||
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC APM VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC APM VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
@ -25,133 +24,137 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestAnnotationsExtraction(unittest.TestCase):
|
||||
"""Unit tests for the annotations module.
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_DEBUG_PLOT_VAD = False
|
||||
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
|
||||
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
|
||||
_VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
|
||||
_VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_DEBUG_PLOT_VAD = False
|
||||
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
|
||||
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
|
||||
| _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
|
||||
| _VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
|
||||
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
|
||||
'pure_tone', [440, 1000])
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._wav_file_path, pure_tone)
|
||||
self._sample_rate = pure_tone.frame_rate
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
|
||||
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
|
||||
'pure_tone', [440, 1000])
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._wav_file_path, pure_tone)
|
||||
self._sample_rate = pure_tone.frame_rate
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' +
|
||||
(self._tmp_path))
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' + (
|
||||
self._tmp_path))
|
||||
def testFrameSizes(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
samples_to_ms = lambda n, sr: 1000 * n // sr
|
||||
self.assertEqual(
|
||||
samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
|
||||
e.GetLevelFrameSizeMs())
|
||||
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
|
||||
e.GetVadFrameSizeMs())
|
||||
|
||||
def testFrameSizes(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
samples_to_ms = lambda n, sr: 1000 * n // sr
|
||||
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
|
||||
e.GetLevelFrameSizeMs())
|
||||
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
|
||||
e.GetVadFrameSizeMs())
|
||||
def testVoiceActivityDetectors(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
|
||||
e.Extract(self._wav_file_path)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(
|
||||
self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
def testVoiceActivityDetectors(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
|
||||
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
|
||||
e.Extract(self._wav_file_path)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
|
||||
0.95)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(
|
||||
self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
|
||||
0.95)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
(vad_probs,
|
||||
vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
self.assertGreater(len(vad_probs), 0)
|
||||
self.assertGreater(len(vad_rms), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_probs)) / len(vad_probs), 0.5)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_rms)) / len(vad_rms), 20000)
|
||||
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
(vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
self.assertGreater(len(vad_probs), 0)
|
||||
self.assertGreater(len(vad_rms), 0)
|
||||
self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
|
||||
0.5)
|
||||
self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
|
||||
if self._DEBUG_PLOT_VAD:
|
||||
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
|
||||
num_frames).astype(np.float32) * frame_size_ms / 1000.0
|
||||
level = e.GetLevel()
|
||||
t_level = frame_times_s(num_frames=len(level),
|
||||
frame_size_ms=e.GetLevelFrameSizeMs())
|
||||
t_vad = frame_times_s(num_frames=len(vad_output),
|
||||
frame_size_ms=e.GetVadFrameSizeMs())
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hold(True)
|
||||
plt.plot(t_level, level)
|
||||
plt.plot(t_vad, vad_output * np.max(level), '.')
|
||||
plt.show()
|
||||
|
||||
if self._DEBUG_PLOT_VAD:
|
||||
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
|
||||
num_frames).astype(np.float32) * frame_size_ms / 1000.0
|
||||
level = e.GetLevel()
|
||||
t_level = frame_times_s(
|
||||
num_frames=len(level),
|
||||
frame_size_ms=e.GetLevelFrameSizeMs())
|
||||
t_vad = frame_times_s(
|
||||
num_frames=len(vad_output),
|
||||
frame_size_ms=e.GetVadFrameSizeMs())
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hold(True)
|
||||
plt.plot(t_level, level)
|
||||
plt.plot(t_vad, vad_output * np.max(level), '.')
|
||||
plt.show()
|
||||
def testSaveLoad(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, "fake-annotation")
|
||||
|
||||
def testSaveLoad(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, "fake-annotation")
|
||||
data = np.load(
|
||||
os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
np.testing.assert_array_equal(e.GetLevel(), data['level'])
|
||||
self.assertEqual(np.float32, data['level'].dtype)
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
|
||||
data['vad_energy_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
|
||||
data['vad_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
|
||||
data['vad_probs'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
|
||||
data['vad_rms'])
|
||||
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_probs'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_rms'].dtype)
|
||||
|
||||
data = np.load(os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
np.testing.assert_array_equal(e.GetLevel(), data['level'])
|
||||
self.assertEqual(np.float32, data['level'].dtype)
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
|
||||
data['vad_energy_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
|
||||
data['vad_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
|
||||
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_probs'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_rms'].dtype)
|
||||
def testEmptyExternalShouldNotCrash(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
annotations.AudioAnnotationsExtractor(vad_type_value, {})
|
||||
|
||||
def testEmptyExternalShouldNotCrash(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
|
||||
annotations.AudioAnnotationsExtractor(vad_type_value, {})
|
||||
def testFakeExternalSaveLoad(self):
|
||||
def FakeExternalFactory():
|
||||
return external_vad.ExternalVad(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fake_external_vad.py'), 'fake')
|
||||
|
||||
def testFakeExternalSaveLoad(self):
|
||||
def FakeExternalFactory():
|
||||
return external_vad.ExternalVad(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
|
||||
'fake'
|
||||
)
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
|
||||
e = annotations.AudioAnnotationsExtractor(
|
||||
vad_type_value,
|
||||
{'fake': FakeExternalFactory()})
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, annotation_name="fake-annotation")
|
||||
data = np.load(os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
|
||||
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
|
||||
data['extvad_conf-fake'])
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
e = annotations.AudioAnnotationsExtractor(
|
||||
vad_type_value, {'fake': FakeExternalFactory()})
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, annotation_name="fake-annotation")
|
||||
data = np.load(
|
||||
os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
|
||||
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
|
||||
data['extvad_conf-fake'])
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Class implementing a wrapper for APM simulators.
|
||||
"""
|
||||
|
||||
@ -19,33 +18,36 @@ from . import exceptions
|
||||
|
||||
|
||||
class AudioProcWrapper(object):
|
||||
"""Wrapper for APM simulators.
|
||||
"""Wrapper for APM simulators.
|
||||
"""
|
||||
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(os.path.join(
|
||||
os.pardir, 'audioproc_f'))
|
||||
OUTPUT_FILENAME = 'output.wav'
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(
|
||||
os.path.join(os.pardir, 'audioproc_f'))
|
||||
OUTPUT_FILENAME = 'output.wav'
|
||||
|
||||
def __init__(self, simulator_bin_path):
|
||||
"""Ctor.
|
||||
def __init__(self, simulator_bin_path):
|
||||
"""Ctor.
|
||||
|
||||
Args:
|
||||
simulator_bin_path: path to the APM simulator binary.
|
||||
"""
|
||||
self._simulator_bin_path = simulator_bin_path
|
||||
self._config = None
|
||||
self._output_signal_filepath = None
|
||||
self._simulator_bin_path = simulator_bin_path
|
||||
self._config = None
|
||||
self._output_signal_filepath = None
|
||||
|
||||
# Profiler instance to measure running time.
|
||||
self._profiler = cProfile.Profile()
|
||||
# Profiler instance to measure running time.
|
||||
self._profiler = cProfile.Profile()
|
||||
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_signal_filepath
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_signal_filepath
|
||||
|
||||
def Run(self, config_filepath, capture_input_filepath, output_path,
|
||||
render_input_filepath=None):
|
||||
"""Runs APM simulator.
|
||||
def Run(self,
|
||||
config_filepath,
|
||||
capture_input_filepath,
|
||||
output_path,
|
||||
render_input_filepath=None):
|
||||
"""Runs APM simulator.
|
||||
|
||||
Args:
|
||||
config_filepath: path to the configuration file specifying the arguments
|
||||
@ -56,41 +58,43 @@ class AudioProcWrapper(object):
|
||||
render_input_filepath: path to the render audio track input file (aka
|
||||
reverse or far-end).
|
||||
"""
|
||||
# Init.
|
||||
self._output_signal_filepath = os.path.join(
|
||||
output_path, self.OUTPUT_FILENAME)
|
||||
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
|
||||
# Init.
|
||||
self._output_signal_filepath = os.path.join(output_path,
|
||||
self.OUTPUT_FILENAME)
|
||||
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
|
||||
|
||||
# Skip if the output has already been generated.
|
||||
if os.path.exists(self._output_signal_filepath) and os.path.exists(
|
||||
profiling_stats_filepath):
|
||||
return
|
||||
# Skip if the output has already been generated.
|
||||
if os.path.exists(self._output_signal_filepath) and os.path.exists(
|
||||
profiling_stats_filepath):
|
||||
return
|
||||
|
||||
# Load configuration.
|
||||
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
|
||||
# Load configuration.
|
||||
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
|
||||
|
||||
# Set remaining parameters.
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
raise exceptions.FileNotFoundError('cannot find capture input file')
|
||||
self._config['-i'] = capture_input_filepath
|
||||
self._config['-o'] = self._output_signal_filepath
|
||||
if render_input_filepath is not None:
|
||||
if not os.path.exists(render_input_filepath):
|
||||
raise exceptions.FileNotFoundError('cannot find render input file')
|
||||
self._config['-ri'] = render_input_filepath
|
||||
# Set remaining parameters.
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
raise exceptions.FileNotFoundError(
|
||||
'cannot find capture input file')
|
||||
self._config['-i'] = capture_input_filepath
|
||||
self._config['-o'] = self._output_signal_filepath
|
||||
if render_input_filepath is not None:
|
||||
if not os.path.exists(render_input_filepath):
|
||||
raise exceptions.FileNotFoundError(
|
||||
'cannot find render input file')
|
||||
self._config['-ri'] = render_input_filepath
|
||||
|
||||
# Build arguments list.
|
||||
args = [self._simulator_bin_path]
|
||||
for param_name in self._config:
|
||||
args.append(param_name)
|
||||
if self._config[param_name] is not None:
|
||||
args.append(str(self._config[param_name]))
|
||||
logging.debug(' '.join(args))
|
||||
# Build arguments list.
|
||||
args = [self._simulator_bin_path]
|
||||
for param_name in self._config:
|
||||
args.append(param_name)
|
||||
if self._config[param_name] is not None:
|
||||
args.append(str(self._config[param_name]))
|
||||
logging.debug(' '.join(args))
|
||||
|
||||
# Run.
|
||||
self._profiler.enable()
|
||||
subprocess.call(args)
|
||||
self._profiler.disable()
|
||||
# Run.
|
||||
self._profiler.enable()
|
||||
subprocess.call(args)
|
||||
self._profiler.disable()
|
||||
|
||||
# Save profiling stats.
|
||||
self._profiler.dump_stats(profiling_stats_filepath)
|
||||
# Save profiling stats.
|
||||
self._profiler.dump_stats(profiling_stats_filepath)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Imports a filtered subset of the scores and configurations computed
|
||||
by apm_quality_assessment.py into a pandas data frame.
|
||||
"""
|
||||
@ -18,71 +17,88 @@ import re
|
||||
import sys
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package pandas')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package pandas')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access as data_access
|
||||
from . import simulation as sim
|
||||
|
||||
# Compiled regular expressions used to extract score descriptors.
|
||||
RE_CONFIG_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + r'(.+)')
|
||||
RE_CAPTURE_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + r'(.+)')
|
||||
RE_RENDER_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
|
||||
RE_ECHO_SIM_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + r'(.+)')
|
||||
RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() +
|
||||
r'(.+)')
|
||||
RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() +
|
||||
r'(.+)')
|
||||
RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
|
||||
RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() +
|
||||
r'(.+)')
|
||||
RE_TEST_DATA_GEN_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)')
|
||||
RE_TEST_DATA_GEN_PARAMS = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)')
|
||||
RE_SCORE_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + r'(.+)(\..+)')
|
||||
RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() +
|
||||
r'(.+)(\..+)')
|
||||
|
||||
|
||||
def InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Override this description in a user script by changing'
|
||||
' `parser.description` of the returned parser.'))
|
||||
parser = argparse.ArgumentParser(
|
||||
description=('Override this description in a user script by changing'
|
||||
' `parser.description` of the returned parser.'))
|
||||
|
||||
parser.add_argument('-o', '--output_dir', required=True,
|
||||
help=('the same base path used with the '
|
||||
'apm_quality_assessment tool'))
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
required=True,
|
||||
help=('the same base path used with the '
|
||||
'apm_quality_assessment tool'))
|
||||
|
||||
parser.add_argument('-c', '--config_names', type=re.compile,
|
||||
help=('regular expression to filter the APM configuration'
|
||||
' names'))
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the APM configuration'
|
||||
' names'))
|
||||
|
||||
parser.add_argument('-i', '--capture_names', type=re.compile,
|
||||
help=('regular expression to filter the capture signal '
|
||||
'names'))
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--capture_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the capture signal '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-r', '--render_names', type=re.compile,
|
||||
help=('regular expression to filter the render signal '
|
||||
'names'))
|
||||
parser.add_argument('-r',
|
||||
'--render_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the render signal '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-e', '--echo_simulator_names', type=re.compile,
|
||||
help=('regular expression to filter the echo simulator '
|
||||
'names'))
|
||||
parser.add_argument(
|
||||
'-e',
|
||||
'--echo_simulator_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the echo simulator '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-t', '--test_data_generators', type=re.compile,
|
||||
help=('regular expression to filter the test data '
|
||||
'generator names'))
|
||||
parser.add_argument('-t',
|
||||
'--test_data_generators',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the test data '
|
||||
'generator names'))
|
||||
|
||||
parser.add_argument('-s', '--eval_scores', type=re.compile,
|
||||
help=('regular expression to filter the evaluation score '
|
||||
'names'))
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--eval_scores',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the evaluation score '
|
||||
'names'))
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def _GetScoreDescriptors(score_filepath):
|
||||
"""Extracts a score descriptor from the given score file path.
|
||||
"""Extracts a score descriptor from the given score file path.
|
||||
|
||||
Args:
|
||||
score_filepath: path to the score file.
|
||||
@ -92,23 +108,23 @@ def _GetScoreDescriptors(score_filepath):
|
||||
render audio track name, echo simulator name, test data generator name,
|
||||
test data generator parameters as string, evaluation score name).
|
||||
"""
|
||||
fields = score_filepath.split(os.sep)[-7:]
|
||||
extract_name = lambda index, reg_expr: (
|
||||
reg_expr.match(fields[index]).groups(0)[0])
|
||||
return (
|
||||
extract_name(0, RE_CONFIG_NAME),
|
||||
extract_name(1, RE_CAPTURE_NAME),
|
||||
extract_name(2, RE_RENDER_NAME),
|
||||
extract_name(3, RE_ECHO_SIM_NAME),
|
||||
extract_name(4, RE_TEST_DATA_GEN_NAME),
|
||||
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
|
||||
extract_name(6, RE_SCORE_NAME),
|
||||
)
|
||||
fields = score_filepath.split(os.sep)[-7:]
|
||||
extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]).
|
||||
groups(0)[0])
|
||||
return (
|
||||
extract_name(0, RE_CONFIG_NAME),
|
||||
extract_name(1, RE_CAPTURE_NAME),
|
||||
extract_name(2, RE_RENDER_NAME),
|
||||
extract_name(3, RE_ECHO_SIM_NAME),
|
||||
extract_name(4, RE_TEST_DATA_GEN_NAME),
|
||||
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
|
||||
extract_name(6, RE_SCORE_NAME),
|
||||
)
|
||||
|
||||
|
||||
def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, score_name, args):
|
||||
"""Decides whether excluding a score.
|
||||
"""Decides whether excluding a score.
|
||||
|
||||
A set of optional regular expressions in args is used to determine if the
|
||||
score should be excluded (depending on its |*_name| descriptors).
|
||||
@ -125,27 +141,27 @@ def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
|
||||
Returns:
|
||||
A boolean.
|
||||
"""
|
||||
value_regexpr_pairs = [
|
||||
(config_name, args.config_names),
|
||||
(capture_name, args.capture_names),
|
||||
(render_name, args.render_names),
|
||||
(echo_simulator_name, args.echo_simulator_names),
|
||||
(test_data_gen_name, args.test_data_generators),
|
||||
(score_name, args.eval_scores),
|
||||
]
|
||||
value_regexpr_pairs = [
|
||||
(config_name, args.config_names),
|
||||
(capture_name, args.capture_names),
|
||||
(render_name, args.render_names),
|
||||
(echo_simulator_name, args.echo_simulator_names),
|
||||
(test_data_gen_name, args.test_data_generators),
|
||||
(score_name, args.eval_scores),
|
||||
]
|
||||
|
||||
# Score accepted if each value matches the corresponding regular expression.
|
||||
for value, regexpr in value_regexpr_pairs:
|
||||
if regexpr is None:
|
||||
continue
|
||||
if not regexpr.match(value):
|
||||
return True
|
||||
# Score accepted if each value matches the corresponding regular expression.
|
||||
for value, regexpr in value_regexpr_pairs:
|
||||
if regexpr is None:
|
||||
continue
|
||||
if not regexpr.match(value):
|
||||
return True
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def FindScores(src_path, args):
|
||||
"""Given a search path, find scores and return a DataFrame object.
|
||||
"""Given a search path, find scores and return a DataFrame object.
|
||||
|
||||
Args:
|
||||
src_path: Search path pattern.
|
||||
@ -154,89 +170,74 @@ def FindScores(src_path, args):
|
||||
Returns:
|
||||
A DataFrame object.
|
||||
"""
|
||||
# Get scores.
|
||||
scores = []
|
||||
for score_filepath in glob.iglob(src_path):
|
||||
# Extract score descriptor fields from the path.
|
||||
(config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name) = _GetScoreDescriptors(score_filepath)
|
||||
# Get scores.
|
||||
scores = []
|
||||
for score_filepath in glob.iglob(src_path):
|
||||
# Extract score descriptor fields from the path.
|
||||
(config_name, capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, test_data_gen_params,
|
||||
score_name) = _GetScoreDescriptors(score_filepath)
|
||||
|
||||
# Ignore the score if required.
|
||||
if _ExcludeScore(
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
score_name,
|
||||
args):
|
||||
logging.info(
|
||||
'ignored score: %s %s %s %s %s %s',
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
score_name)
|
||||
continue
|
||||
# Ignore the score if required.
|
||||
if _ExcludeScore(config_name, capture_name, render_name,
|
||||
echo_simulator_name, test_data_gen_name, score_name,
|
||||
args):
|
||||
logging.info('ignored score: %s %s %s %s %s %s', config_name,
|
||||
capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, score_name)
|
||||
continue
|
||||
|
||||
# Read metadata and score.
|
||||
metadata = data_access.Metadata.LoadAudioTestDataPaths(
|
||||
os.path.split(score_filepath)[0])
|
||||
score = data_access.ScoreFile.Load(score_filepath)
|
||||
# Read metadata and score.
|
||||
metadata = data_access.Metadata.LoadAudioTestDataPaths(
|
||||
os.path.split(score_filepath)[0])
|
||||
score = data_access.ScoreFile.Load(score_filepath)
|
||||
|
||||
# Add a score with its descriptor fields.
|
||||
scores.append((
|
||||
metadata['clean_capture_input_filepath'],
|
||||
metadata['echo_free_capture_filepath'],
|
||||
metadata['echo_filepath'],
|
||||
metadata['render_filepath'],
|
||||
metadata['capture_filepath'],
|
||||
metadata['apm_output_filepath'],
|
||||
metadata['apm_reference_filepath'],
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name,
|
||||
score,
|
||||
))
|
||||
# Add a score with its descriptor fields.
|
||||
scores.append((
|
||||
metadata['clean_capture_input_filepath'],
|
||||
metadata['echo_free_capture_filepath'],
|
||||
metadata['echo_filepath'],
|
||||
metadata['render_filepath'],
|
||||
metadata['capture_filepath'],
|
||||
metadata['apm_output_filepath'],
|
||||
metadata['apm_reference_filepath'],
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name,
|
||||
score,
|
||||
))
|
||||
|
||||
return pd.DataFrame(
|
||||
data=scores,
|
||||
columns=(
|
||||
'clean_capture_input_filepath',
|
||||
'echo_free_capture_filepath',
|
||||
'echo_filepath',
|
||||
'render_filepath',
|
||||
'capture_filepath',
|
||||
'apm_output_filepath',
|
||||
'apm_reference_filepath',
|
||||
'apm_config',
|
||||
'capture',
|
||||
'render',
|
||||
'echo_simulator',
|
||||
'test_data_gen',
|
||||
'test_data_gen_params',
|
||||
'eval_score_name',
|
||||
'score',
|
||||
))
|
||||
return pd.DataFrame(data=scores,
|
||||
columns=(
|
||||
'clean_capture_input_filepath',
|
||||
'echo_free_capture_filepath',
|
||||
'echo_filepath',
|
||||
'render_filepath',
|
||||
'capture_filepath',
|
||||
'apm_output_filepath',
|
||||
'apm_reference_filepath',
|
||||
'apm_config',
|
||||
'capture',
|
||||
'render',
|
||||
'echo_simulator',
|
||||
'test_data_gen',
|
||||
'test_data_gen_params',
|
||||
'eval_score_name',
|
||||
'score',
|
||||
))
|
||||
|
||||
|
||||
def ConstructSrcPath(args):
|
||||
return os.path.join(
|
||||
args.output_dir,
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + '*')
|
||||
return os.path.join(
|
||||
args.output_dir,
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + '*')
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Data access utility functions and classes.
|
||||
"""
|
||||
|
||||
@ -14,29 +13,29 @@ import os
|
||||
|
||||
|
||||
def MakeDirectory(path):
|
||||
"""Makes a directory recursively without rising exceptions if existing.
|
||||
"""Makes a directory recursively without rising exceptions if existing.
|
||||
|
||||
Args:
|
||||
path: path to the directory to be created.
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
return
|
||||
os.makedirs(path)
|
||||
if os.path.exists(path):
|
||||
return
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
"""Data access class to save and load metadata.
|
||||
"""Data access class to save and load metadata.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
_GENERIC_METADATA_SUFFIX = '.mdata'
|
||||
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
||||
_GENERIC_METADATA_SUFFIX = '.mdata'
|
||||
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
||||
|
||||
@classmethod
|
||||
def LoadFileMetadata(cls, filepath):
|
||||
"""Loads generic metadata linked to a file.
|
||||
@classmethod
|
||||
def LoadFileMetadata(cls, filepath):
|
||||
"""Loads generic metadata linked to a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the metadata file to read.
|
||||
@ -44,23 +43,23 @@ class Metadata(object):
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
|
||||
return json.load(f)
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def SaveFileMetadata(cls, filepath, metadata):
|
||||
"""Saves generic metadata linked to a file.
|
||||
@classmethod
|
||||
def SaveFileMetadata(cls, filepath, metadata):
|
||||
"""Saves generic metadata linked to a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the metadata file to write.
|
||||
metadata: a dict.
|
||||
"""
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
@classmethod
|
||||
def LoadAudioTestDataPaths(cls, metadata_path):
|
||||
"""Loads the input and the reference audio track paths.
|
||||
@classmethod
|
||||
def LoadAudioTestDataPaths(cls, metadata_path):
|
||||
"""Loads the input and the reference audio track paths.
|
||||
|
||||
Args:
|
||||
metadata_path: path to the directory containing the metadata file.
|
||||
@ -68,14 +67,14 @@ class Metadata(object):
|
||||
Returns:
|
||||
Tuple with the paths to the input and output audio tracks.
|
||||
"""
|
||||
metadata_filepath = os.path.join(
|
||||
metadata_path, cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(metadata_filepath) as f:
|
||||
return json.load(f)
|
||||
metadata_filepath = os.path.join(metadata_path,
|
||||
cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(metadata_filepath) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
|
||||
"""Saves the input and the reference audio track paths.
|
||||
@classmethod
|
||||
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
|
||||
"""Saves the input and the reference audio track paths.
|
||||
|
||||
Args:
|
||||
output_path: path to the directory containing the metadata file.
|
||||
@ -83,23 +82,24 @@ class Metadata(object):
|
||||
Keyword Args:
|
||||
filepaths: collection of audio track file paths to save.
|
||||
"""
|
||||
output_filepath = os.path.join(output_path, cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(output_filepath, 'w') as f:
|
||||
json.dump(filepaths, f)
|
||||
output_filepath = os.path.join(output_path,
|
||||
cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(output_filepath, 'w') as f:
|
||||
json.dump(filepaths, f)
|
||||
|
||||
|
||||
class AudioProcConfigFile(object):
|
||||
"""Data access to load/save APM simulator argument lists.
|
||||
"""Data access to load/save APM simulator argument lists.
|
||||
|
||||
The arguments stored in the config files are used to control the APM flags.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a configuration file for an APM simulator.
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a configuration file for an APM simulator.
|
||||
|
||||
Args:
|
||||
filepath: path to the configuration file.
|
||||
@ -107,31 +107,31 @@ class AudioProcConfigFile(object):
|
||||
Returns:
|
||||
A dict containing the configuration.
|
||||
"""
|
||||
with open(filepath) as f:
|
||||
return json.load(f)
|
||||
with open(filepath) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def Save(cls, filepath, config):
|
||||
"""Saves a configuration file for an APM simulator.
|
||||
@classmethod
|
||||
def Save(cls, filepath, config):
|
||||
"""Saves a configuration file for an APM simulator.
|
||||
|
||||
Args:
|
||||
filepath: path to the configuration file.
|
||||
config: a dict containing the configuration.
|
||||
"""
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(config, f)
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
class ScoreFile(object):
|
||||
"""Data access class to save and load float scalar scores.
|
||||
"""Data access class to save and load float scalar scores.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a score from file.
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a score from file.
|
||||
|
||||
Args:
|
||||
filepath: path to the score file.
|
||||
@ -139,16 +139,16 @@ class ScoreFile(object):
|
||||
Returns:
|
||||
A float encoding the score.
|
||||
"""
|
||||
with open(filepath) as f:
|
||||
return float(f.readline().strip())
|
||||
with open(filepath) as f:
|
||||
return float(f.readline().strip())
|
||||
|
||||
@classmethod
|
||||
def Save(cls, filepath, score):
|
||||
"""Saves a score into a file.
|
||||
@classmethod
|
||||
def Save(cls, filepath, score):
|
||||
"""Saves a score into a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the score file.
|
||||
score: float encoding the score.
|
||||
"""
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('{0:f}\n'.format(score))
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('{0:f}\n'.format(score))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Echo path simulation module.
|
||||
"""
|
||||
|
||||
@ -16,21 +15,21 @@ from . import signal_processing
|
||||
|
||||
|
||||
class EchoPathSimulator(object):
|
||||
"""Abstract class for the echo path simulators.
|
||||
"""Abstract class for the echo path simulators.
|
||||
|
||||
In general, an echo path simulator is a function of the render signal and
|
||||
simulates the propagation of the latter into the microphone (e.g., due to
|
||||
mechanical or electrical paths).
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Creates the echo signal and stores it in an audio file (abstract method).
|
||||
def Simulate(self, output_path):
|
||||
"""Creates the echo signal and stores it in an audio file (abstract method).
|
||||
|
||||
Args:
|
||||
output_path: Path in which any output can be saved.
|
||||
@ -38,11 +37,11 @@ class EchoPathSimulator(object):
|
||||
Returns:
|
||||
Path to the generated audio track file or None if no echo is present.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EchoPathSimulator implementation.
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EchoPathSimulator implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend
|
||||
EchoPathSimulator.
|
||||
@ -52,85 +51,86 @@ class EchoPathSimulator(object):
|
||||
class NoEchoPathSimulator(EchoPathSimulator):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class NoEchoPathSimulator(EchoPathSimulator):
|
||||
"""Simulates absence of echo."""
|
||||
"""Simulates absence of echo."""
|
||||
|
||||
NAME = 'noecho'
|
||||
NAME = 'noecho'
|
||||
|
||||
def __init__(self):
|
||||
EchoPathSimulator.__init__(self)
|
||||
def __init__(self):
|
||||
EchoPathSimulator.__init__(self)
|
||||
|
||||
def Simulate(self, output_path):
|
||||
return None
|
||||
def Simulate(self, output_path):
|
||||
return None
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class LinearEchoPathSimulator(EchoPathSimulator):
|
||||
"""Simulates linear echo path.
|
||||
"""Simulates linear echo path.
|
||||
|
||||
This class applies a given impulse response to the render input and then it
|
||||
sums the signal to the capture input signal.
|
||||
"""
|
||||
|
||||
NAME = 'linear'
|
||||
NAME = 'linear'
|
||||
|
||||
def __init__(self, render_input_filepath, impulse_response):
|
||||
"""
|
||||
def __init__(self, render_input_filepath, impulse_response):
|
||||
"""
|
||||
Args:
|
||||
render_input_filepath: Render audio track file.
|
||||
impulse_response: list or numpy vector of float values.
|
||||
"""
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
self._impulse_response = impulse_response
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
self._impulse_response = impulse_response
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Simulates linear echo path."""
|
||||
# Form the file name with a hash of the impulse response.
|
||||
impulse_response_hash = hashlib.sha256(
|
||||
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
|
||||
echo_filepath = os.path.join(output_path, 'linear_echo_{}.wav'.format(
|
||||
impulse_response_hash))
|
||||
def Simulate(self, output_path):
|
||||
"""Simulates linear echo path."""
|
||||
# Form the file name with a hash of the impulse response.
|
||||
impulse_response_hash = hashlib.sha256(
|
||||
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
|
||||
echo_filepath = os.path.join(
|
||||
output_path, 'linear_echo_{}.wav'.format(impulse_response_hash))
|
||||
|
||||
# If the simulated echo audio track file does not exists, create it.
|
||||
if not os.path.exists(echo_filepath):
|
||||
render = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._render_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
render, self._impulse_response)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(echo_filepath, echo)
|
||||
# If the simulated echo audio track file does not exists, create it.
|
||||
if not os.path.exists(echo_filepath):
|
||||
render = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._render_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
render, self._impulse_response)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
echo_filepath, echo)
|
||||
|
||||
return echo_filepath
|
||||
return echo_filepath
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class RecordedEchoPathSimulator(EchoPathSimulator):
|
||||
"""Uses recorded echo.
|
||||
"""Uses recorded echo.
|
||||
|
||||
This class uses the clean capture input file name to build the file name of
|
||||
the corresponding recording containing echo (a predefined suffix is used).
|
||||
Such a file is expected to be already existing.
|
||||
"""
|
||||
|
||||
NAME = 'recorded'
|
||||
NAME = 'recorded'
|
||||
|
||||
_FILE_NAME_SUFFIX = '_echo'
|
||||
_FILE_NAME_SUFFIX = '_echo'
|
||||
|
||||
def __init__(self, render_input_filepath):
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
def __init__(self, render_input_filepath):
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Uses recorded echo path."""
|
||||
path, file_name_ext = os.path.split(self._render_input_filepath)
|
||||
file_name, file_ext = os.path.splitext(file_name_ext)
|
||||
echo_filepath = os.path.join(path, '{}{}{}'.format(
|
||||
file_name, self._FILE_NAME_SUFFIX, file_ext))
|
||||
assert os.path.exists(echo_filepath), (
|
||||
'cannot find the echo audio track file {}'.format(echo_filepath))
|
||||
return echo_filepath
|
||||
def Simulate(self, output_path):
|
||||
"""Uses recorded echo path."""
|
||||
path, file_name_ext = os.path.split(self._render_input_filepath)
|
||||
file_name, file_ext = os.path.splitext(file_name_ext)
|
||||
echo_filepath = os.path.join(
|
||||
path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext))
|
||||
assert os.path.exists(echo_filepath), (
|
||||
'cannot find the echo audio track file {}'.format(echo_filepath))
|
||||
return echo_filepath
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Echo path simulation factory module.
|
||||
"""
|
||||
|
||||
@ -16,16 +15,16 @@ from . import echo_path_simulation
|
||||
|
||||
class EchoPathSimulatorFactory(object):
|
||||
|
||||
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
|
||||
# realistic impulse response.
|
||||
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0]*(20 * 48) + [0.15])
|
||||
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
|
||||
# realistic impulse response.
|
||||
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15])
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
|
||||
"""Creates an EchoPathSimulator instance given a class object.
|
||||
@classmethod
|
||||
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
|
||||
"""Creates an EchoPathSimulator instance given a class object.
|
||||
|
||||
Args:
|
||||
echo_path_simulator_class: EchoPathSimulator class object (not an
|
||||
@ -35,14 +34,15 @@ class EchoPathSimulatorFactory(object):
|
||||
Returns:
|
||||
An EchoPathSimulator instance.
|
||||
"""
|
||||
assert render_input_filepath is not None or (
|
||||
echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator)
|
||||
assert render_input_filepath is not None or (
|
||||
echo_path_simulator_class ==
|
||||
echo_path_simulation.NoEchoPathSimulator)
|
||||
|
||||
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
|
||||
return echo_path_simulation.NoEchoPathSimulator()
|
||||
elif echo_path_simulator_class == (
|
||||
echo_path_simulation.LinearEchoPathSimulator):
|
||||
return echo_path_simulation.LinearEchoPathSimulator(
|
||||
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
|
||||
else:
|
||||
return echo_path_simulator_class(render_input_filepath)
|
||||
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
|
||||
return echo_path_simulation.NoEchoPathSimulator()
|
||||
elif echo_path_simulator_class == (
|
||||
echo_path_simulation.LinearEchoPathSimulator):
|
||||
return echo_path_simulation.LinearEchoPathSimulator(
|
||||
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
|
||||
else:
|
||||
return echo_path_simulator_class(render_input_filepath)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the echo path simulation module.
|
||||
"""
|
||||
|
||||
@ -22,60 +21,62 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestEchoPathSimulators(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Create and save white noise.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._audio_track_num_samples = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
|
||||
self._audio_track_filepath = os.path.join(self._tmp_path, 'white_noise.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._audio_track_filepath, white_noise)
|
||||
# Create and save white noise.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._audio_track_num_samples = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
|
||||
self._audio_track_filepath = os.path.join(self._tmp_path,
|
||||
'white_noise.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._audio_track_filepath, white_noise)
|
||||
|
||||
# Make a copy the white noise audio track file; it will be used by
|
||||
# echo_path_simulation.RecordedEchoPathSimulator.
|
||||
shutil.copy(self._audio_track_filepath, os.path.join(
|
||||
self._tmp_path, 'white_noise_echo.wav'))
|
||||
# Make a copy the white noise audio track file; it will be used by
|
||||
# echo_path_simulation.RecordedEchoPathSimulator.
|
||||
shutil.copy(self._audio_track_filepath,
|
||||
os.path.join(self._tmp_path, 'white_noise_echo.wav'))
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testRegisteredClasses(self):
|
||||
# Check that there is at least one registered echo path simulator.
|
||||
registered_classes = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
def testRegisteredClasses(self):
|
||||
# Check that there is at least one registered echo path simulator.
|
||||
registered_classes = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance factory.
|
||||
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
|
||||
# Instance factory.
|
||||
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
|
||||
|
||||
# Try each registered echo path simulator.
|
||||
for echo_path_simulator_name in registered_classes:
|
||||
simulator = factory.GetInstance(
|
||||
echo_path_simulator_class=registered_classes[
|
||||
echo_path_simulator_name],
|
||||
render_input_filepath=self._audio_track_filepath)
|
||||
# Try each registered echo path simulator.
|
||||
for echo_path_simulator_name in registered_classes:
|
||||
simulator = factory.GetInstance(
|
||||
echo_path_simulator_class=registered_classes[
|
||||
echo_path_simulator_name],
|
||||
render_input_filepath=self._audio_track_filepath)
|
||||
|
||||
echo_filepath = simulator.Simulate(self._tmp_path)
|
||||
if echo_filepath is None:
|
||||
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
|
||||
echo_path_simulator_name)
|
||||
# No other tests in this case.
|
||||
continue
|
||||
echo_filepath = simulator.Simulate(self._tmp_path)
|
||||
if echo_filepath is None:
|
||||
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
|
||||
echo_path_simulator_name)
|
||||
# No other tests in this case.
|
||||
continue
|
||||
|
||||
# Check that the echo audio track file exists and its length is greater or
|
||||
# equal to that of the render audio track.
|
||||
self.assertTrue(os.path.exists(echo_filepath))
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
|
||||
self.assertGreaterEqual(
|
||||
signal_processing.SignalProcessingUtils.CountSamples(echo),
|
||||
self._audio_track_num_samples)
|
||||
# Check that the echo audio track file exists and its length is greater or
|
||||
# equal to that of the render audio track.
|
||||
self.assertTrue(os.path.exists(echo_filepath))
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
echo_filepath)
|
||||
self.assertGreaterEqual(
|
||||
signal_processing.SignalProcessingUtils.CountSamples(echo),
|
||||
self._audio_track_num_samples)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Evaluation score abstract class and implementations.
|
||||
"""
|
||||
|
||||
@ -17,10 +16,10 @@ import subprocess
|
||||
import sys
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
@ -29,23 +28,23 @@ from . import signal_processing
|
||||
|
||||
class EvaluationScore(object):
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
self._score_filename_prefix = score_filename_prefix
|
||||
self._input_signal_metadata = None
|
||||
self._reference_signal = None
|
||||
self._reference_signal_filepath = None
|
||||
self._tested_signal = None
|
||||
self._tested_signal_filepath = None
|
||||
self._output_filepath = None
|
||||
self._score = None
|
||||
self._render_signal_filepath = None
|
||||
def __init__(self, score_filename_prefix):
|
||||
self._score_filename_prefix = score_filename_prefix
|
||||
self._input_signal_metadata = None
|
||||
self._reference_signal = None
|
||||
self._reference_signal_filepath = None
|
||||
self._tested_signal = None
|
||||
self._tested_signal_filepath = None
|
||||
self._output_filepath = None
|
||||
self._score = None
|
||||
self._render_signal_filepath = None
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EvaluationScore implementation.
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EvaluationScore implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend EvaluationScore.
|
||||
Example usage:
|
||||
@ -54,91 +53,90 @@ class EvaluationScore(object):
|
||||
class AudioLevelScore(EvaluationScore):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_filepath
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_filepath
|
||||
|
||||
@property
|
||||
def score(self):
|
||||
return self._score
|
||||
@property
|
||||
def score(self):
|
||||
return self._score
|
||||
|
||||
def SetInputSignalMetadata(self, metadata):
|
||||
"""Sets input signal metadata.
|
||||
def SetInputSignalMetadata(self, metadata):
|
||||
"""Sets input signal metadata.
|
||||
|
||||
Args:
|
||||
metadata: dict instance.
|
||||
"""
|
||||
self._input_signal_metadata = metadata
|
||||
self._input_signal_metadata = metadata
|
||||
|
||||
def SetReferenceSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as reference signal.
|
||||
def SetReferenceSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as reference signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the reference audio track.
|
||||
"""
|
||||
self._reference_signal_filepath = filepath
|
||||
self._reference_signal_filepath = filepath
|
||||
|
||||
def SetTestedSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as test signal.
|
||||
def SetTestedSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as test signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the test audio track.
|
||||
"""
|
||||
self._tested_signal_filepath = filepath
|
||||
self._tested_signal_filepath = filepath
|
||||
|
||||
def SetRenderSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as render signal.
|
||||
def SetRenderSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as render signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the test audio track.
|
||||
"""
|
||||
self._render_signal_filepath = filepath
|
||||
self._render_signal_filepath = filepath
|
||||
|
||||
def Run(self, output_path):
|
||||
"""Extracts the score for the set test data pair.
|
||||
def Run(self, output_path):
|
||||
"""Extracts the score for the set test data pair.
|
||||
|
||||
Args:
|
||||
output_path: path to the directory where the output is written.
|
||||
"""
|
||||
self._output_filepath = os.path.join(
|
||||
output_path, self._score_filename_prefix + self.NAME + '.txt')
|
||||
try:
|
||||
# If the score has already been computed, load.
|
||||
self._LoadScore()
|
||||
logging.debug('score found and loaded')
|
||||
except IOError:
|
||||
# Compute the score.
|
||||
logging.debug('score not found, compute')
|
||||
self._Run(output_path)
|
||||
self._output_filepath = os.path.join(
|
||||
output_path, self._score_filename_prefix + self.NAME + '.txt')
|
||||
try:
|
||||
# If the score has already been computed, load.
|
||||
self._LoadScore()
|
||||
logging.debug('score found and loaded')
|
||||
except IOError:
|
||||
# Compute the score.
|
||||
logging.debug('score not found, compute')
|
||||
self._Run(output_path)
|
||||
|
||||
def _Run(self, output_path):
|
||||
# Abstract method.
|
||||
raise NotImplementedError()
|
||||
def _Run(self, output_path):
|
||||
# Abstract method.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _LoadReferenceSignal(self):
|
||||
assert self._reference_signal_filepath is not None
|
||||
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._reference_signal_filepath)
|
||||
def _LoadReferenceSignal(self):
|
||||
assert self._reference_signal_filepath is not None
|
||||
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._reference_signal_filepath)
|
||||
|
||||
def _LoadTestedSignal(self):
|
||||
assert self._tested_signal_filepath is not None
|
||||
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._tested_signal_filepath)
|
||||
def _LoadTestedSignal(self):
|
||||
assert self._tested_signal_filepath is not None
|
||||
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._tested_signal_filepath)
|
||||
|
||||
def _LoadScore(self):
|
||||
return data_access.ScoreFile.Load(self._output_filepath)
|
||||
|
||||
def _LoadScore(self):
|
||||
return data_access.ScoreFile.Load(self._output_filepath)
|
||||
|
||||
def _SaveScore(self):
|
||||
return data_access.ScoreFile.Save(self._output_filepath, self._score)
|
||||
def _SaveScore(self):
|
||||
return data_access.ScoreFile.Save(self._output_filepath, self._score)
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class AudioLevelPeakScore(EvaluationScore):
|
||||
"""Peak audio level score.
|
||||
"""Peak audio level score.
|
||||
|
||||
Defined as the difference between the peak audio level of the tested and
|
||||
the reference signals.
|
||||
@ -148,21 +146,21 @@ class AudioLevelPeakScore(EvaluationScore):
|
||||
Worst case: +/-inf dB
|
||||
"""
|
||||
|
||||
NAME = 'audio_level_peak'
|
||||
NAME = 'audio_level_peak'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
|
||||
self._SaveScore()
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
|
||||
self._SaveScore()
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class MeanAudioLevelScore(EvaluationScore):
|
||||
"""Mean audio level score.
|
||||
"""Mean audio level score.
|
||||
|
||||
Defined as the difference between the mean audio level of the tested and
|
||||
the reference signals.
|
||||
@ -172,29 +170,30 @@ class MeanAudioLevelScore(EvaluationScore):
|
||||
Worst case: +/-inf dB
|
||||
"""
|
||||
|
||||
NAME = 'audio_level_mean'
|
||||
NAME = 'audio_level_mean'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
|
||||
dbfs_diffs_sum = 0.0
|
||||
seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000
|
||||
for t in range(seconds):
|
||||
t0 = t * seconds
|
||||
t1 = t0 + seconds
|
||||
dbfs_diffs_sum += (
|
||||
self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS)
|
||||
self._score = dbfs_diffs_sum / float(seconds)
|
||||
self._SaveScore()
|
||||
dbfs_diffs_sum = 0.0
|
||||
seconds = min(len(self._tested_signal), len(
|
||||
self._reference_signal)) // 1000
|
||||
for t in range(seconds):
|
||||
t0 = t * seconds
|
||||
t1 = t0 + seconds
|
||||
dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
|
||||
self._reference_signal[t0:t1].dBFS)
|
||||
self._score = dbfs_diffs_sum / float(seconds)
|
||||
self._SaveScore()
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class EchoMetric(EvaluationScore):
|
||||
"""Echo score.
|
||||
"""Echo score.
|
||||
|
||||
Proportion of detected echo.
|
||||
|
||||
@ -203,46 +202,47 @@ class EchoMetric(EvaluationScore):
|
||||
Worst case: 1
|
||||
"""
|
||||
|
||||
NAME = 'echo_metric'
|
||||
NAME = 'echo_metric'
|
||||
|
||||
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
# POLQA binary file path.
|
||||
self._echo_detector_bin_filepath = echo_detector_bin_filepath
|
||||
if not os.path.exists(self._echo_detector_bin_filepath):
|
||||
logging.error('cannot find EchoMetric tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
# POLQA binary file path.
|
||||
self._echo_detector_bin_filepath = echo_detector_bin_filepath
|
||||
if not os.path.exists(self._echo_detector_bin_filepath):
|
||||
logging.error('cannot find EchoMetric tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
self._echo_detector_bin_path, _ = os.path.split(
|
||||
self._echo_detector_bin_filepath)
|
||||
self._echo_detector_bin_path, _ = os.path.split(
|
||||
self._echo_detector_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
echo_detector_out_filepath = os.path.join(output_path, 'echo_detector.out')
|
||||
if os.path.exists(echo_detector_out_filepath):
|
||||
os.unlink(echo_detector_out_filepath)
|
||||
def _Run(self, output_path):
|
||||
echo_detector_out_filepath = os.path.join(output_path,
|
||||
'echo_detector.out')
|
||||
if os.path.exists(echo_detector_out_filepath):
|
||||
os.unlink(echo_detector_out_filepath)
|
||||
|
||||
logging.debug("Render signal filepath: %s", self._render_signal_filepath)
|
||||
if not os.path.exists(self._render_signal_filepath):
|
||||
logging.error("Render input required for evaluating the echo metric.")
|
||||
logging.debug("Render signal filepath: %s",
|
||||
self._render_signal_filepath)
|
||||
if not os.path.exists(self._render_signal_filepath):
|
||||
logging.error(
|
||||
"Render input required for evaluating the echo metric.")
|
||||
|
||||
args = [
|
||||
self._echo_detector_bin_filepath,
|
||||
'--output_file', echo_detector_out_filepath,
|
||||
'--',
|
||||
'-i', self._tested_signal_filepath,
|
||||
'-ri', self._render_signal_filepath
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._echo_detector_bin_path)
|
||||
args = [
|
||||
self._echo_detector_bin_filepath, '--output_file',
|
||||
echo_detector_out_filepath, '--', '-i',
|
||||
self._tested_signal_filepath, '-ri', self._render_signal_filepath
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._echo_detector_bin_path)
|
||||
|
||||
# Parse Echo detector tool output and extract the score.
|
||||
self._score = self._ParseOutputFile(echo_detector_out_filepath)
|
||||
self._SaveScore()
|
||||
# Parse Echo detector tool output and extract the score.
|
||||
self._score = self._ParseOutputFile(echo_detector_out_filepath)
|
||||
self._SaveScore()
|
||||
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, echo_metric_file_path):
|
||||
"""
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, echo_metric_file_path):
|
||||
"""
|
||||
Parses the POLQA tool output formatted as a table ('-t' option).
|
||||
|
||||
Args:
|
||||
@ -251,12 +251,13 @@ class EchoMetric(EvaluationScore):
|
||||
Returns:
|
||||
The score as a number in [0, 1].
|
||||
"""
|
||||
with open(echo_metric_file_path) as f:
|
||||
return float(f.read())
|
||||
with open(echo_metric_file_path) as f:
|
||||
return float(f.read())
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class PolqaScore(EvaluationScore):
|
||||
"""POLQA score.
|
||||
"""POLQA score.
|
||||
|
||||
See http://www.polqa.info/.
|
||||
|
||||
@ -265,44 +266,51 @@ class PolqaScore(EvaluationScore):
|
||||
Worst case: 1.0
|
||||
"""
|
||||
|
||||
NAME = 'polqa'
|
||||
NAME = 'polqa'
|
||||
|
||||
def __init__(self, score_filename_prefix, polqa_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
def __init__(self, score_filename_prefix, polqa_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
# POLQA binary file path.
|
||||
self._polqa_bin_filepath = polqa_bin_filepath
|
||||
if not os.path.exists(self._polqa_bin_filepath):
|
||||
logging.error('cannot find POLQA tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
# POLQA binary file path.
|
||||
self._polqa_bin_filepath = polqa_bin_filepath
|
||||
if not os.path.exists(self._polqa_bin_filepath):
|
||||
logging.error('cannot find POLQA tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
|
||||
if os.path.exists(polqa_out_filepath):
|
||||
os.unlink(polqa_out_filepath)
|
||||
def _Run(self, output_path):
|
||||
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
|
||||
if os.path.exists(polqa_out_filepath):
|
||||
os.unlink(polqa_out_filepath)
|
||||
|
||||
args = [
|
||||
self._polqa_bin_filepath, '-t', '-q', '-Overwrite',
|
||||
'-Ref', self._reference_signal_filepath,
|
||||
'-Test', self._tested_signal_filepath,
|
||||
'-LC', 'NB',
|
||||
'-Out', polqa_out_filepath,
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._polqa_tool_path)
|
||||
args = [
|
||||
self._polqa_bin_filepath,
|
||||
'-t',
|
||||
'-q',
|
||||
'-Overwrite',
|
||||
'-Ref',
|
||||
self._reference_signal_filepath,
|
||||
'-Test',
|
||||
self._tested_signal_filepath,
|
||||
'-LC',
|
||||
'NB',
|
||||
'-Out',
|
||||
polqa_out_filepath,
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._polqa_tool_path)
|
||||
|
||||
# Parse POLQA tool output and extract the score.
|
||||
polqa_output = self._ParseOutputFile(polqa_out_filepath)
|
||||
self._score = float(polqa_output['PolqaScore'])
|
||||
# Parse POLQA tool output and extract the score.
|
||||
polqa_output = self._ParseOutputFile(polqa_out_filepath)
|
||||
self._score = float(polqa_output['PolqaScore'])
|
||||
|
||||
self._SaveScore()
|
||||
self._SaveScore()
|
||||
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, polqa_out_filepath):
|
||||
"""
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, polqa_out_filepath):
|
||||
"""
|
||||
Parses the POLQA tool output formatted as a table ('-t' option).
|
||||
|
||||
Args:
|
||||
@ -311,29 +319,32 @@ class PolqaScore(EvaluationScore):
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
data = []
|
||||
with open(polqa_out_filepath) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) == 0 or line.startswith('*'):
|
||||
# Ignore comments.
|
||||
continue
|
||||
# Read fields.
|
||||
data.append(re.split(r'\t+', line))
|
||||
data = []
|
||||
with open(polqa_out_filepath) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) == 0 or line.startswith('*'):
|
||||
# Ignore comments.
|
||||
continue
|
||||
# Read fields.
|
||||
data.append(re.split(r'\t+', line))
|
||||
|
||||
# Two rows expected (header and values).
|
||||
assert len(data) == 2, 'Cannot parse POLQA output'
|
||||
number_of_fields = len(data[0])
|
||||
assert number_of_fields == len(data[1])
|
||||
# Two rows expected (header and values).
|
||||
assert len(data) == 2, 'Cannot parse POLQA output'
|
||||
number_of_fields = len(data[0])
|
||||
assert number_of_fields == len(data[1])
|
||||
|
||||
# Build and return a dictionary with field names (header) as keys and the
|
||||
# corresponding field values as values.
|
||||
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
|
||||
# Build and return a dictionary with field names (header) as keys and the
|
||||
# corresponding field values as values.
|
||||
return {
|
||||
data[0][index]: data[1][index]
|
||||
for index in range(number_of_fields)
|
||||
}
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class TotalHarmonicDistorsionScore(EvaluationScore):
|
||||
"""Total harmonic distorsion plus noise score.
|
||||
"""Total harmonic distorsion plus noise score.
|
||||
|
||||
Total harmonic distorsion plus noise score.
|
||||
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
|
||||
@ -343,69 +354,74 @@ class TotalHarmonicDistorsionScore(EvaluationScore):
|
||||
Worst case: +inf
|
||||
"""
|
||||
|
||||
NAME = 'thd'
|
||||
NAME = 'thd'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
self._input_frequency = None
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
self._input_frequency = None
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._CheckInputSignal()
|
||||
def _Run(self, output_path):
|
||||
self._CheckInputSignal()
|
||||
|
||||
self._LoadTestedSignal()
|
||||
if self._tested_signal.channels != 1:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'unsupported number of channels')
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._tested_signal)
|
||||
self._LoadTestedSignal()
|
||||
if self._tested_signal.channels != 1:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'unsupported number of channels')
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._tested_signal)
|
||||
|
||||
# Init.
|
||||
num_samples = len(samples)
|
||||
duration = len(self._tested_signal) / 1000.0
|
||||
scaling = 2.0 / num_samples
|
||||
max_freq = self._tested_signal.frame_rate / 2
|
||||
f0_freq = float(self._input_frequency)
|
||||
t = np.linspace(0, duration, num_samples)
|
||||
# Init.
|
||||
num_samples = len(samples)
|
||||
duration = len(self._tested_signal) / 1000.0
|
||||
scaling = 2.0 / num_samples
|
||||
max_freq = self._tested_signal.frame_rate / 2
|
||||
f0_freq = float(self._input_frequency)
|
||||
t = np.linspace(0, duration, num_samples)
|
||||
|
||||
# Analyze harmonics.
|
||||
b_terms = []
|
||||
n = 1
|
||||
while f0_freq * n < max_freq:
|
||||
x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
b_terms.append(np.sqrt(x_n**2 + y_n**2))
|
||||
n += 1
|
||||
# Analyze harmonics.
|
||||
b_terms = []
|
||||
n = 1
|
||||
while f0_freq * n < max_freq:
|
||||
x_n = np.sum(
|
||||
samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
y_n = np.sum(
|
||||
samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
b_terms.append(np.sqrt(x_n**2 + y_n**2))
|
||||
n += 1
|
||||
|
||||
output_without_fundamental = samples - b_terms[0] * np.sin(
|
||||
2.0 * np.pi * f0_freq * t)
|
||||
distortion_and_noise = np.sqrt(np.sum(
|
||||
output_without_fundamental**2) * np.pi * scaling)
|
||||
output_without_fundamental = samples - b_terms[0] * np.sin(
|
||||
2.0 * np.pi * f0_freq * t)
|
||||
distortion_and_noise = np.sqrt(
|
||||
np.sum(output_without_fundamental**2) * np.pi * scaling)
|
||||
|
||||
# TODO(alessiob): Fix or remove if not needed.
|
||||
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
|
||||
# TODO(alessiob): Fix or remove if not needed.
|
||||
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
|
||||
|
||||
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
|
||||
# docstring above if accordingly.
|
||||
thd_plus_noise = distortion_and_noise / b_terms[0]
|
||||
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
|
||||
# docstring above if accordingly.
|
||||
thd_plus_noise = distortion_and_noise / b_terms[0]
|
||||
|
||||
self._score = thd_plus_noise
|
||||
self._SaveScore()
|
||||
self._score = thd_plus_noise
|
||||
self._SaveScore()
|
||||
|
||||
def _CheckInputSignal(self):
|
||||
# Check input signal and get properties.
|
||||
try:
|
||||
if self._input_signal_metadata['signal'] != 'pure_tone':
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires a pure tone as input signal')
|
||||
self._input_frequency = self._input_signal_metadata['frequency']
|
||||
if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
|
||||
self._input_signal_metadata['test_data_gen_config'] != 'default'):
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score cannot be used with any test data generator other '
|
||||
'than "identity"')
|
||||
except TypeError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires an input signal with associated metadata')
|
||||
except KeyError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'Invalid input signal metadata to compute the THD score')
|
||||
def _CheckInputSignal(self):
|
||||
# Check input signal and get properties.
|
||||
try:
|
||||
if self._input_signal_metadata['signal'] != 'pure_tone':
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires a pure tone as input signal')
|
||||
self._input_frequency = self._input_signal_metadata['frequency']
|
||||
if self._input_signal_metadata[
|
||||
'test_data_gen_name'] != 'identity' or (
|
||||
self._input_signal_metadata['test_data_gen_config'] !=
|
||||
'default'):
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score cannot be used with any test data generator other '
|
||||
'than "identity"')
|
||||
except TypeError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires an input signal with associated metadata'
|
||||
)
|
||||
except KeyError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'Invalid input signal metadata to compute the THD score')
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""EvaluationScore factory class.
|
||||
"""
|
||||
|
||||
@ -16,22 +15,22 @@ from . import eval_scores
|
||||
|
||||
|
||||
class EvaluationScoreWorkerFactory(object):
|
||||
"""Factory class used to instantiate evaluation score workers.
|
||||
"""Factory class used to instantiate evaluation score workers.
|
||||
|
||||
The ctor gets the parametrs that are used to instatiate the evaluation score
|
||||
workers.
|
||||
"""
|
||||
|
||||
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
|
||||
self._score_filename_prefix = None
|
||||
self._polqa_tool_bin_path = polqa_tool_bin_path
|
||||
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
|
||||
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
|
||||
self._score_filename_prefix = None
|
||||
self._polqa_tool_bin_path = polqa_tool_bin_path
|
||||
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
|
||||
|
||||
def SetScoreFilenamePrefix(self, prefix):
|
||||
self._score_filename_prefix = prefix
|
||||
def SetScoreFilenamePrefix(self, prefix):
|
||||
self._score_filename_prefix = prefix
|
||||
|
||||
def GetInstance(self, evaluation_score_class):
|
||||
"""Creates an EvaluationScore instance given a class object.
|
||||
def GetInstance(self, evaluation_score_class):
|
||||
"""Creates an EvaluationScore instance given a class object.
|
||||
|
||||
Args:
|
||||
evaluation_score_class: EvaluationScore class object (not an instance).
|
||||
@ -39,17 +38,18 @@ class EvaluationScoreWorkerFactory(object):
|
||||
Returns:
|
||||
An EvaluationScore instance.
|
||||
"""
|
||||
if self._score_filename_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The score file name prefix for evaluation score workers is not set')
|
||||
logging.debug(
|
||||
'factory producing a %s evaluation score', evaluation_score_class)
|
||||
if self._score_filename_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The score file name prefix for evaluation score workers is not set'
|
||||
)
|
||||
logging.debug('factory producing a %s evaluation score',
|
||||
evaluation_score_class)
|
||||
|
||||
if evaluation_score_class == eval_scores.PolqaScore:
|
||||
return eval_scores.PolqaScore(
|
||||
self._score_filename_prefix, self._polqa_tool_bin_path)
|
||||
elif evaluation_score_class == eval_scores.EchoMetric:
|
||||
return eval_scores.EchoMetric(
|
||||
self._score_filename_prefix, self._echo_metric_tool_bin_path)
|
||||
else:
|
||||
return evaluation_score_class(self._score_filename_prefix)
|
||||
if evaluation_score_class == eval_scores.PolqaScore:
|
||||
return eval_scores.PolqaScore(self._score_filename_prefix,
|
||||
self._polqa_tool_bin_path)
|
||||
elif evaluation_score_class == eval_scores.EchoMetric:
|
||||
return eval_scores.EchoMetric(self._score_filename_prefix,
|
||||
self._echo_metric_tool_bin_path)
|
||||
else:
|
||||
return evaluation_score_class(self._score_filename_prefix)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
@ -23,111 +22,116 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestEvalScores(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary output folder and two audio track files."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Create temporary output folder and two audio track files."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
|
||||
# Create fake reference and tested (i.e., APM output) audio track files.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_reference_signal = (
|
||||
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
|
||||
fake_tested_signal = (
|
||||
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
|
||||
# Create fake reference and tested (i.e., APM output) audio track files.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_reference_signal = (signal_processing.SignalProcessingUtils.
|
||||
GenerateWhiteNoise(silence))
|
||||
fake_tested_signal = (signal_processing.SignalProcessingUtils.
|
||||
GenerateWhiteNoise(silence))
|
||||
|
||||
# Save fake audio tracks.
|
||||
self._fake_reference_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_ref.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_reference_signal_filepath, fake_reference_signal)
|
||||
self._fake_tested_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_test.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_tested_signal_filepath, fake_tested_signal)
|
||||
# Save fake audio tracks.
|
||||
self._fake_reference_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_ref.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_reference_signal_filepath, fake_reference_signal)
|
||||
self._fake_tested_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_test.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_tested_signal_filepath, fake_tested_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
shutil.rmtree(self._output_path)
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
shutil.rmtree(self._output_path)
|
||||
|
||||
def testRegisteredClasses(self):
|
||||
# Evaluation score names to exclude (tested separately).
|
||||
exceptions = ['thd', 'echo_metric']
|
||||
def testRegisteredClasses(self):
|
||||
# Evaluation score names to exclude (tested separately).
|
||||
exceptions = ['thd', 'echo_metric']
|
||||
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._output_path))
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._output_path))
|
||||
|
||||
# Check that there is at least one registered evaluation score worker.
|
||||
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
# Check that there is at least one registered evaluation score worker.
|
||||
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance evaluation score workers factory with fake dependencies.
|
||||
eval_score_workers_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
))
|
||||
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
|
||||
# Instance evaluation score workers factory with fake dependencies.
|
||||
eval_score_workers_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None))
|
||||
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
|
||||
|
||||
# Try each registered evaluation score worker.
|
||||
for eval_score_name in registered_classes:
|
||||
if eval_score_name in exceptions:
|
||||
continue
|
||||
# Try each registered evaluation score worker.
|
||||
for eval_score_name in registered_classes:
|
||||
if eval_score_name in exceptions:
|
||||
continue
|
||||
|
||||
# Instance evaluation score worker.
|
||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||
registered_classes[eval_score_name])
|
||||
# Instance evaluation score worker.
|
||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||
registered_classes[eval_score_name])
|
||||
|
||||
# Set fake input metadata and reference and test file paths, then run.
|
||||
eval_score_worker.SetReferenceSignalFilepath(
|
||||
self._fake_reference_signal_filepath)
|
||||
eval_score_worker.SetTestedSignalFilepath(
|
||||
self._fake_tested_signal_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
# Set fake input metadata and reference and test file paths, then run.
|
||||
eval_score_worker.SetReferenceSignalFilepath(
|
||||
self._fake_reference_signal_filepath)
|
||||
eval_score_worker.SetTestedSignalFilepath(
|
||||
self._fake_tested_signal_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
|
||||
# Check output.
|
||||
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
|
||||
self.assertTrue(isinstance(score, float))
|
||||
# Check output.
|
||||
score = data_access.ScoreFile.Load(
|
||||
eval_score_worker.output_filepath)
|
||||
self.assertTrue(isinstance(score, float))
|
||||
|
||||
def testTotalHarmonicDistorsionScore(self):
|
||||
# Init.
|
||||
pure_tone_freq = 5000.0
|
||||
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
|
||||
eval_score_worker.SetInputSignalMetadata({
|
||||
'signal': 'pure_tone',
|
||||
'frequency': pure_tone_freq,
|
||||
'test_data_gen_name': 'identity',
|
||||
'test_data_gen_config': 'default',
|
||||
})
|
||||
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
def testTotalHarmonicDistorsionScore(self):
|
||||
# Init.
|
||||
pure_tone_freq = 5000.0
|
||||
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
|
||||
eval_score_worker.SetInputSignalMetadata({
|
||||
'signal':
|
||||
'pure_tone',
|
||||
'frequency':
|
||||
pure_tone_freq,
|
||||
'test_data_gen_name':
|
||||
'identity',
|
||||
'test_data_gen_config':
|
||||
'default',
|
||||
})
|
||||
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
|
||||
# Create 3 test signals: pure tone, pure tone + white noise, white noise
|
||||
# only.
|
||||
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, pure_tone_freq)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
template)
|
||||
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
pure_tone, white_noise)
|
||||
# Create 3 test signals: pure tone, pure tone + white noise, white noise
|
||||
# only.
|
||||
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, pure_tone_freq)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
template)
|
||||
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
pure_tone, white_noise)
|
||||
|
||||
# Compute scores for increasingly distorted pure tone signals.
|
||||
scores = [None, None, None]
|
||||
for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
|
||||
# Save signal.
|
||||
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
tmp_filepath, tested_signal)
|
||||
# Compute scores for increasingly distorted pure tone signals.
|
||||
scores = [None, None, None]
|
||||
for index, tested_signal in enumerate(
|
||||
[pure_tone, noisy_tone, white_noise]):
|
||||
# Save signal.
|
||||
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
tmp_filepath, tested_signal)
|
||||
|
||||
# Compute score.
|
||||
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
scores[index] = eval_score_worker.score
|
||||
# Compute score.
|
||||
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
scores[index] = eval_score_worker.score
|
||||
|
||||
# Remove output file to avoid caching.
|
||||
os.remove(eval_score_worker.output_filepath)
|
||||
# Remove output file to avoid caching.
|
||||
os.remove(eval_score_worker.output_filepath)
|
||||
|
||||
# Validate scores (lowest score with a pure tone).
|
||||
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
|
||||
# Validate scores (lowest score with a pure tone).
|
||||
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Evaluator of the APM module.
|
||||
"""
|
||||
|
||||
@ -13,17 +12,17 @@ import logging
|
||||
|
||||
|
||||
class ApmModuleEvaluator(object):
|
||||
"""APM evaluator class.
|
||||
"""APM evaluator class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Run(cls, evaluation_score_workers, apm_input_metadata,
|
||||
apm_output_filepath, reference_input_filepath,
|
||||
render_input_filepath, output_path):
|
||||
"""Runs the evaluation.
|
||||
@classmethod
|
||||
def Run(cls, evaluation_score_workers, apm_input_metadata,
|
||||
apm_output_filepath, reference_input_filepath,
|
||||
render_input_filepath, output_path):
|
||||
"""Runs the evaluation.
|
||||
|
||||
Iterates over the given evaluation score workers.
|
||||
|
||||
@ -37,20 +36,22 @@ class ApmModuleEvaluator(object):
|
||||
Returns:
|
||||
A dict of evaluation score name and score pairs.
|
||||
"""
|
||||
# Init.
|
||||
scores = {}
|
||||
# Init.
|
||||
scores = {}
|
||||
|
||||
for evaluation_score_worker in evaluation_score_workers:
|
||||
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
|
||||
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
|
||||
evaluation_score_worker.SetReferenceSignalFilepath(
|
||||
reference_input_filepath)
|
||||
evaluation_score_worker.SetTestedSignalFilepath(
|
||||
apm_output_filepath)
|
||||
evaluation_score_worker.SetRenderSignalFilepath(
|
||||
render_input_filepath)
|
||||
for evaluation_score_worker in evaluation_score_workers:
|
||||
logging.info(' computing <%s> score',
|
||||
evaluation_score_worker.NAME)
|
||||
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
|
||||
evaluation_score_worker.SetReferenceSignalFilepath(
|
||||
reference_input_filepath)
|
||||
evaluation_score_worker.SetTestedSignalFilepath(
|
||||
apm_output_filepath)
|
||||
evaluation_score_worker.SetRenderSignalFilepath(
|
||||
render_input_filepath)
|
||||
|
||||
evaluation_score_worker.Run(output_path)
|
||||
scores[evaluation_score_worker.NAME] = evaluation_score_worker.score
|
||||
evaluation_score_worker.Run(output_path)
|
||||
scores[
|
||||
evaluation_score_worker.NAME] = evaluation_score_worker.score
|
||||
|
||||
return scores
|
||||
return scores
|
||||
|
||||
@ -5,42 +5,41 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Exception classes.
|
||||
"""
|
||||
|
||||
|
||||
class FileNotFoundError(Exception):
|
||||
"""File not found exception.
|
||||
"""File not found exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class SignalProcessingException(Exception):
|
||||
"""Signal processing exception.
|
||||
"""Signal processing exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InputMixerException(Exception):
|
||||
"""Input mixer exception.
|
||||
"""Input mixer exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InputSignalCreatorException(Exception):
|
||||
"""Input signal creator exception.
|
||||
"""Input signal creator exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class EvaluationScoreException(Exception):
|
||||
"""Evaluation score exception.
|
||||
"""Evaluation score exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class InitializationException(Exception):
|
||||
"""Initialization exception.
|
||||
"""Initialization exception.
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -14,58 +14,58 @@ import re
|
||||
import sys
|
||||
|
||||
try:
|
||||
import csscompressor
|
||||
import csscompressor
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package csscompressor')
|
||||
sys.exit(1)
|
||||
logging.critical(
|
||||
'Cannot import the third-party Python package csscompressor')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import jsmin
|
||||
import jsmin
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package jsmin')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package jsmin')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class HtmlExport(object):
|
||||
"""HTML exporter class for APM quality scores."""
|
||||
"""HTML exporter class for APM quality scores."""
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
_NEW_LINE = '\n'
|
||||
|
||||
# CSS and JS file paths.
|
||||
_PATH = os.path.dirname(os.path.realpath(__file__))
|
||||
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
|
||||
_CSS_MINIFIED = True
|
||||
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
|
||||
_JS_MINIFIED = True
|
||||
# CSS and JS file paths.
|
||||
_PATH = os.path.dirname(os.path.realpath(__file__))
|
||||
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
|
||||
_CSS_MINIFIED = True
|
||||
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
|
||||
_JS_MINIFIED = True
|
||||
|
||||
def __init__(self, output_filepath):
|
||||
self._scores_data_frame = None
|
||||
self._output_filepath = output_filepath
|
||||
def __init__(self, output_filepath):
|
||||
self._scores_data_frame = None
|
||||
self._output_filepath = output_filepath
|
||||
|
||||
def Export(self, scores_data_frame):
|
||||
"""Exports scores into an HTML file.
|
||||
def Export(self, scores_data_frame):
|
||||
"""Exports scores into an HTML file.
|
||||
|
||||
Args:
|
||||
scores_data_frame: DataFrame instance.
|
||||
"""
|
||||
self._scores_data_frame = scores_data_frame
|
||||
html = ['<html>',
|
||||
self._scores_data_frame = scores_data_frame
|
||||
html = [
|
||||
'<html>',
|
||||
self._BuildHeader(),
|
||||
('<script type="text/javascript">'
|
||||
'(function () {'
|
||||
'window.addEventListener(\'load\', function () {'
|
||||
'var inspector = new AudioInspector();'
|
||||
'});'
|
||||
'(function () {'
|
||||
'window.addEventListener(\'load\', function () {'
|
||||
'var inspector = new AudioInspector();'
|
||||
'});'
|
||||
'})();'
|
||||
'</script>'),
|
||||
'<body>',
|
||||
self._BuildBody(),
|
||||
'</body>',
|
||||
'</html>']
|
||||
self._Save(self._output_filepath, self._NEW_LINE.join(html))
|
||||
'</script>'), '<body>',
|
||||
self._BuildBody(), '</body>', '</html>'
|
||||
]
|
||||
self._Save(self._output_filepath, self._NEW_LINE.join(html))
|
||||
|
||||
def _BuildHeader(self):
|
||||
"""Builds the <head> section of the HTML file.
|
||||
def _BuildHeader(self):
|
||||
"""Builds the <head> section of the HTML file.
|
||||
|
||||
The header contains the page title and either embedded or linked CSS and JS
|
||||
files.
|
||||
@ -73,325 +73,349 @@ class HtmlExport(object):
|
||||
Returns:
|
||||
A string with <head>...</head> HTML.
|
||||
"""
|
||||
html = ['<head>', '<title>Results</title>']
|
||||
html = ['<head>', '<title>Results</title>']
|
||||
|
||||
# Add Material Design hosted libs.
|
||||
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
|
||||
'css?family=Roboto:300,400,500,700" type="text/css">')
|
||||
html.append('<link rel="stylesheet" href="https://fonts.googleapis.com/'
|
||||
'icon?family=Material+Icons">')
|
||||
html.append('<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
|
||||
'material.indigo-pink.min.css">')
|
||||
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
|
||||
'material.min.js"></script>')
|
||||
|
||||
# Embed custom JavaScript and CSS files.
|
||||
html.append('<script>')
|
||||
with open(self._JS_FILEPATH) as f:
|
||||
html.append(jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</script>')
|
||||
html.append('<style>')
|
||||
with open(self._CSS_FILEPATH) as f:
|
||||
html.append(csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</style>')
|
||||
|
||||
html.append('</head>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildBody(self):
|
||||
"""Builds the content of the <body> section."""
|
||||
score_names = self._scores_data_frame['eval_score_name'].drop_duplicates(
|
||||
).values.tolist()
|
||||
|
||||
html = [
|
||||
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
|
||||
'mdl-layout--fixed-tabs">'),
|
||||
'<header class="mdl-layout__header">',
|
||||
'<div class="mdl-layout__header-row">',
|
||||
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
|
||||
self._output_filepath),
|
||||
'</div>',
|
||||
]
|
||||
|
||||
# Tab selectors.
|
||||
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
is_active = tab_index == 0
|
||||
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
|
||||
'{}</a>'.format(tab_index,
|
||||
' is-active' if is_active else '',
|
||||
self._FormatName(score_name)))
|
||||
html.append('</div>')
|
||||
|
||||
html.append('</header>')
|
||||
html.append('<main class="mdl-layout__content" style="overflow-x: auto;">')
|
||||
|
||||
# Tabs content.
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
html.append('<section class="mdl-layout__tab-panel{}" '
|
||||
'id="score-tab-{}">'.format(
|
||||
' is-active' if is_active else '', tab_index))
|
||||
html.append('<div class="page-content">')
|
||||
html.append(self._BuildScoreTab(score_name, ('s{}'.format(tab_index),)))
|
||||
html.append('</div>')
|
||||
html.append('</section>')
|
||||
|
||||
html.append('</main>')
|
||||
html.append('</div>')
|
||||
|
||||
# Add snackbar for notifications.
|
||||
html.append(
|
||||
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
|
||||
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
|
||||
'<div class="mdl-snackbar__text"></div>'
|
||||
'<button type="button" class="mdl-snackbar__action"></button>'
|
||||
'</div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreTab(self, score_name, anchor_data):
|
||||
"""Builds the content of a tab."""
|
||||
# Find unique values.
|
||||
scores = self._scores_data_frame[
|
||||
self._scores_data_frame.eval_score_name == score_name]
|
||||
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
|
||||
test_data_gen_configs = sorted(self._FindUniqueTuples(
|
||||
scores, ['test_data_gen', 'test_data_gen_params']))
|
||||
|
||||
html = [
|
||||
'<div class="mdl-grid">',
|
||||
'<div class="mdl-layout-spacer"></div>',
|
||||
'<div class="mdl-cell mdl-cell--10-col">',
|
||||
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
|
||||
'style="width: 100%;">'),
|
||||
]
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>APM config / Test data generator</th>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
html.append('<th>{} {}</th>'.format(
|
||||
self._FormatName(test_data_gen_info[0]), test_data_gen_info[1]))
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for apm_config in apm_configs:
|
||||
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
# Add Material Design hosted libs.
|
||||
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
|
||||
'css?family=Roboto:300,400,500,700" type="text/css">')
|
||||
html.append(
|
||||
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.format(
|
||||
dialog_id, self._BuildScoreTableCell(
|
||||
score_name, test_data_gen_info[0], test_data_gen_info[1],
|
||||
apm_config[0])))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
'<link rel="stylesheet" href="https://fonts.googleapis.com/'
|
||||
'icon?family=Material+Icons">')
|
||||
html.append(
|
||||
'<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
|
||||
'material.indigo-pink.min.css">')
|
||||
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
|
||||
'material.min.js"></script>')
|
||||
|
||||
html.append('</table></div><div class="mdl-layout-spacer"></div></div>')
|
||||
# Embed custom JavaScript and CSS files.
|
||||
html.append('<script>')
|
||||
with open(self._JS_FILEPATH) as f:
|
||||
html.append(
|
||||
jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</script>')
|
||||
html.append('<style>')
|
||||
with open(self._CSS_FILEPATH) as f:
|
||||
html.append(
|
||||
csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</style>')
|
||||
|
||||
html.append(self._BuildScoreStatsInspectorDialogs(
|
||||
score_name, apm_configs, test_data_gen_configs,
|
||||
anchor_data))
|
||||
html.append('</head>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreTableCell(self, score_name, test_data_gen,
|
||||
test_data_gen_params, apm_config):
|
||||
"""Builds the content of a table cell for a score table."""
|
||||
scores = self._SliceDataForScoreTableCell(
|
||||
score_name, apm_config, test_data_gen, test_data_gen_params)
|
||||
stats = self._ComputeScoreStats(scores)
|
||||
def _BuildBody(self):
|
||||
"""Builds the content of the <body> section."""
|
||||
score_names = self._scores_data_frame[
|
||||
'eval_score_name'].drop_duplicates().values.tolist()
|
||||
|
||||
html = []
|
||||
items_id_prefix = (
|
||||
score_name + test_data_gen + test_data_gen_params + apm_config)
|
||||
if stats['count'] == 1:
|
||||
# Show the only available score.
|
||||
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
|
||||
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
|
||||
item_id, scores['score'].mean()))
|
||||
html.append('<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
|
||||
'</div>'.format(item_id, 'single value'))
|
||||
else:
|
||||
# Show stats.
|
||||
for stat_name in ['min', 'max', 'mean', 'std dev']:
|
||||
item_id = hashlib.md5(
|
||||
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
|
||||
html.append('<div id="stats-{0}">{1:f}</div>'.format(
|
||||
item_id, stats[stat_name]))
|
||||
html.append('<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
|
||||
'</div>'.format(item_id, stat_name))
|
||||
html = [
|
||||
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
|
||||
'mdl-layout--fixed-tabs">'),
|
||||
'<header class="mdl-layout__header">',
|
||||
'<div class="mdl-layout__header-row">',
|
||||
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
|
||||
self._output_filepath),
|
||||
'</div>',
|
||||
]
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorDialogs(
|
||||
self, score_name, apm_configs, test_data_gen_configs, anchor_data):
|
||||
"""Builds a set of score stats inspector dialogs."""
|
||||
html = []
|
||||
for apm_config in apm_configs:
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0],
|
||||
test_data_gen_info[0], test_data_gen_info[1])
|
||||
|
||||
html.append('<dialog class="mdl-dialog" id="{}" '
|
||||
'style="width: 40%;">'.format(dialog_id))
|
||||
|
||||
# Content.
|
||||
html.append('<div class="mdl-dialog__content">')
|
||||
html.append('<h6><strong>APM config preset</strong>: {}<br/>'
|
||||
'<strong>Test data generator</strong>: {} ({})</h6>'.format(
|
||||
self._FormatName(apm_config[0]),
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append(self._BuildScoreStatsInspectorDialog(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1], anchor_data + (dialog_id,)))
|
||||
# Tab selectors.
|
||||
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
is_active = tab_index == 0
|
||||
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
|
||||
'{}</a>'.format(tab_index,
|
||||
' is-active' if is_active else '',
|
||||
self._FormatName(score_name)))
|
||||
html.append('</div>')
|
||||
|
||||
# Actions.
|
||||
html.append('<div class="mdl-dialog__actions">')
|
||||
html.append('<button type="button" class="mdl-button" '
|
||||
'onclick="closeScoreStatsInspector()">'
|
||||
'Close</button>')
|
||||
html.append('</header>')
|
||||
html.append(
|
||||
'<main class="mdl-layout__content" style="overflow-x: auto;">')
|
||||
|
||||
# Tabs content.
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
html.append('<section class="mdl-layout__tab-panel{}" '
|
||||
'id="score-tab-{}">'.format(
|
||||
' is-active' if is_active else '', tab_index))
|
||||
html.append('<div class="page-content">')
|
||||
html.append(
|
||||
self._BuildScoreTab(score_name, ('s{}'.format(tab_index), )))
|
||||
html.append('</div>')
|
||||
html.append('</section>')
|
||||
|
||||
html.append('</main>')
|
||||
html.append('</div>')
|
||||
|
||||
html.append('</dialog>')
|
||||
# Add snackbar for notifications.
|
||||
html.append(
|
||||
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
|
||||
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
|
||||
'<div class="mdl-snackbar__text"></div>'
|
||||
'<button type="button" class="mdl-snackbar__action"></button>'
|
||||
'</div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorDialog(
|
||||
self, score_name, apm_config, test_data_gen, test_data_gen_params,
|
||||
anchor_data):
|
||||
"""Builds one score stats inspector dialog."""
|
||||
scores = self._SliceDataForScoreTableCell(
|
||||
score_name, apm_config, test_data_gen, test_data_gen_params)
|
||||
def _BuildScoreTab(self, score_name, anchor_data):
|
||||
"""Builds the content of a tab."""
|
||||
# Find unique values.
|
||||
scores = self._scores_data_frame[
|
||||
self._scores_data_frame.eval_score_name == score_name]
|
||||
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
|
||||
test_data_gen_configs = sorted(
|
||||
self._FindUniqueTuples(scores,
|
||||
['test_data_gen', 'test_data_gen_params']))
|
||||
|
||||
capture_render_pairs = sorted(self._FindUniqueTuples(
|
||||
scores, ['capture', 'render']))
|
||||
echo_simulators = sorted(self._FindUniqueTuples(scores, ['echo_simulator']))
|
||||
html = [
|
||||
'<div class="mdl-grid">',
|
||||
'<div class="mdl-layout-spacer"></div>',
|
||||
'<div class="mdl-cell mdl-cell--10-col">',
|
||||
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
|
||||
'style="width: 100%;">'),
|
||||
]
|
||||
|
||||
html = ['<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">']
|
||||
# Header.
|
||||
html.append('<thead><tr><th>APM config / Test data generator</th>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
html.append('<th>{} {}</th>'.format(
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
|
||||
for echo_simulator in echo_simulators:
|
||||
html.append('<th>' + self._FormatName(echo_simulator[0]) +'</th>')
|
||||
html.append('</tr></thead>')
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for apm_config in apm_configs:
|
||||
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
html.append(
|
||||
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.
|
||||
format(
|
||||
dialog_id,
|
||||
self._BuildScoreTableCell(score_name,
|
||||
test_data_gen_info[0],
|
||||
test_data_gen_info[1],
|
||||
apm_config[0])))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for row, (capture, render) in enumerate(capture_render_pairs):
|
||||
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
|
||||
capture, render))
|
||||
for col, echo_simulator in enumerate(echo_simulators):
|
||||
score_tuple = self._SliceDataForScoreStatsTableCell(
|
||||
scores, capture, render, echo_simulator[0])
|
||||
cell_class = 'r{}c{}'.format(row, col)
|
||||
html.append('<td class="single-score-cell {}">{}</td>'.format(
|
||||
cell_class, self._BuildScoreStatsInspectorTableCell(
|
||||
score_tuple, anchor_data + (cell_class,))))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
html.append(
|
||||
'</table></div><div class="mdl-layout-spacer"></div></div>')
|
||||
|
||||
html.append('</table>')
|
||||
html.append(
|
||||
self._BuildScoreStatsInspectorDialogs(score_name, apm_configs,
|
||||
test_data_gen_configs,
|
||||
anchor_data))
|
||||
|
||||
# Placeholder for the audio inspector.
|
||||
html.append('<div class="audio-inspector-placeholder"></div>')
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
def _BuildScoreTableCell(self, score_name, test_data_gen,
|
||||
test_data_gen_params, apm_config):
|
||||
"""Builds the content of a table cell for a score table."""
|
||||
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
|
||||
test_data_gen,
|
||||
test_data_gen_params)
|
||||
stats = self._ComputeScoreStats(scores)
|
||||
|
||||
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
|
||||
"""Builds the content of a cell of a score stats inspector."""
|
||||
anchor = '&'.join(anchor_data)
|
||||
html = [('<div class="v">{}</div>'
|
||||
'<button class="mdl-button mdl-js-button mdl-button--icon"'
|
||||
' data-anchor="{}">'
|
||||
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
|
||||
'</button>').format(score_tuple.score, anchor)]
|
||||
html = []
|
||||
items_id_prefix = (score_name + test_data_gen + test_data_gen_params +
|
||||
apm_config)
|
||||
if stats['count'] == 1:
|
||||
# Show the only available score.
|
||||
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
|
||||
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
|
||||
item_id, scores['score'].mean()))
|
||||
html.append(
|
||||
'<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
|
||||
'</div>'.format(item_id, 'single value'))
|
||||
else:
|
||||
# Show stats.
|
||||
for stat_name in ['min', 'max', 'mean', 'std dev']:
|
||||
item_id = hashlib.md5(
|
||||
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
|
||||
html.append('<div id="stats-{0}">{1:f}</div>'.format(
|
||||
item_id, stats[stat_name]))
|
||||
html.append(
|
||||
'<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
|
||||
'</div>'.format(item_id, stat_name))
|
||||
|
||||
# Add all the available file paths as hidden data.
|
||||
for field_name in score_tuple.keys():
|
||||
if field_name.endswith('_filepath'):
|
||||
html.append('<input type="hidden" name="{}" value="{}">'.format(
|
||||
field_name, score_tuple[field_name]))
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs,
|
||||
test_data_gen_configs, anchor_data):
|
||||
"""Builds a set of score stats inspector dialogs."""
|
||||
html = []
|
||||
for apm_config in apm_configs:
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
|
||||
def _SliceDataForScoreTableCell(
|
||||
self, score_name, apm_config, test_data_gen, test_data_gen_params):
|
||||
"""Slices |self._scores_data_frame| to extract the data for a tab."""
|
||||
masks = []
|
||||
masks.append(self._scores_data_frame.eval_score_name == score_name)
|
||||
masks.append(self._scores_data_frame.apm_config == apm_config)
|
||||
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
|
||||
masks.append(
|
||||
self._scores_data_frame.test_data_gen_params == test_data_gen_params)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
return self._scores_data_frame[mask]
|
||||
html.append('<dialog class="mdl-dialog" id="{}" '
|
||||
'style="width: 40%;">'.format(dialog_id))
|
||||
|
||||
@classmethod
|
||||
def _SliceDataForScoreStatsTableCell(
|
||||
cls, scores, capture, render, echo_simulator):
|
||||
"""Slices |scores| to extract the data for a tab."""
|
||||
masks = []
|
||||
# Content.
|
||||
html.append('<div class="mdl-dialog__content">')
|
||||
html.append(
|
||||
'<h6><strong>APM config preset</strong>: {}<br/>'
|
||||
'<strong>Test data generator</strong>: {} ({})</h6>'.
|
||||
format(self._FormatName(apm_config[0]),
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append(
|
||||
self._BuildScoreStatsInspectorDialog(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1], anchor_data + (dialog_id, )))
|
||||
html.append('</div>')
|
||||
|
||||
masks.append(scores.capture == capture)
|
||||
masks.append(scores.render == render)
|
||||
masks.append(scores.echo_simulator == echo_simulator)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
# Actions.
|
||||
html.append('<div class="mdl-dialog__actions">')
|
||||
html.append('<button type="button" class="mdl-button" '
|
||||
'onclick="closeScoreStatsInspector()">'
|
||||
'Close</button>')
|
||||
html.append('</div>')
|
||||
|
||||
sliced_data = scores[mask]
|
||||
assert len(sliced_data) == 1, 'single score is expected'
|
||||
return sliced_data.iloc[0]
|
||||
html.append('</dialog>')
|
||||
|
||||
@classmethod
|
||||
def _FindUniqueTuples(cls, data_frame, fields):
|
||||
"""Slices |data_frame| to a list of fields and finds unique tuples."""
|
||||
return data_frame[fields].drop_duplicates().values.tolist()
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
@classmethod
|
||||
def _ComputeScoreStats(cls, data_frame):
|
||||
"""Computes score stats."""
|
||||
scores = data_frame['score']
|
||||
return {
|
||||
'count': scores.count(),
|
||||
'min': scores.min(),
|
||||
'max': scores.max(),
|
||||
'mean': scores.mean(),
|
||||
'std dev': scores.std(),
|
||||
}
|
||||
def _BuildScoreStatsInspectorDialog(self, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params,
|
||||
anchor_data):
|
||||
"""Builds one score stats inspector dialog."""
|
||||
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
|
||||
test_data_gen,
|
||||
test_data_gen_params)
|
||||
|
||||
@classmethod
|
||||
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params):
|
||||
"""Assigns a unique name to a dialog."""
|
||||
return 'score-stats-dialog-' + hashlib.md5(
|
||||
'score-stats-inspector-{}-{}-{}-{}'.format(
|
||||
score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params).encode('utf-8')).hexdigest()
|
||||
capture_render_pairs = sorted(
|
||||
self._FindUniqueTuples(scores, ['capture', 'render']))
|
||||
echo_simulators = sorted(
|
||||
self._FindUniqueTuples(scores, ['echo_simulator']))
|
||||
|
||||
@classmethod
|
||||
def _Save(cls, output_filepath, html):
|
||||
"""Writes the HTML file.
|
||||
html = [
|
||||
'<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">'
|
||||
]
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
|
||||
for echo_simulator in echo_simulators:
|
||||
html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>')
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for row, (capture, render) in enumerate(capture_render_pairs):
|
||||
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
|
||||
capture, render))
|
||||
for col, echo_simulator in enumerate(echo_simulators):
|
||||
score_tuple = self._SliceDataForScoreStatsTableCell(
|
||||
scores, capture, render, echo_simulator[0])
|
||||
cell_class = 'r{}c{}'.format(row, col)
|
||||
html.append('<td class="single-score-cell {}">{}</td>'.format(
|
||||
cell_class,
|
||||
self._BuildScoreStatsInspectorTableCell(
|
||||
score_tuple, anchor_data + (cell_class, ))))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
|
||||
html.append('</table>')
|
||||
|
||||
# Placeholder for the audio inspector.
|
||||
html.append('<div class="audio-inspector-placeholder"></div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
|
||||
"""Builds the content of a cell of a score stats inspector."""
|
||||
anchor = '&'.join(anchor_data)
|
||||
html = [('<div class="v">{}</div>'
|
||||
'<button class="mdl-button mdl-js-button mdl-button--icon"'
|
||||
' data-anchor="{}">'
|
||||
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
|
||||
'</button>').format(score_tuple.score, anchor)]
|
||||
|
||||
# Add all the available file paths as hidden data.
|
||||
for field_name in score_tuple.keys():
|
||||
if field_name.endswith('_filepath'):
|
||||
html.append(
|
||||
'<input type="hidden" name="{}" value="{}">'.format(
|
||||
field_name, score_tuple[field_name]))
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _SliceDataForScoreTableCell(self, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params):
|
||||
"""Slices |self._scores_data_frame| to extract the data for a tab."""
|
||||
masks = []
|
||||
masks.append(self._scores_data_frame.eval_score_name == score_name)
|
||||
masks.append(self._scores_data_frame.apm_config == apm_config)
|
||||
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
|
||||
masks.append(self._scores_data_frame.test_data_gen_params ==
|
||||
test_data_gen_params)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
return self._scores_data_frame[mask]
|
||||
|
||||
@classmethod
|
||||
def _SliceDataForScoreStatsTableCell(cls, scores, capture, render,
|
||||
echo_simulator):
|
||||
"""Slices |scores| to extract the data for a tab."""
|
||||
masks = []
|
||||
|
||||
masks.append(scores.capture == capture)
|
||||
masks.append(scores.render == render)
|
||||
masks.append(scores.echo_simulator == echo_simulator)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
|
||||
sliced_data = scores[mask]
|
||||
assert len(sliced_data) == 1, 'single score is expected'
|
||||
return sliced_data.iloc[0]
|
||||
|
||||
@classmethod
|
||||
def _FindUniqueTuples(cls, data_frame, fields):
|
||||
"""Slices |data_frame| to a list of fields and finds unique tuples."""
|
||||
return data_frame[fields].drop_duplicates().values.tolist()
|
||||
|
||||
@classmethod
|
||||
def _ComputeScoreStats(cls, data_frame):
|
||||
"""Computes score stats."""
|
||||
scores = data_frame['score']
|
||||
return {
|
||||
'count': scores.count(),
|
||||
'min': scores.min(),
|
||||
'max': scores.max(),
|
||||
'mean': scores.mean(),
|
||||
'std dev': scores.std(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params):
|
||||
"""Assigns a unique name to a dialog."""
|
||||
return 'score-stats-dialog-' + hashlib.md5(
|
||||
'score-stats-inspector-{}-{}-{}-{}'.format(
|
||||
score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params).encode('utf-8')).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _Save(cls, output_filepath, html):
|
||||
"""Writes the HTML file.
|
||||
|
||||
Args:
|
||||
output_filepath: output file path.
|
||||
html: string with the HTML content.
|
||||
"""
|
||||
with open(output_filepath, 'w') as f:
|
||||
f.write(html)
|
||||
with open(output_filepath, 'w') as f:
|
||||
f.write(html)
|
||||
|
||||
@classmethod
|
||||
def _FormatName(cls, name):
|
||||
"""Formats a name.
|
||||
@classmethod
|
||||
def _FormatName(cls, name):
|
||||
"""Formats a name.
|
||||
|
||||
Args:
|
||||
name: a string.
|
||||
@ -399,4 +423,4 @@ class HtmlExport(object):
|
||||
Returns:
|
||||
A copy of name in which underscores and dashes are replaced with a space.
|
||||
"""
|
||||
return re.sub(r'[_\-]', ' ', name)
|
||||
return re.sub(r'[_\-]', ' ', name)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the export module.
|
||||
"""
|
||||
|
||||
@ -27,60 +26,61 @@ from . import test_data_generation_factory
|
||||
|
||||
|
||||
class TestExport(unittest.TestCase):
|
||||
"""Unit tests for the export module.
|
||||
"""Unit tests for the export module.
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data to export."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Creates temporary data to export."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Run a fake experiment to produce data to export.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=[
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
|
||||
],
|
||||
test_data_generator_names=['identity', 'white_noise'],
|
||||
eval_score_names=['audio_level_peak', 'audio_level_mean'],
|
||||
output_dir=self._tmp_path)
|
||||
# Run a fake experiment to produce data to export.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=[
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
|
||||
],
|
||||
test_data_generator_names=['identity', 'white_noise'],
|
||||
eval_score_names=['audio_level_peak', 'audio_level_mean'],
|
||||
output_dir=self._tmp_path)
|
||||
|
||||
# Export results.
|
||||
p = collect_data.InstanceArgumentsParser()
|
||||
args = p.parse_args(['--output_dir', self._tmp_path])
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
self._data_to_export = collect_data.FindScores(src_path, args)
|
||||
# Export results.
|
||||
p = collect_data.InstanceArgumentsParser()
|
||||
args = p.parse_args(['--output_dir', self._tmp_path])
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
self._data_to_export = collect_data.FindScores(src_path, args)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' + (
|
||||
self._tmp_path))
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' +
|
||||
(self._tmp_path))
|
||||
|
||||
def testCreateHtmlReport(self):
|
||||
fn_out = os.path.join(self._tmp_path, 'results.html')
|
||||
exporter = export.HtmlExport(fn_out)
|
||||
exporter.Export(self._data_to_export)
|
||||
def testCreateHtmlReport(self):
|
||||
fn_out = os.path.join(self._tmp_path, 'results.html')
|
||||
exporter = export.HtmlExport(fn_out)
|
||||
exporter.Export(self._data_to_export)
|
||||
|
||||
document = pq.PyQuery(filename=fn_out)
|
||||
self.assertIsInstance(document, pq.PyQuery)
|
||||
# TODO(alessiob): Use PyQuery API to check the HTML file.
|
||||
document = pq.PyQuery(filename=fn_out)
|
||||
self.assertIsInstance(document, pq.PyQuery)
|
||||
# TODO(alessiob): Use PyQuery API to check the HTML file.
|
||||
|
||||
@ -16,62 +16,60 @@ import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import signal_processing
|
||||
|
||||
class ExternalVad(object):
|
||||
|
||||
def __init__(self, path_to_binary, name):
|
||||
"""Args:
|
||||
class ExternalVad(object):
|
||||
def __init__(self, path_to_binary, name):
|
||||
"""Args:
|
||||
path_to_binary: path to binary that accepts '-i <wav>', '-o
|
||||
<float probabilities>'. There must be one float value per
|
||||
10ms audio
|
||||
name: a name to identify the external VAD. Used for saving
|
||||
the output as extvad_output-<name>.
|
||||
"""
|
||||
self._path_to_binary = path_to_binary
|
||||
self.name = name
|
||||
assert os.path.exists(self._path_to_binary), (
|
||||
self._path_to_binary)
|
||||
self._vad_output = None
|
||||
self._path_to_binary = path_to_binary
|
||||
self.name = name
|
||||
assert os.path.exists(self._path_to_binary), (self._path_to_binary)
|
||||
self._vad_output = None
|
||||
|
||||
def Run(self, wav_file_path):
|
||||
_signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
|
||||
if _signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel'
|
||||
' annotations not implemented')
|
||||
if _signal.frame_rate != 48000:
|
||||
raise NotImplementedError('Frame rates '
|
||||
'other than 48000 not implemented')
|
||||
def Run(self, wav_file_path):
|
||||
_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
wav_file_path)
|
||||
if _signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel'
|
||||
' annotations not implemented')
|
||||
if _signal.frame_rate != 48000:
|
||||
raise NotImplementedError('Frame rates '
|
||||
'other than 48000 not implemented')
|
||||
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
try:
|
||||
output_file_path = os.path.join(
|
||||
tmp_path, self.name + '_vad.tmp')
|
||||
subprocess.call([
|
||||
self._path_to_binary,
|
||||
'-i', wav_file_path,
|
||||
'-o', output_file_path
|
||||
])
|
||||
self._vad_output = np.fromfile(output_file_path, np.float32)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the ' + self.name +
|
||||
' VAD (' + e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
try:
|
||||
output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
|
||||
subprocess.call([
|
||||
self._path_to_binary, '-i', wav_file_path, '-o',
|
||||
output_file_path
|
||||
])
|
||||
self._vad_output = np.fromfile(output_file_path, np.float32)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the ' + self.name + ' VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
def GetVadOutput(self):
|
||||
assert self._vad_output is not None
|
||||
return self._vad_output
|
||||
def GetVadOutput(self):
|
||||
assert self._vad_output is not None
|
||||
return self._vad_output
|
||||
|
||||
@classmethod
|
||||
def ConstructVadDict(cls, vad_paths, vad_names):
|
||||
external_vads = {}
|
||||
for path, name in zip(vad_paths, vad_names):
|
||||
external_vads[name] = ExternalVad(path, name)
|
||||
return external_vads
|
||||
@classmethod
|
||||
def ConstructVadDict(cls, vad_paths, vad_names):
|
||||
external_vads = {}
|
||||
for path, name in zip(vad_paths, vad_names):
|
||||
external_vads[name] = ExternalVad(path, name)
|
||||
return external_vads
|
||||
|
||||
@ -9,16 +9,17 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', required=True)
|
||||
parser.add_argument('-o', required=True)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', required=True)
|
||||
parser.add_argument('-o', required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
array = np.arange(100, dtype=np.float32)
|
||||
array.tofile(open(args.o, 'w'))
|
||||
array = np.arange(100, dtype=np.float32)
|
||||
array.tofile(open(args.o, 'w'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Input mixer module.
|
||||
"""
|
||||
|
||||
@ -17,24 +16,24 @@ from . import signal_processing
|
||||
|
||||
|
||||
class ApmInputMixer(object):
|
||||
"""Class to mix a set of audio segments down to the APM input."""
|
||||
"""Class to mix a set of audio segments down to the APM input."""
|
||||
|
||||
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
|
||||
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def HardClippingLogMessage(cls):
|
||||
"""Returns the log message used when hard clipping is detected in the mix.
|
||||
@classmethod
|
||||
def HardClippingLogMessage(cls):
|
||||
"""Returns the log message used when hard clipping is detected in the mix.
|
||||
|
||||
This method is mainly intended to be used by the unit tests.
|
||||
"""
|
||||
return cls._HARD_CLIPPING_LOG_MSG
|
||||
return cls._HARD_CLIPPING_LOG_MSG
|
||||
|
||||
@classmethod
|
||||
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
|
||||
"""Mixes capture and echo.
|
||||
@classmethod
|
||||
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
|
||||
"""Mixes capture and echo.
|
||||
|
||||
Creates the overall capture input for APM by mixing the "echo-free" capture
|
||||
signal with the echo signal (e.g., echo simulated via the
|
||||
@ -58,38 +57,41 @@ class ApmInputMixer(object):
|
||||
Returns:
|
||||
Path to the mix audio track file.
|
||||
"""
|
||||
if echo_filepath is None:
|
||||
return capture_input_filepath
|
||||
if echo_filepath is None:
|
||||
return capture_input_filepath
|
||||
|
||||
# Build the mix output file name as a function of the echo file name.
|
||||
# This ensures that if the internal parameters of the echo path simulator
|
||||
# change, no erroneous cache hit occurs.
|
||||
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
|
||||
capture_input_file_name, _ = os.path.splitext(
|
||||
os.path.split(capture_input_filepath)[1])
|
||||
mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format(
|
||||
capture_input_file_name, echo_file_name))
|
||||
# Build the mix output file name as a function of the echo file name.
|
||||
# This ensures that if the internal parameters of the echo path simulator
|
||||
# change, no erroneous cache hit occurs.
|
||||
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
|
||||
capture_input_file_name, _ = os.path.splitext(
|
||||
os.path.split(capture_input_filepath)[1])
|
||||
mix_filepath = os.path.join(
|
||||
output_path,
|
||||
'mix_capture_{}_{}.wav'.format(capture_input_file_name,
|
||||
echo_file_name))
|
||||
|
||||
# Create the mix if not done yet.
|
||||
mix = None
|
||||
if not os.path.exists(mix_filepath):
|
||||
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
capture_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
|
||||
# Create the mix if not done yet.
|
||||
mix = None
|
||||
if not os.path.exists(mix_filepath):
|
||||
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
capture_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
echo_filepath)
|
||||
|
||||
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
echo_free_capture)):
|
||||
raise exceptions.InputMixerException(
|
||||
'echo cannot be shorter than capture')
|
||||
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
echo_free_capture)):
|
||||
raise exceptions.InputMixerException(
|
||||
'echo cannot be shorter than capture')
|
||||
|
||||
mix = echo_free_capture.overlay(echo)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
|
||||
mix = echo_free_capture.overlay(echo)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
|
||||
|
||||
# Check if hard clipping occurs.
|
||||
if mix is None:
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
|
||||
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
|
||||
# Check if hard clipping occurs.
|
||||
if mix is None:
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
|
||||
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
|
||||
|
||||
return mix_filepath
|
||||
return mix_filepath
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the input mixer module.
|
||||
"""
|
||||
|
||||
@ -23,122 +22,119 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestApmInputMixer(unittest.TestCase):
|
||||
"""Unit tests for the ApmInputMixer class.
|
||||
"""Unit tests for the ApmInputMixer class.
|
||||
"""
|
||||
|
||||
# Audio track file names created in setUp().
|
||||
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
|
||||
# Audio track file names created in setUp().
|
||||
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
|
||||
|
||||
# Target peak power level (dBFS) of each audio track file created in setUp().
|
||||
# These values are hand-crafted in order to make saturation happen when
|
||||
# capture and echo_2 are mixed and the contrary for capture and echo_1.
|
||||
# None means that the power is not changed.
|
||||
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
|
||||
# Target peak power level (dBFS) of each audio track file created in setUp().
|
||||
# These values are hand-crafted in order to make saturation happen when
|
||||
# capture and echo_2 are mixed and the contrary for capture and echo_1.
|
||||
# None means that the power is not changed.
|
||||
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
|
||||
|
||||
# Audio track file durations in milliseconds.
|
||||
_DURATIONS = [1000, 1000, 1000, 800, 1200]
|
||||
# Audio track file durations in milliseconds.
|
||||
_DURATIONS = [1000, 1000, 1000, 800, 1200]
|
||||
|
||||
_SAMPLE_RATE = 48000
|
||||
_SAMPLE_RATE = 48000
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Create audio track files.
|
||||
self._audio_tracks = {}
|
||||
for filename, peak_power, duration in zip(
|
||||
self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS):
|
||||
audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format(
|
||||
filename))
|
||||
# Create audio track files.
|
||||
self._audio_tracks = {}
|
||||
for filename, peak_power, duration in zip(self._FILENAMES,
|
||||
self._MAX_PEAK_POWER_LEVELS,
|
||||
self._DURATIONS):
|
||||
audio_track_filepath = os.path.join(self._tmp_path,
|
||||
'{}.wav'.format(filename))
|
||||
|
||||
# Create a pure tone with the target peak power level.
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration=duration, sample_rate=self._SAMPLE_RATE)
|
||||
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template)
|
||||
if peak_power is not None:
|
||||
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
|
||||
# Create a pure tone with the target peak power level.
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration=duration, sample_rate=self._SAMPLE_RATE)
|
||||
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template)
|
||||
if peak_power is not None:
|
||||
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
|
||||
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
audio_track_filepath, signal)
|
||||
self._audio_tracks[filename] = {
|
||||
'filepath': audio_track_filepath,
|
||||
'num_samples': signal_processing.SignalProcessingUtils.CountSamples(
|
||||
signal)
|
||||
}
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
audio_track_filepath, signal)
|
||||
self._audio_tracks[filename] = {
|
||||
'filepath':
|
||||
audio_track_filepath,
|
||||
'num_samples':
|
||||
signal_processing.SignalProcessingUtils.CountSamples(signal)
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testCheckMixSameDuration(self):
|
||||
"""Checks the duration when mixing capture and echo with same duration."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
def testCheckMixSameDuration(self):
|
||||
"""Checks the duration when mixing capture and echo with same duration."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(
|
||||
self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
|
||||
def testRejectShorterEcho(self):
|
||||
"""Rejects echo signals that are shorter than the capture signal."""
|
||||
try:
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['shorter']['filepath'])
|
||||
self.fail('no exception raised')
|
||||
except exceptions.InputMixerException:
|
||||
pass
|
||||
def testRejectShorterEcho(self):
|
||||
"""Rejects echo signals that are shorter than the capture signal."""
|
||||
try:
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['shorter']['filepath'])
|
||||
self.fail('no exception raised')
|
||||
except exceptions.InputMixerException:
|
||||
pass
|
||||
|
||||
def testCheckMixDurationWithLongerEcho(self):
|
||||
"""Checks the duration when mixing an echo longer than the capture."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['longer']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
def testCheckMixDurationWithLongerEcho(self):
|
||||
"""Checks the duration when mixing an echo longer than the capture."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['longer']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(
|
||||
self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
|
||||
def testCheckOutputFileNamesConflict(self):
|
||||
"""Checks that different echo files lead to different output file names."""
|
||||
mix1_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix1_filepath))
|
||||
def testCheckOutputFileNamesConflict(self):
|
||||
"""Checks that different echo files lead to different output file names."""
|
||||
mix1_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix1_filepath))
|
||||
|
||||
mix2_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix2_filepath))
|
||||
mix2_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix2_filepath))
|
||||
|
||||
self.assertNotEqual(mix1_filepath, mix2_filepath)
|
||||
self.assertNotEqual(mix1_filepath, mix2_filepath)
|
||||
|
||||
def testHardClippingLogExpected(self):
|
||||
"""Checks that hard clipping warning is raised when occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
logging.warning.assert_called_once_with(
|
||||
input_mixer.ApmInputMixer.HardClippingLogMessage())
|
||||
def testHardClippingLogExpected(self):
|
||||
"""Checks that hard clipping warning is raised when occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
logging.warning.assert_called_once_with(
|
||||
input_mixer.ApmInputMixer.HardClippingLogMessage())
|
||||
|
||||
def testHardClippingLogNotExpected(self):
|
||||
"""Checks that hard clipping warning is not raised when not occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path,
|
||||
self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertNotIn(
|
||||
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
|
||||
logging.warning.call_args_list)
|
||||
def testHardClippingLogNotExpected(self):
|
||||
"""Checks that hard clipping warning is not raised when not occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertNotIn(
|
||||
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
|
||||
logging.warning.call_args_list)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Input signal creator module.
|
||||
"""
|
||||
|
||||
@ -14,12 +13,12 @@ from . import signal_processing
|
||||
|
||||
|
||||
class InputSignalCreator(object):
|
||||
"""Input signal creator class.
|
||||
"""Input signal creator class.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def Create(cls, name, raw_params):
|
||||
"""Creates a input signal and its metadata.
|
||||
@classmethod
|
||||
def Create(cls, name, raw_params):
|
||||
"""Creates a input signal and its metadata.
|
||||
|
||||
Args:
|
||||
name: Input signal creator name.
|
||||
@ -28,29 +27,30 @@ class InputSignalCreator(object):
|
||||
Returns:
|
||||
(AudioSegment, dict) tuple.
|
||||
"""
|
||||
try:
|
||||
signal = {}
|
||||
params = {}
|
||||
try:
|
||||
signal = {}
|
||||
params = {}
|
||||
|
||||
if name == 'pure_tone':
|
||||
params['frequency'] = float(raw_params[0])
|
||||
params['duration'] = int(raw_params[1])
|
||||
signal = cls._CreatePureTone(params['frequency'], params['duration'])
|
||||
else:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid input signal creator name')
|
||||
if name == 'pure_tone':
|
||||
params['frequency'] = float(raw_params[0])
|
||||
params['duration'] = int(raw_params[1])
|
||||
signal = cls._CreatePureTone(params['frequency'],
|
||||
params['duration'])
|
||||
else:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid input signal creator name')
|
||||
|
||||
# Complete metadata.
|
||||
params['signal'] = name
|
||||
# Complete metadata.
|
||||
params['signal'] = name
|
||||
|
||||
return signal, params
|
||||
except (TypeError, AssertionError) as e:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid signal creator parameters: {}'.format(e))
|
||||
return signal, params
|
||||
except (TypeError, AssertionError) as e:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid signal creator parameters: {}'.format(e))
|
||||
|
||||
@classmethod
|
||||
def _CreatePureTone(cls, frequency, duration):
|
||||
"""
|
||||
@classmethod
|
||||
def _CreatePureTone(cls, frequency, duration):
|
||||
"""
|
||||
Generates a pure tone at 48000 Hz.
|
||||
|
||||
Args:
|
||||
@ -60,8 +60,9 @@ class InputSignalCreator(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
assert 0 < frequency <= 24000
|
||||
assert duration > 0
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(duration)
|
||||
return signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, frequency)
|
||||
assert 0 < frequency <= 24000
|
||||
assert duration > 0
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration)
|
||||
return signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, frequency)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Signal processing utility module.
|
||||
"""
|
||||
|
||||
@ -16,44 +15,44 @@ import sys
|
||||
import enum
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import pydub
|
||||
import pydub.generators
|
||||
import pydub
|
||||
import pydub.generators
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package pydub')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package pydub')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import scipy.signal
|
||||
import scipy.fftpack
|
||||
import scipy.signal
|
||||
import scipy.fftpack
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class SignalProcessingUtils(object):
|
||||
"""Collection of signal processing utilities.
|
||||
"""Collection of signal processing utilities.
|
||||
"""
|
||||
|
||||
@enum.unique
|
||||
class MixPadding(enum.Enum):
|
||||
NO_PADDING = 0
|
||||
ZERO_PADDING = 1
|
||||
LOOP = 2
|
||||
@enum.unique
|
||||
class MixPadding(enum.Enum):
|
||||
NO_PADDING = 0
|
||||
ZERO_PADDING = 1
|
||||
LOOP = 2
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def LoadWav(cls, filepath, channels=1):
|
||||
"""Loads wav file.
|
||||
@classmethod
|
||||
def LoadWav(cls, filepath, channels=1):
|
||||
"""Loads wav file.
|
||||
|
||||
Args:
|
||||
filepath: path to the wav audio track file to load.
|
||||
@ -62,25 +61,26 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
if not os.path.exists(filepath):
|
||||
logging.error('cannot find the <%s> audio track file', filepath)
|
||||
raise exceptions.FileNotFoundError()
|
||||
return pydub.AudioSegment.from_file(
|
||||
filepath, format='wav', channels=channels)
|
||||
if not os.path.exists(filepath):
|
||||
logging.error('cannot find the <%s> audio track file', filepath)
|
||||
raise exceptions.FileNotFoundError()
|
||||
return pydub.AudioSegment.from_file(filepath,
|
||||
format='wav',
|
||||
channels=channels)
|
||||
|
||||
@classmethod
|
||||
def SaveWav(cls, output_filepath, signal):
|
||||
"""Saves wav file.
|
||||
@classmethod
|
||||
def SaveWav(cls, output_filepath, signal):
|
||||
"""Saves wav file.
|
||||
|
||||
Args:
|
||||
output_filepath: path to the wav audio track file to save.
|
||||
signal: AudioSegment instance.
|
||||
"""
|
||||
return signal.export(output_filepath, format='wav')
|
||||
return signal.export(output_filepath, format='wav')
|
||||
|
||||
@classmethod
|
||||
def CountSamples(cls, signal):
|
||||
"""Number of samples per channel.
|
||||
@classmethod
|
||||
def CountSamples(cls, signal):
|
||||
"""Number of samples per channel.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -88,14 +88,14 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An integer.
|
||||
"""
|
||||
number_of_samples = len(signal.get_array_of_samples())
|
||||
assert signal.channels > 0
|
||||
assert number_of_samples % signal.channels == 0
|
||||
return number_of_samples / signal.channels
|
||||
number_of_samples = len(signal.get_array_of_samples())
|
||||
assert signal.channels > 0
|
||||
assert number_of_samples % signal.channels == 0
|
||||
return number_of_samples / signal.channels
|
||||
|
||||
@classmethod
|
||||
def GenerateSilence(cls, duration=1000, sample_rate=48000):
|
||||
"""Generates silence.
|
||||
@classmethod
|
||||
def GenerateSilence(cls, duration=1000, sample_rate=48000):
|
||||
"""Generates silence.
|
||||
|
||||
This method can also be used to create a template AudioSegment instance.
|
||||
A template can then be used with other Generate*() methods accepting an
|
||||
@ -108,11 +108,11 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
return pydub.AudioSegment.silent(duration, sample_rate)
|
||||
return pydub.AudioSegment.silent(duration, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def GeneratePureTone(cls, template, frequency=440.0):
|
||||
"""Generates a pure tone.
|
||||
@classmethod
|
||||
def GeneratePureTone(cls, template, frequency=440.0):
|
||||
"""Generates a pure tone.
|
||||
|
||||
The pure tone is generated with the same duration and in the same format of
|
||||
the given template signal.
|
||||
@ -124,21 +124,18 @@ class SignalProcessingUtils(object):
|
||||
Return:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
if frequency > template.frame_rate >> 1:
|
||||
raise exceptions.SignalProcessingException('Invalid frequency')
|
||||
if frequency > template.frame_rate >> 1:
|
||||
raise exceptions.SignalProcessingException('Invalid frequency')
|
||||
|
||||
generator = pydub.generators.Sine(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8,
|
||||
freq=frequency)
|
||||
generator = pydub.generators.Sine(sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8,
|
||||
freq=frequency)
|
||||
|
||||
return generator.to_audio_segment(
|
||||
duration=len(template),
|
||||
volume=0.0)
|
||||
return generator.to_audio_segment(duration=len(template), volume=0.0)
|
||||
|
||||
@classmethod
|
||||
def GenerateWhiteNoise(cls, template):
|
||||
"""Generates white noise.
|
||||
@classmethod
|
||||
def GenerateWhiteNoise(cls, template):
|
||||
"""Generates white noise.
|
||||
|
||||
The white noise is generated with the same duration and in the same format
|
||||
of the given template signal.
|
||||
@ -149,33 +146,32 @@ class SignalProcessingUtils(object):
|
||||
Return:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
generator = pydub.generators.WhiteNoise(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8)
|
||||
return generator.to_audio_segment(
|
||||
duration=len(template),
|
||||
volume=0.0)
|
||||
generator = pydub.generators.WhiteNoise(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8)
|
||||
return generator.to_audio_segment(duration=len(template), volume=0.0)
|
||||
|
||||
@classmethod
|
||||
def AudioSegmentToRawData(cls, signal):
|
||||
samples = signal.get_array_of_samples()
|
||||
if samples.typecode != 'h':
|
||||
raise exceptions.SignalProcessingException('Unsupported samples type')
|
||||
return np.array(signal.get_array_of_samples(), np.int16)
|
||||
@classmethod
|
||||
def AudioSegmentToRawData(cls, signal):
|
||||
samples = signal.get_array_of_samples()
|
||||
if samples.typecode != 'h':
|
||||
raise exceptions.SignalProcessingException(
|
||||
'Unsupported samples type')
|
||||
return np.array(signal.get_array_of_samples(), np.int16)
|
||||
|
||||
@classmethod
|
||||
def Fft(cls, signal, normalize=True):
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel FFT not implemented')
|
||||
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
|
||||
if normalize:
|
||||
x /= max(abs(np.max(x)), 1.0)
|
||||
y = scipy.fftpack.fft(x)
|
||||
return y[:len(y) / 2]
|
||||
@classmethod
|
||||
def Fft(cls, signal, normalize=True):
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel FFT not implemented')
|
||||
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
|
||||
if normalize:
|
||||
x /= max(abs(np.max(x)), 1.0)
|
||||
y = scipy.fftpack.fft(x)
|
||||
return y[:len(y) / 2]
|
||||
|
||||
@classmethod
|
||||
def DetectHardClipping(cls, signal, threshold=2):
|
||||
"""Detects hard clipping.
|
||||
@classmethod
|
||||
def DetectHardClipping(cls, signal, threshold=2):
|
||||
"""Detects hard clipping.
|
||||
|
||||
Hard clipping is simply detected by counting samples that touch either the
|
||||
lower or upper bound too many times in a row (according to |threshold|).
|
||||
@ -189,32 +185,33 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
True if hard clipping is detect, False otherwise.
|
||||
"""
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel clipping not implemented')
|
||||
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
||||
raise exceptions.SignalProcessingException(
|
||||
'hard-clipping detection only supported for 16 bit samples')
|
||||
samples = cls.AudioSegmentToRawData(signal)
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError(
|
||||
'multiple-channel clipping not implemented')
|
||||
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
||||
raise exceptions.SignalProcessingException(
|
||||
'hard-clipping detection only supported for 16 bit samples')
|
||||
samples = cls.AudioSegmentToRawData(signal)
|
||||
|
||||
# Detect adjacent clipped samples.
|
||||
samples_type_info = np.iinfo(samples.dtype)
|
||||
mask_min = samples == samples_type_info.min
|
||||
mask_max = samples == samples_type_info.max
|
||||
# Detect adjacent clipped samples.
|
||||
samples_type_info = np.iinfo(samples.dtype)
|
||||
mask_min = samples == samples_type_info.min
|
||||
mask_max = samples == samples_type_info.max
|
||||
|
||||
def HasLongSequence(vector, min_legth=threshold):
|
||||
"""Returns True if there are one or more long sequences of True flags."""
|
||||
seq_length = 0
|
||||
for b in vector:
|
||||
seq_length = seq_length + 1 if b else 0
|
||||
if seq_length >= min_legth:
|
||||
return True
|
||||
return False
|
||||
def HasLongSequence(vector, min_legth=threshold):
|
||||
"""Returns True if there are one or more long sequences of True flags."""
|
||||
seq_length = 0
|
||||
for b in vector:
|
||||
seq_length = seq_length + 1 if b else 0
|
||||
if seq_length >= min_legth:
|
||||
return True
|
||||
return False
|
||||
|
||||
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
|
||||
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
|
||||
|
||||
@classmethod
|
||||
def ApplyImpulseResponse(cls, signal, impulse_response):
|
||||
"""Applies an impulse response to a signal.
|
||||
@classmethod
|
||||
def ApplyImpulseResponse(cls, signal, impulse_response):
|
||||
"""Applies an impulse response to a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -223,44 +220,48 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
# Get samples.
|
||||
assert signal.channels == 1, (
|
||||
'multiple-channel recordings not supported')
|
||||
samples = signal.get_array_of_samples()
|
||||
# Get samples.
|
||||
assert signal.channels == 1, (
|
||||
'multiple-channel recordings not supported')
|
||||
samples = signal.get_array_of_samples()
|
||||
|
||||
# Convolve.
|
||||
logging.info('applying %d order impulse response to a signal lasting %d ms',
|
||||
len(impulse_response), len(signal))
|
||||
convolved_samples = scipy.signal.fftconvolve(
|
||||
in1=samples,
|
||||
in2=impulse_response,
|
||||
mode='full').astype(np.int16)
|
||||
logging.info('convolution computed')
|
||||
# Convolve.
|
||||
logging.info(
|
||||
'applying %d order impulse response to a signal lasting %d ms',
|
||||
len(impulse_response), len(signal))
|
||||
convolved_samples = scipy.signal.fftconvolve(in1=samples,
|
||||
in2=impulse_response,
|
||||
mode='full').astype(
|
||||
np.int16)
|
||||
logging.info('convolution computed')
|
||||
|
||||
# Cast.
|
||||
convolved_samples = array.array(signal.array_type, convolved_samples)
|
||||
# Cast.
|
||||
convolved_samples = array.array(signal.array_type, convolved_samples)
|
||||
|
||||
# Verify.
|
||||
logging.debug('signal length: %d samples', len(samples))
|
||||
logging.debug('convolved signal length: %d samples', len(convolved_samples))
|
||||
assert len(convolved_samples) > len(samples)
|
||||
# Verify.
|
||||
logging.debug('signal length: %d samples', len(samples))
|
||||
logging.debug('convolved signal length: %d samples',
|
||||
len(convolved_samples))
|
||||
assert len(convolved_samples) > len(samples)
|
||||
|
||||
# Generate convolved signal AudioSegment instance.
|
||||
convolved_signal = pydub.AudioSegment(
|
||||
data=convolved_samples,
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
assert len(convolved_signal) > len(signal)
|
||||
# Generate convolved signal AudioSegment instance.
|
||||
convolved_signal = pydub.AudioSegment(data=convolved_samples,
|
||||
metadata={
|
||||
'sample_width':
|
||||
signal.sample_width,
|
||||
'frame_rate':
|
||||
signal.frame_rate,
|
||||
'frame_width':
|
||||
signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
assert len(convolved_signal) > len(signal)
|
||||
|
||||
return convolved_signal
|
||||
return convolved_signal
|
||||
|
||||
@classmethod
|
||||
def Normalize(cls, signal):
|
||||
"""Normalizes a signal.
|
||||
@classmethod
|
||||
def Normalize(cls, signal):
|
||||
"""Normalizes a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -268,11 +269,11 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
return signal.apply_gain(-signal.max_dBFS)
|
||||
return signal.apply_gain(-signal.max_dBFS)
|
||||
|
||||
@classmethod
|
||||
def Copy(cls, signal):
|
||||
"""Makes a copy os a signal.
|
||||
@classmethod
|
||||
def Copy(cls, signal):
|
||||
"""Makes a copy os a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
@ -280,19 +281,21 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
return pydub.AudioSegment(
|
||||
data=signal.get_array_of_samples(),
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
return pydub.AudioSegment(data=signal.get_array_of_samples(),
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def MixSignals(cls, signal, noise, target_snr=0.0,
|
||||
pad_noise=MixPadding.NO_PADDING):
|
||||
"""Mixes |signal| and |noise| with a target SNR.
|
||||
@classmethod
|
||||
def MixSignals(cls,
|
||||
signal,
|
||||
noise,
|
||||
target_snr=0.0,
|
||||
pad_noise=MixPadding.NO_PADDING):
|
||||
"""Mixes |signal| and |noise| with a target SNR.
|
||||
|
||||
Mix |signal| and |noise| with a desired SNR by scaling |noise|.
|
||||
If the target SNR is +/- infinite, a copy of signal/noise is returned.
|
||||
@ -312,45 +315,45 @@ class SignalProcessingUtils(object):
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
# Handle infinite target SNR.
|
||||
if target_snr == -np.Inf:
|
||||
# Return a copy of noise.
|
||||
logging.warning('SNR = -Inf, returning noise')
|
||||
return cls.Copy(noise)
|
||||
elif target_snr == np.Inf:
|
||||
# Return a copy of signal.
|
||||
logging.warning('SNR = +Inf, returning signal')
|
||||
return cls.Copy(signal)
|
||||
# Handle infinite target SNR.
|
||||
if target_snr == -np.Inf:
|
||||
# Return a copy of noise.
|
||||
logging.warning('SNR = -Inf, returning noise')
|
||||
return cls.Copy(noise)
|
||||
elif target_snr == np.Inf:
|
||||
# Return a copy of signal.
|
||||
logging.warning('SNR = +Inf, returning signal')
|
||||
return cls.Copy(signal)
|
||||
|
||||
# Check signal and noise power.
|
||||
signal_power = float(signal.dBFS)
|
||||
noise_power = float(noise.dBFS)
|
||||
if signal_power == -np.Inf:
|
||||
logging.error('signal has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
if noise_power == -np.Inf:
|
||||
logging.error('noise has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
# Check signal and noise power.
|
||||
signal_power = float(signal.dBFS)
|
||||
noise_power = float(noise.dBFS)
|
||||
if signal_power == -np.Inf:
|
||||
logging.error('signal has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
if noise_power == -np.Inf:
|
||||
logging.error('noise has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
|
||||
# Mix.
|
||||
gain_db = signal_power - noise_power - target_snr
|
||||
signal_duration = len(signal)
|
||||
noise_duration = len(noise)
|
||||
if signal_duration <= noise_duration:
|
||||
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
|
||||
# mix will have the same length of |signal|.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.NO_PADDING:
|
||||
# |signal| is longer than |noise|, but no padding is applied to |noise|.
|
||||
# Truncate |signal|.
|
||||
return noise.overlay(signal, gain_during_overlay=gain_db)
|
||||
elif pad_noise == cls.MixPadding.ZERO_PADDING:
|
||||
# TODO(alessiob): Check that this works as expected.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.LOOP:
|
||||
# |signal| is longer than |noise|, extend |noise| by looping.
|
||||
return signal.overlay(noise.apply_gain(gain_db), loop=True)
|
||||
else:
|
||||
raise exceptions.SignalProcessingException('invalid padding type')
|
||||
# Mix.
|
||||
gain_db = signal_power - noise_power - target_snr
|
||||
signal_duration = len(signal)
|
||||
noise_duration = len(noise)
|
||||
if signal_duration <= noise_duration:
|
||||
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
|
||||
# mix will have the same length of |signal|.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.NO_PADDING:
|
||||
# |signal| is longer than |noise|, but no padding is applied to |noise|.
|
||||
# Truncate |signal|.
|
||||
return noise.overlay(signal, gain_during_overlay=gain_db)
|
||||
elif pad_noise == cls.MixPadding.ZERO_PADDING:
|
||||
# TODO(alessiob): Check that this works as expected.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.LOOP:
|
||||
# |signal| is longer than |noise|, extend |noise| by looping.
|
||||
return signal.overlay(noise.apply_gain(gain_db), loop=True)
|
||||
else:
|
||||
raise exceptions.SignalProcessingException('invalid padding type')
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""
|
||||
|
||||
@ -19,168 +18,166 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestSignalProcessing(unittest.TestCase):
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""
|
||||
|
||||
def testMixSignals(self):
|
||||
# Generate a template signal with which white noise can be generated.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
def testMixSignals(self):
|
||||
# Generate a template signal with which white noise can be generated.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
|
||||
# Generate two distinct AudioSegment instances with 1 second of white noise.
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
# Generate two distinct AudioSegment instances with 1 second of white noise.
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
|
||||
# Extract samples.
|
||||
signal_samples = signal.get_array_of_samples()
|
||||
noise_samples = noise.get_array_of_samples()
|
||||
# Extract samples.
|
||||
signal_samples = signal.get_array_of_samples()
|
||||
noise_samples = noise.get_array_of_samples()
|
||||
|
||||
# Test target SNR -Inf (noise expected).
|
||||
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, -np.Inf)
|
||||
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
|
||||
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
|
||||
# Test target SNR -Inf (noise expected).
|
||||
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, -np.Inf)
|
||||
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
|
||||
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
|
||||
|
||||
# Test target SNR 0.0 (different data expected).
|
||||
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, 0.0)
|
||||
self.assertTrue(len(signal), len(mix_0)) # Check duration.
|
||||
self.assertTrue(len(noise), len(mix_0))
|
||||
mix_0_samples = mix_0.get_array_of_samples()
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
|
||||
# Test target SNR 0.0 (different data expected).
|
||||
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, 0.0)
|
||||
self.assertTrue(len(signal), len(mix_0)) # Check duration.
|
||||
self.assertTrue(len(noise), len(mix_0))
|
||||
mix_0_samples = mix_0.get_array_of_samples()
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
|
||||
|
||||
# Test target SNR +Inf (signal expected).
|
||||
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, np.Inf)
|
||||
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
|
||||
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
|
||||
# Test target SNR +Inf (signal expected).
|
||||
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, np.Inf)
|
||||
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
|
||||
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
|
||||
|
||||
def testMixSignalsMinInfPower(self):
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
def testMixSignalsMinInfPower(self):
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, silence, 0.0)
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, silence, 0.0)
|
||||
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
silence, signal, 0.0)
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
silence, signal, 0.0)
|
||||
|
||||
def testMixSignalNoiseDifferentLengths(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
|
||||
def testMixSignalNoiseDifferentLengths(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
|
||||
|
||||
# When the signal is shorter than the noise, the mix length always equals
|
||||
# that of the signal regardless of whether padding is applied.
|
||||
# No noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# When the signal is shorter than the noise, the mix length always equals
|
||||
# that of the signal regardless of whether padding is applied.
|
||||
# No noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
|
||||
# When the signal is longer than the noise, the mix length depends on
|
||||
# whether padding is applied.
|
||||
# No noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
|
||||
self.assertEqual(len(longer), len(mix))
|
||||
# When the signal is longer than the noise, the mix length depends on
|
||||
# whether padding is applied.
|
||||
# No noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
self.assertEqual(len(longer), len(mix))
|
||||
|
||||
def testMixSignalNoisePaddingTypes(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
|
||||
def testMixSignalNoisePaddingTypes(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
|
||||
|
||||
# Zero padding: expect pure tone only in 1-2s.
|
||||
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
|
||||
# Zero padding: expect pure tone only in 1-2s.
|
||||
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
|
||||
# Loop: expect pure tone plus noise in 1-2s.
|
||||
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
# Loop: expect pure tone plus noise in 1-2s.
|
||||
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
|
||||
def Energy(signal):
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
signal).astype(np.float32)
|
||||
return np.sum(samples * samples)
|
||||
def Energy(signal):
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
signal).astype(np.float32)
|
||||
return np.sum(samples * samples)
|
||||
|
||||
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
|
||||
e_mix_loop = Energy(mix_loop[-1000:])
|
||||
self.assertLess(0, e_mix_zero_pad)
|
||||
self.assertLess(e_mix_zero_pad, e_mix_loop)
|
||||
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
|
||||
e_mix_loop = Energy(mix_loop[-1000:])
|
||||
self.assertLess(0, e_mix_zero_pad)
|
||||
self.assertLess(e_mix_zero_pad, e_mix_loop)
|
||||
|
||||
def testMixSignalSnr(self):
|
||||
# Test signals.
|
||||
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
|
||||
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
|
||||
def testMixSignalSnr(self):
|
||||
# Test signals.
|
||||
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
|
||||
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
|
||||
|
||||
def ToneAmplitudes(mix):
|
||||
"""Returns the amplitude of the coefficients #16 and #192, which
|
||||
def ToneAmplitudes(mix):
|
||||
"""Returns the amplitude of the coefficients #16 and #192, which
|
||||
correspond to the tones at 250 and 3k Hz respectively."""
|
||||
mix_fft = np.absolute(signal_processing.SignalProcessingUtils.Fft(mix))
|
||||
return mix_fft[16], mix_fft[192]
|
||||
mix_fft = np.absolute(
|
||||
signal_processing.SignalProcessingUtils.Fft(mix))
|
||||
return mix_fft[16], mix_fft[192]
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low,
|
||||
noise=tone_high,
|
||||
target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low, noise=tone_high, target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high,
|
||||
noise=tone_low,
|
||||
target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high, noise=tone_low, target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low,
|
||||
noise=tone_high,
|
||||
target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low, noise=tone_high, target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high,
|
||||
noise=tone_low,
|
||||
target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high, noise=tone_low, target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""APM module simulator.
|
||||
"""
|
||||
|
||||
@ -25,85 +24,93 @@ from . import test_data_generation
|
||||
|
||||
|
||||
class ApmModuleSimulator(object):
|
||||
"""Audio processing module (APM) simulator class.
|
||||
"""Audio processing module (APM) simulator class.
|
||||
"""
|
||||
|
||||
_TEST_DATA_GENERATOR_CLASSES = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
_TEST_DATA_GENERATOR_CLASSES = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
|
||||
_PREFIX_APM_CONFIG = 'apmcfg-'
|
||||
_PREFIX_CAPTURE = 'capture-'
|
||||
_PREFIX_RENDER = 'render-'
|
||||
_PREFIX_ECHO_SIMULATOR = 'echosim-'
|
||||
_PREFIX_TEST_DATA_GEN = 'datagen-'
|
||||
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
|
||||
_PREFIX_SCORE = 'score-'
|
||||
_PREFIX_APM_CONFIG = 'apmcfg-'
|
||||
_PREFIX_CAPTURE = 'capture-'
|
||||
_PREFIX_RENDER = 'render-'
|
||||
_PREFIX_ECHO_SIMULATOR = 'echosim-'
|
||||
_PREFIX_TEST_DATA_GEN = 'datagen-'
|
||||
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
|
||||
_PREFIX_SCORE = 'score-'
|
||||
|
||||
def __init__(self, test_data_generator_factory, evaluation_score_factory,
|
||||
ap_wrapper, evaluator, external_vads=None):
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._test_data_generator_factory = test_data_generator_factory
|
||||
self._evaluation_score_factory = evaluation_score_factory
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
self._annotator = annotations.AudioAnnotationsExtractor(
|
||||
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
|
||||
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
|
||||
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
|
||||
external_vads
|
||||
)
|
||||
def __init__(self,
|
||||
test_data_generator_factory,
|
||||
evaluation_score_factory,
|
||||
ap_wrapper,
|
||||
evaluator,
|
||||
external_vads=None):
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._test_data_generator_factory = test_data_generator_factory
|
||||
self._evaluation_score_factory = evaluation_score_factory
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
self._annotator = annotations.AudioAnnotationsExtractor(
|
||||
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD
|
||||
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO
|
||||
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
|
||||
external_vads)
|
||||
|
||||
# Init.
|
||||
self._test_data_generator_factory.SetOutputDirectoryPrefix(
|
||||
self._PREFIX_TEST_DATA_GEN_PARAMS)
|
||||
self._evaluation_score_factory.SetScoreFilenamePrefix(
|
||||
self._PREFIX_SCORE)
|
||||
# Init.
|
||||
self._test_data_generator_factory.SetOutputDirectoryPrefix(
|
||||
self._PREFIX_TEST_DATA_GEN_PARAMS)
|
||||
self._evaluation_score_factory.SetScoreFilenamePrefix(
|
||||
self._PREFIX_SCORE)
|
||||
|
||||
# Properties for each run.
|
||||
self._base_output_path = None
|
||||
self._output_cache_path = None
|
||||
self._test_data_generators = None
|
||||
self._evaluation_score_workers = None
|
||||
self._config_filepaths = None
|
||||
self._capture_input_filepaths = None
|
||||
self._render_input_filepaths = None
|
||||
self._echo_path_simulator_class = None
|
||||
# Properties for each run.
|
||||
self._base_output_path = None
|
||||
self._output_cache_path = None
|
||||
self._test_data_generators = None
|
||||
self._evaluation_score_workers = None
|
||||
self._config_filepaths = None
|
||||
self._capture_input_filepaths = None
|
||||
self._render_input_filepaths = None
|
||||
self._echo_path_simulator_class = None
|
||||
|
||||
@classmethod
|
||||
def GetPrefixApmConfig(cls):
|
||||
return cls._PREFIX_APM_CONFIG
|
||||
@classmethod
|
||||
def GetPrefixApmConfig(cls):
|
||||
return cls._PREFIX_APM_CONFIG
|
||||
|
||||
@classmethod
|
||||
def GetPrefixCapture(cls):
|
||||
return cls._PREFIX_CAPTURE
|
||||
@classmethod
|
||||
def GetPrefixCapture(cls):
|
||||
return cls._PREFIX_CAPTURE
|
||||
|
||||
@classmethod
|
||||
def GetPrefixRender(cls):
|
||||
return cls._PREFIX_RENDER
|
||||
@classmethod
|
||||
def GetPrefixRender(cls):
|
||||
return cls._PREFIX_RENDER
|
||||
|
||||
@classmethod
|
||||
def GetPrefixEchoSimulator(cls):
|
||||
return cls._PREFIX_ECHO_SIMULATOR
|
||||
@classmethod
|
||||
def GetPrefixEchoSimulator(cls):
|
||||
return cls._PREFIX_ECHO_SIMULATOR
|
||||
|
||||
@classmethod
|
||||
def GetPrefixTestDataGenerator(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN
|
||||
@classmethod
|
||||
def GetPrefixTestDataGenerator(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN
|
||||
|
||||
@classmethod
|
||||
def GetPrefixTestDataGeneratorParameters(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN_PARAMS
|
||||
@classmethod
|
||||
def GetPrefixTestDataGeneratorParameters(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN_PARAMS
|
||||
|
||||
@classmethod
|
||||
def GetPrefixScore(cls):
|
||||
return cls._PREFIX_SCORE
|
||||
@classmethod
|
||||
def GetPrefixScore(cls):
|
||||
return cls._PREFIX_SCORE
|
||||
|
||||
def Run(self, config_filepaths, capture_input_filepaths,
|
||||
test_data_generator_names, eval_score_names, output_dir,
|
||||
render_input_filepaths=None, echo_path_simulator_name=(
|
||||
echo_path_simulation.NoEchoPathSimulator.NAME)):
|
||||
"""Runs the APM simulation.
|
||||
def Run(self,
|
||||
config_filepaths,
|
||||
capture_input_filepaths,
|
||||
test_data_generator_names,
|
||||
eval_score_names,
|
||||
output_dir,
|
||||
render_input_filepaths=None,
|
||||
echo_path_simulator_name=(
|
||||
echo_path_simulation.NoEchoPathSimulator.NAME)):
|
||||
"""Runs the APM simulation.
|
||||
|
||||
Initializes paths and required instances, then runs all the simulations.
|
||||
The render input can be optionally added. If added, the number of capture
|
||||
@ -120,132 +127,140 @@ class ApmModuleSimulator(object):
|
||||
echo_path_simulator_name: name of the echo path simulator to use when
|
||||
render input is provided.
|
||||
"""
|
||||
assert render_input_filepaths is None or (
|
||||
len(capture_input_filepaths) == len(render_input_filepaths)), (
|
||||
'render input set size not matching input set size')
|
||||
assert render_input_filepaths is None or echo_path_simulator_name in (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
|
||||
'invalid echo path simulator')
|
||||
self._base_output_path = os.path.abspath(output_dir)
|
||||
assert render_input_filepaths is None or (
|
||||
len(capture_input_filepaths) == len(render_input_filepaths)), (
|
||||
'render input set size not matching input set size')
|
||||
assert render_input_filepaths is None or echo_path_simulator_name in (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
|
||||
'invalid echo path simulator')
|
||||
self._base_output_path = os.path.abspath(output_dir)
|
||||
|
||||
# Output path used to cache the data shared across simulations.
|
||||
self._output_cache_path = os.path.join(self._base_output_path, '_cache')
|
||||
# Output path used to cache the data shared across simulations.
|
||||
self._output_cache_path = os.path.join(self._base_output_path,
|
||||
'_cache')
|
||||
|
||||
# Instance test data generators.
|
||||
self._test_data_generators = [self._test_data_generator_factory.GetInstance(
|
||||
test_data_generators_class=(
|
||||
self._TEST_DATA_GENERATOR_CLASSES[name])) for name in (
|
||||
test_data_generator_names)]
|
||||
# Instance test data generators.
|
||||
self._test_data_generators = [
|
||||
self._test_data_generator_factory.GetInstance(
|
||||
test_data_generators_class=(
|
||||
self._TEST_DATA_GENERATOR_CLASSES[name]))
|
||||
for name in (test_data_generator_names)
|
||||
]
|
||||
|
||||
# Instance evaluation score workers.
|
||||
self._evaluation_score_workers = [
|
||||
self._evaluation_score_factory.GetInstance(
|
||||
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name]) for (
|
||||
name) in eval_score_names]
|
||||
# Instance evaluation score workers.
|
||||
self._evaluation_score_workers = [
|
||||
self._evaluation_score_factory.GetInstance(
|
||||
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name])
|
||||
for (name) in eval_score_names
|
||||
]
|
||||
|
||||
# Set APM configuration file paths.
|
||||
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
|
||||
# Set APM configuration file paths.
|
||||
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
|
||||
|
||||
# Set probing signal file paths.
|
||||
if render_input_filepaths is None:
|
||||
# Capture input only.
|
||||
self._capture_input_filepaths = self._CreatePathsCollection(
|
||||
capture_input_filepaths)
|
||||
self._render_input_filepaths = None
|
||||
else:
|
||||
# Set both capture and render input signals.
|
||||
self._SetTestInputSignalFilePaths(
|
||||
capture_input_filepaths, render_input_filepaths)
|
||||
# Set probing signal file paths.
|
||||
if render_input_filepaths is None:
|
||||
# Capture input only.
|
||||
self._capture_input_filepaths = self._CreatePathsCollection(
|
||||
capture_input_filepaths)
|
||||
self._render_input_filepaths = None
|
||||
else:
|
||||
# Set both capture and render input signals.
|
||||
self._SetTestInputSignalFilePaths(capture_input_filepaths,
|
||||
render_input_filepaths)
|
||||
|
||||
# Set the echo path simulator class.
|
||||
self._echo_path_simulator_class = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES[
|
||||
echo_path_simulator_name])
|
||||
# Set the echo path simulator class.
|
||||
self._echo_path_simulator_class = (
|
||||
echo_path_simulation.EchoPathSimulator.
|
||||
REGISTERED_CLASSES[echo_path_simulator_name])
|
||||
|
||||
self._SimulateAll()
|
||||
self._SimulateAll()
|
||||
|
||||
def _SimulateAll(self):
|
||||
"""Runs all the simulations.
|
||||
def _SimulateAll(self):
|
||||
"""Runs all the simulations.
|
||||
|
||||
Iterates over the combinations of APM configurations, probing signals, and
|
||||
test data generators. This method is mainly responsible for the creation of
|
||||
the cache and output directories required in order to call _Simulate().
|
||||
"""
|
||||
without_render_input = self._render_input_filepaths is None
|
||||
without_render_input = self._render_input_filepaths is None
|
||||
|
||||
# Try different APM config files.
|
||||
for config_name in self._config_filepaths:
|
||||
config_filepath = self._config_filepaths[config_name]
|
||||
# Try different APM config files.
|
||||
for config_name in self._config_filepaths:
|
||||
config_filepath = self._config_filepaths[config_name]
|
||||
|
||||
# Try different capture-render pairs.
|
||||
for capture_input_name in self._capture_input_filepaths:
|
||||
# Output path for the capture signal annotations.
|
||||
capture_annotations_cache_path = os.path.join(
|
||||
self._output_cache_path,
|
||||
self._PREFIX_CAPTURE + capture_input_name)
|
||||
data_access.MakeDirectory(capture_annotations_cache_path)
|
||||
# Try different capture-render pairs.
|
||||
for capture_input_name in self._capture_input_filepaths:
|
||||
# Output path for the capture signal annotations.
|
||||
capture_annotations_cache_path = os.path.join(
|
||||
self._output_cache_path,
|
||||
self._PREFIX_CAPTURE + capture_input_name)
|
||||
data_access.MakeDirectory(capture_annotations_cache_path)
|
||||
|
||||
# Capture.
|
||||
capture_input_filepath = self._capture_input_filepaths[
|
||||
capture_input_name]
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
# If the input signal file does not exist, try to create using the
|
||||
# available input signal creators.
|
||||
self._CreateInputSignal(capture_input_filepath)
|
||||
assert os.path.exists(capture_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
capture_input_filepath, capture_annotations_cache_path)
|
||||
# Capture.
|
||||
capture_input_filepath = self._capture_input_filepaths[
|
||||
capture_input_name]
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
# If the input signal file does not exist, try to create using the
|
||||
# available input signal creators.
|
||||
self._CreateInputSignal(capture_input_filepath)
|
||||
assert os.path.exists(capture_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
capture_input_filepath, capture_annotations_cache_path)
|
||||
|
||||
# Render and simulated echo path (optional).
|
||||
render_input_filepath = None if without_render_input else (
|
||||
self._render_input_filepaths[capture_input_name])
|
||||
render_input_name = '(none)' if without_render_input else (
|
||||
self._ExtractFileName(render_input_filepath))
|
||||
echo_path_simulator = (
|
||||
echo_path_simulation_factory.EchoPathSimulatorFactory.GetInstance(
|
||||
self._echo_path_simulator_class, render_input_filepath))
|
||||
# Render and simulated echo path (optional).
|
||||
render_input_filepath = None if without_render_input else (
|
||||
self._render_input_filepaths[capture_input_name])
|
||||
render_input_name = '(none)' if without_render_input else (
|
||||
self._ExtractFileName(render_input_filepath))
|
||||
echo_path_simulator = (echo_path_simulation_factory.
|
||||
EchoPathSimulatorFactory.GetInstance(
|
||||
self._echo_path_simulator_class,
|
||||
render_input_filepath))
|
||||
|
||||
# Try different test data generators.
|
||||
for test_data_generators in self._test_data_generators:
|
||||
logging.info('APM config preset: <%s>, capture: <%s>, render: <%s>,'
|
||||
'test data generator: <%s>, echo simulator: <%s>',
|
||||
config_name, capture_input_name, render_input_name,
|
||||
test_data_generators.NAME, echo_path_simulator.NAME)
|
||||
# Try different test data generators.
|
||||
for test_data_generators in self._test_data_generators:
|
||||
logging.info(
|
||||
'APM config preset: <%s>, capture: <%s>, render: <%s>,'
|
||||
'test data generator: <%s>, echo simulator: <%s>',
|
||||
config_name, capture_input_name, render_input_name,
|
||||
test_data_generators.NAME, echo_path_simulator.NAME)
|
||||
|
||||
# Output path for the generated test data.
|
||||
test_data_cache_path = os.path.join(
|
||||
capture_annotations_cache_path,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(test_data_cache_path)
|
||||
logging.debug('test data cache path: <%s>', test_data_cache_path)
|
||||
# Output path for the generated test data.
|
||||
test_data_cache_path = os.path.join(
|
||||
capture_annotations_cache_path,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(test_data_cache_path)
|
||||
logging.debug('test data cache path: <%s>',
|
||||
test_data_cache_path)
|
||||
|
||||
# Output path for the echo simulator and APM input mixer output.
|
||||
echo_test_data_cache_path = os.path.join(
|
||||
test_data_cache_path, 'echosim-{}'.format(
|
||||
echo_path_simulator.NAME))
|
||||
data_access.MakeDirectory(echo_test_data_cache_path)
|
||||
logging.debug('echo test data cache path: <%s>',
|
||||
echo_test_data_cache_path)
|
||||
# Output path for the echo simulator and APM input mixer output.
|
||||
echo_test_data_cache_path = os.path.join(
|
||||
test_data_cache_path,
|
||||
'echosim-{}'.format(echo_path_simulator.NAME))
|
||||
data_access.MakeDirectory(echo_test_data_cache_path)
|
||||
logging.debug('echo test data cache path: <%s>',
|
||||
echo_test_data_cache_path)
|
||||
|
||||
# Full output path.
|
||||
output_path = os.path.join(
|
||||
self._base_output_path,
|
||||
self._PREFIX_APM_CONFIG + config_name,
|
||||
self._PREFIX_CAPTURE + capture_input_name,
|
||||
self._PREFIX_RENDER + render_input_name,
|
||||
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(output_path)
|
||||
logging.debug('output path: <%s>', output_path)
|
||||
# Full output path.
|
||||
output_path = os.path.join(
|
||||
self._base_output_path,
|
||||
self._PREFIX_APM_CONFIG + config_name,
|
||||
self._PREFIX_CAPTURE + capture_input_name,
|
||||
self._PREFIX_RENDER + render_input_name,
|
||||
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(output_path)
|
||||
logging.debug('output path: <%s>', output_path)
|
||||
|
||||
self._Simulate(test_data_generators, capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path,
|
||||
config_filepath, echo_path_simulator)
|
||||
self._Simulate(test_data_generators,
|
||||
capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path,
|
||||
config_filepath, echo_path_simulator)
|
||||
|
||||
@staticmethod
|
||||
def _CreateInputSignal(input_signal_filepath):
|
||||
"""Creates a missing input signal file.
|
||||
@staticmethod
|
||||
def _CreateInputSignal(input_signal_filepath):
|
||||
"""Creates a missing input signal file.
|
||||
|
||||
The file name is parsed to extract input signal creator and params. If a
|
||||
creator is matched and the parameters are valid, a new signal is generated
|
||||
@ -257,30 +272,33 @@ class ApmModuleSimulator(object):
|
||||
Raises:
|
||||
InputSignalCreatorException
|
||||
"""
|
||||
filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0]
|
||||
filename_parts = filename.split('-')
|
||||
filename = os.path.splitext(
|
||||
os.path.split(input_signal_filepath)[-1])[0]
|
||||
filename_parts = filename.split('-')
|
||||
|
||||
if len(filename_parts) < 2:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Cannot parse input signal file name')
|
||||
if len(filename_parts) < 2:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Cannot parse input signal file name')
|
||||
|
||||
signal, metadata = input_signal_creator.InputSignalCreator.Create(
|
||||
filename_parts[0], filename_parts[1].split('_'))
|
||||
signal, metadata = input_signal_creator.InputSignalCreator.Create(
|
||||
filename_parts[0], filename_parts[1].split('_'))
|
||||
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
input_signal_filepath, signal)
|
||||
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
input_signal_filepath, signal)
|
||||
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
|
||||
|
||||
def _ExtractCaptureAnnotations(self, input_filepath, output_path,
|
||||
annotation_name=""):
|
||||
self._annotator.Extract(input_filepath)
|
||||
self._annotator.Save(output_path, annotation_name)
|
||||
def _ExtractCaptureAnnotations(self,
|
||||
input_filepath,
|
||||
output_path,
|
||||
annotation_name=""):
|
||||
self._annotator.Extract(input_filepath)
|
||||
self._annotator.Save(output_path, annotation_name)
|
||||
|
||||
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path, config_filepath,
|
||||
echo_path_simulator):
|
||||
"""Runs a single set of simulation.
|
||||
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path, config_filepath,
|
||||
echo_path_simulator):
|
||||
"""Runs a single set of simulation.
|
||||
|
||||
Simulates a given combination of APM configuration, probing signal, and
|
||||
test data generator. It iterates over the test data generator
|
||||
@ -298,90 +316,92 @@ class ApmModuleSimulator(object):
|
||||
config_filepath: APM configuration file to test.
|
||||
echo_path_simulator: EchoPathSimulator instance.
|
||||
"""
|
||||
# Generate pairs of noisy input and reference signal files.
|
||||
test_data_generators.Generate(
|
||||
input_signal_filepath=clean_capture_input_filepath,
|
||||
test_data_cache_path=test_data_cache_path,
|
||||
base_output_path=output_path)
|
||||
# Generate pairs of noisy input and reference signal files.
|
||||
test_data_generators.Generate(
|
||||
input_signal_filepath=clean_capture_input_filepath,
|
||||
test_data_cache_path=test_data_cache_path,
|
||||
base_output_path=output_path)
|
||||
|
||||
# Extract metadata linked to the clean input file (if any).
|
||||
apm_input_metadata = None
|
||||
try:
|
||||
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
|
||||
clean_capture_input_filepath)
|
||||
except IOError as e:
|
||||
apm_input_metadata = {}
|
||||
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
|
||||
apm_input_metadata['test_data_gen_config'] = None
|
||||
# Extract metadata linked to the clean input file (if any).
|
||||
apm_input_metadata = None
|
||||
try:
|
||||
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
|
||||
clean_capture_input_filepath)
|
||||
except IOError as e:
|
||||
apm_input_metadata = {}
|
||||
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
|
||||
apm_input_metadata['test_data_gen_config'] = None
|
||||
|
||||
# For each test data pair, simulate a call and evaluate.
|
||||
for config_name in test_data_generators.config_names:
|
||||
logging.info(' - test data generator config: <%s>', config_name)
|
||||
apm_input_metadata['test_data_gen_config'] = config_name
|
||||
# For each test data pair, simulate a call and evaluate.
|
||||
for config_name in test_data_generators.config_names:
|
||||
logging.info(' - test data generator config: <%s>', config_name)
|
||||
apm_input_metadata['test_data_gen_config'] = config_name
|
||||
|
||||
# Paths to the test data generator output.
|
||||
# Note that the reference signal does not depend on the render input
|
||||
# which is optional.
|
||||
noisy_capture_input_filepath = (
|
||||
test_data_generators.noisy_signal_filepaths[config_name])
|
||||
reference_signal_filepath = (
|
||||
test_data_generators.reference_signal_filepaths[config_name])
|
||||
# Paths to the test data generator output.
|
||||
# Note that the reference signal does not depend on the render input
|
||||
# which is optional.
|
||||
noisy_capture_input_filepath = (
|
||||
test_data_generators.noisy_signal_filepaths[config_name])
|
||||
reference_signal_filepath = (
|
||||
test_data_generators.reference_signal_filepaths[config_name])
|
||||
|
||||
# Output path for the evaluation (e.g., APM output file).
|
||||
evaluation_output_path = test_data_generators.apm_output_paths[
|
||||
config_name]
|
||||
# Output path for the evaluation (e.g., APM output file).
|
||||
evaluation_output_path = test_data_generators.apm_output_paths[
|
||||
config_name]
|
||||
|
||||
# Paths to the APM input signals.
|
||||
echo_path_filepath = echo_path_simulator.Simulate(
|
||||
echo_test_data_cache_path)
|
||||
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
echo_test_data_cache_path, noisy_capture_input_filepath,
|
||||
echo_path_filepath)
|
||||
# Paths to the APM input signals.
|
||||
echo_path_filepath = echo_path_simulator.Simulate(
|
||||
echo_test_data_cache_path)
|
||||
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
echo_test_data_cache_path, noisy_capture_input_filepath,
|
||||
echo_path_filepath)
|
||||
|
||||
# Extract annotations for the APM input mix.
|
||||
apm_input_basepath, apm_input_filename = os.path.split(
|
||||
apm_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
apm_input_filepath, apm_input_basepath,
|
||||
os.path.splitext(apm_input_filename)[0] + '-')
|
||||
# Extract annotations for the APM input mix.
|
||||
apm_input_basepath, apm_input_filename = os.path.split(
|
||||
apm_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
apm_input_filepath, apm_input_basepath,
|
||||
os.path.splitext(apm_input_filename)[0] + '-')
|
||||
|
||||
# Simulate a call using APM.
|
||||
self._audioproc_wrapper.Run(
|
||||
config_filepath=config_filepath,
|
||||
capture_input_filepath=apm_input_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path)
|
||||
# Simulate a call using APM.
|
||||
self._audioproc_wrapper.Run(
|
||||
config_filepath=config_filepath,
|
||||
capture_input_filepath=apm_input_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path)
|
||||
|
||||
try:
|
||||
# Evaluate.
|
||||
self._evaluator.Run(
|
||||
evaluation_score_workers=self._evaluation_score_workers,
|
||||
apm_input_metadata=apm_input_metadata,
|
||||
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
||||
reference_input_filepath=reference_signal_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path,
|
||||
)
|
||||
try:
|
||||
# Evaluate.
|
||||
self._evaluator.Run(
|
||||
evaluation_score_workers=self._evaluation_score_workers,
|
||||
apm_input_metadata=apm_input_metadata,
|
||||
apm_output_filepath=self._audioproc_wrapper.
|
||||
output_filepath,
|
||||
reference_input_filepath=reference_signal_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path,
|
||||
)
|
||||
|
||||
# Save simulation metadata.
|
||||
data_access.Metadata.SaveAudioTestDataPaths(
|
||||
output_path=evaluation_output_path,
|
||||
clean_capture_input_filepath=clean_capture_input_filepath,
|
||||
echo_free_capture_filepath=noisy_capture_input_filepath,
|
||||
echo_filepath=echo_path_filepath,
|
||||
render_filepath=render_input_filepath,
|
||||
capture_filepath=apm_input_filepath,
|
||||
apm_output_filepath=self._audioproc_wrapper.output_filepath,
|
||||
apm_reference_filepath=reference_signal_filepath,
|
||||
apm_config_filepath=config_filepath,
|
||||
)
|
||||
except exceptions.EvaluationScoreException as e:
|
||||
logging.warning('the evaluation failed: %s', e.message)
|
||||
continue
|
||||
# Save simulation metadata.
|
||||
data_access.Metadata.SaveAudioTestDataPaths(
|
||||
output_path=evaluation_output_path,
|
||||
clean_capture_input_filepath=clean_capture_input_filepath,
|
||||
echo_free_capture_filepath=noisy_capture_input_filepath,
|
||||
echo_filepath=echo_path_filepath,
|
||||
render_filepath=render_input_filepath,
|
||||
capture_filepath=apm_input_filepath,
|
||||
apm_output_filepath=self._audioproc_wrapper.
|
||||
output_filepath,
|
||||
apm_reference_filepath=reference_signal_filepath,
|
||||
apm_config_filepath=config_filepath,
|
||||
)
|
||||
except exceptions.EvaluationScoreException as e:
|
||||
logging.warning('the evaluation failed: %s', e.message)
|
||||
continue
|
||||
|
||||
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
||||
render_input_filepaths):
|
||||
"""Sets input and render input file paths collections.
|
||||
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
||||
render_input_filepaths):
|
||||
"""Sets input and render input file paths collections.
|
||||
|
||||
Pairs the input and render input files by storing the file paths into two
|
||||
collections. The key is the file name of the input file.
|
||||
@ -390,20 +410,20 @@ class ApmModuleSimulator(object):
|
||||
capture_input_filepaths: list of file paths.
|
||||
render_input_filepaths: list of file paths.
|
||||
"""
|
||||
self._capture_input_filepaths = {}
|
||||
self._render_input_filepaths = {}
|
||||
assert len(capture_input_filepaths) == len(render_input_filepaths)
|
||||
for capture_input_filepath, render_input_filepath in zip(
|
||||
capture_input_filepaths, render_input_filepaths):
|
||||
name = self._ExtractFileName(capture_input_filepath)
|
||||
self._capture_input_filepaths[name] = os.path.abspath(
|
||||
capture_input_filepath)
|
||||
self._render_input_filepaths[name] = os.path.abspath(
|
||||
render_input_filepath)
|
||||
self._capture_input_filepaths = {}
|
||||
self._render_input_filepaths = {}
|
||||
assert len(capture_input_filepaths) == len(render_input_filepaths)
|
||||
for capture_input_filepath, render_input_filepath in zip(
|
||||
capture_input_filepaths, render_input_filepaths):
|
||||
name = self._ExtractFileName(capture_input_filepath)
|
||||
self._capture_input_filepaths[name] = os.path.abspath(
|
||||
capture_input_filepath)
|
||||
self._render_input_filepaths[name] = os.path.abspath(
|
||||
render_input_filepath)
|
||||
|
||||
@classmethod
|
||||
def _CreatePathsCollection(cls, filepaths):
|
||||
"""Creates a collection of file paths.
|
||||
@classmethod
|
||||
def _CreatePathsCollection(cls, filepaths):
|
||||
"""Creates a collection of file paths.
|
||||
|
||||
Given a list of file paths, makes a collection with one item for each file
|
||||
path. The value is absolute path, the key is the file name without
|
||||
@ -415,12 +435,12 @@ class ApmModuleSimulator(object):
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
filepaths_collection = {}
|
||||
for filepath in filepaths:
|
||||
name = cls._ExtractFileName(filepath)
|
||||
filepaths_collection[name] = os.path.abspath(filepath)
|
||||
return filepaths_collection
|
||||
filepaths_collection = {}
|
||||
for filepath in filepaths:
|
||||
name = cls._ExtractFileName(filepath)
|
||||
filepaths_collection[name] = os.path.abspath(filepath)
|
||||
return filepaths_collection
|
||||
|
||||
@classmethod
|
||||
def _ExtractFileName(cls, filepath):
|
||||
return os.path.splitext(os.path.split(filepath)[-1])[0]
|
||||
@classmethod
|
||||
def _ExtractFileName(cls, filepath):
|
||||
return os.path.splitext(os.path.split(filepath)[-1])[0]
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the simulation module.
|
||||
"""
|
||||
|
||||
@ -28,177 +27,177 @@ from . import test_data_generation_factory
|
||||
|
||||
|
||||
class TestApmModuleSimulator(unittest.TestCase):
|
||||
"""Unit tests for the ApmModuleSimulator class.
|
||||
"""Unit tests for the ApmModuleSimulator class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folders and fake audio track."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Create temporary folders and fake audio track."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_audio_track_path, fake_signal)
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._fake_audio_track_path = os.path.join(self._output_path,
|
||||
'fake.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_audio_track_path, fake_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._output_path)
|
||||
shutil.rmtree(self._tmp_path)
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._output_path)
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testSimulation(self):
|
||||
# Instance dependencies to mock and inject.
|
||||
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
evaluator = evaluation.ApmModuleEvaluator()
|
||||
ap_wrapper.Run = mock.MagicMock(name='Run')
|
||||
evaluator.Run = mock.MagicMock(name='Run')
|
||||
def testSimulation(self):
|
||||
# Instance dependencies to mock and inject.
|
||||
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
evaluator = evaluation.ApmModuleEvaluator()
|
||||
ap_wrapper.Run = mock.MagicMock(name='Run')
|
||||
evaluator.Run = mock.MagicMock(name='Run')
|
||||
|
||||
# Instance non-mocked dependencies.
|
||||
test_data_generator_factory = (
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False))
|
||||
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=test_data_generator_factory,
|
||||
evaluation_score_factory=evaluation_score_factory,
|
||||
ap_wrapper=ap_wrapper,
|
||||
evaluator=evaluator,
|
||||
external_vads={'fake': external_vad.ExternalVad(os.path.join(
|
||||
os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
|
||||
)
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [self._fake_audio_track_path]
|
||||
test_data_generators = ['identity', 'white_noise']
|
||||
eval_scores = ['audio_level_mean', 'polqa']
|
||||
|
||||
# Run all simulations.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=test_data_generators,
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
|
||||
# Check.
|
||||
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
|
||||
# the client code (e.g., number of SNR pairs for the white noise test data
|
||||
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
|
||||
# is known; use that with assertEqual.
|
||||
min_number_of_simulations = len(config_files) * len(input_files) * len(
|
||||
test_data_generators)
|
||||
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
|
||||
def testInputSignalCreation(self):
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
# Instance non-mocked dependencies.
|
||||
test_data_generator_factory = (
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
copy_with_identity=False))
|
||||
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)
|
||||
|
||||
# Inexistent input files to be silently created.
|
||||
input_files = [
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
|
||||
]
|
||||
self.assertFalse(any([os.path.exists(input_file) for input_file in (
|
||||
input_files)]))
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=test_data_generator_factory,
|
||||
evaluation_score_factory=evaluation_score_factory,
|
||||
ap_wrapper=ap_wrapper,
|
||||
evaluator=evaluator,
|
||||
external_vads={
|
||||
'fake':
|
||||
external_vad.ExternalVad(
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
'fake_external_vad.py'), 'fake')
|
||||
})
|
||||
|
||||
# The input files are created during the simulation.
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=['audio_level_peak'],
|
||||
output_dir=self._output_path)
|
||||
self.assertTrue(all([os.path.exists(input_file) for input_file in (
|
||||
input_files)]))
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [self._fake_audio_track_path]
|
||||
test_data_generators = ['identity', 'white_noise']
|
||||
eval_scores = ['audio_level_mean', 'polqa']
|
||||
|
||||
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
# Run all simulations.
|
||||
simulator.Run(config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=test_data_generators,
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(__file__), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None
|
||||
)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
# Check.
|
||||
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
|
||||
# the client code (e.g., number of SNR pairs for the white noise test data
|
||||
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
|
||||
# is known; use that with assertEqual.
|
||||
min_number_of_simulations = len(config_files) * len(input_files) * len(
|
||||
test_data_generators)
|
||||
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
|
||||
eval_scores = ['thd']
|
||||
def testInputSignalCreation(self):
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
|
||||
# Should work.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
self.assertFalse(logging.warning.called)
|
||||
# Inexistent input files to be silently created.
|
||||
input_files = [
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
|
||||
]
|
||||
self.assertFalse(
|
||||
any([os.path.exists(input_file) for input_file in (input_files)]))
|
||||
|
||||
# Warning expected.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['white_noise'], # Not allowed with THD.
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
logging.warning.assert_called_with('the evaluation failed: %s', (
|
||||
'The THD score cannot be used with any test data generator other than '
|
||||
'"identity"'))
|
||||
# The input files are created during the simulation.
|
||||
simulator.Run(config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=['audio_level_peak'],
|
||||
output_dir=self._output_path)
|
||||
self.assertTrue(
|
||||
all([os.path.exists(input_file) for input_file in (input_files)]))
|
||||
|
||||
# # Init.
|
||||
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
|
||||
# input_signal_filepath = os.path.join(
|
||||
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
|
||||
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
|
||||
# # Check that the input signal is generated.
|
||||
# self.assertFalse(os.path.exists(input_signal_filepath))
|
||||
# generator.Generate(
|
||||
# input_signal_filepath=input_signal_filepath,
|
||||
# test_data_cache_path=self._test_data_cache_path,
|
||||
# base_output_path=self._base_output_path)
|
||||
# self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
|
||||
# # Check input signal properties.
|
||||
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
# input_signal_filepath)
|
||||
# self.assertEqual(1000, len(input_signal))
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
|
||||
eval_scores = ['thd']
|
||||
|
||||
# Should work.
|
||||
simulator.Run(config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
self.assertFalse(logging.warning.called)
|
||||
|
||||
# Warning expected.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['white_noise'], # Not allowed with THD.
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
logging.warning.assert_called_with('the evaluation failed: %s', (
|
||||
'The THD score cannot be used with any test data generator other than '
|
||||
'"identity"'))
|
||||
|
||||
# # Init.
|
||||
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
|
||||
# input_signal_filepath = os.path.join(
|
||||
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
|
||||
|
||||
# # Check that the input signal is generated.
|
||||
# self.assertFalse(os.path.exists(input_signal_filepath))
|
||||
# generator.Generate(
|
||||
# input_signal_filepath=input_signal_filepath,
|
||||
# test_data_cache_path=self._test_data_cache_path,
|
||||
# base_output_path=self._base_output_path)
|
||||
# self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
# # Check input signal properties.
|
||||
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
# input_signal_filepath)
|
||||
# self.assertEqual(1000, len(input_signal))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Test data generators producing signals pairs intended to be used to
|
||||
test the APM module. Each pair consists of a noisy input and a reference signal.
|
||||
The former is used as APM input and it is generated by adding noise to a
|
||||
@ -27,10 +26,10 @@ import shutil
|
||||
import sys
|
||||
|
||||
try:
|
||||
import scipy.io
|
||||
import scipy.io
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
@ -38,7 +37,7 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestDataGenerator(object):
|
||||
"""Abstract class responsible for the generation of noisy signals.
|
||||
"""Abstract class responsible for the generation of noisy signals.
|
||||
|
||||
Given a clean signal, it generates two streams named noisy signal and
|
||||
reference. The former is the clean signal deteriorated by the noise source,
|
||||
@ -50,24 +49,24 @@ class TestDataGenerator(object):
|
||||
An test data generator generates one or more pairs.
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
self._output_directory_prefix = output_directory_prefix
|
||||
# Init dictionaries with one entry for each test data generator
|
||||
# configuration (e.g., different SNRs).
|
||||
# Noisy audio track files (stored separately in a cache folder).
|
||||
self._noisy_signal_filepaths = None
|
||||
# Path to be used for the APM simulation output files.
|
||||
self._apm_output_paths = None
|
||||
# Reference audio track files (stored separately in a cache folder).
|
||||
self._reference_signal_filepaths = None
|
||||
self.Clear()
|
||||
def __init__(self, output_directory_prefix):
|
||||
self._output_directory_prefix = output_directory_prefix
|
||||
# Init dictionaries with one entry for each test data generator
|
||||
# configuration (e.g., different SNRs).
|
||||
# Noisy audio track files (stored separately in a cache folder).
|
||||
self._noisy_signal_filepaths = None
|
||||
# Path to be used for the APM simulation output files.
|
||||
self._apm_output_paths = None
|
||||
# Reference audio track files (stored separately in a cache folder).
|
||||
self._reference_signal_filepaths = None
|
||||
self.Clear()
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers a TestDataGenerator implementation.
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers a TestDataGenerator implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend
|
||||
TestDataGenerator.
|
||||
@ -77,28 +76,28 @@ class TestDataGenerator(object):
|
||||
class IdentityGenerator(TestDataGenerator):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
@property
|
||||
def config_names(self):
|
||||
return self._noisy_signal_filepaths.keys()
|
||||
@property
|
||||
def config_names(self):
|
||||
return self._noisy_signal_filepaths.keys()
|
||||
|
||||
@property
|
||||
def noisy_signal_filepaths(self):
|
||||
return self._noisy_signal_filepaths
|
||||
@property
|
||||
def noisy_signal_filepaths(self):
|
||||
return self._noisy_signal_filepaths
|
||||
|
||||
@property
|
||||
def apm_output_paths(self):
|
||||
return self._apm_output_paths
|
||||
@property
|
||||
def apm_output_paths(self):
|
||||
return self._apm_output_paths
|
||||
|
||||
@property
|
||||
def reference_signal_filepaths(self):
|
||||
return self._reference_signal_filepaths
|
||||
@property
|
||||
def reference_signal_filepaths(self):
|
||||
return self._reference_signal_filepaths
|
||||
|
||||
def Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Generates a set of noisy input and reference audiotrack file pairs.
|
||||
def Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates a set of noisy input and reference audiotrack file pairs.
|
||||
|
||||
This method initializes an empty set of pairs and calls the _Generate()
|
||||
method implemented in a concrete class.
|
||||
@ -109,26 +108,26 @@ class TestDataGenerator(object):
|
||||
files.
|
||||
base_output_path: base path where output is written.
|
||||
"""
|
||||
self.Clear()
|
||||
self._Generate(
|
||||
input_signal_filepath, test_data_cache_path, base_output_path)
|
||||
self.Clear()
|
||||
self._Generate(input_signal_filepath, test_data_cache_path,
|
||||
base_output_path)
|
||||
|
||||
def Clear(self):
|
||||
"""Clears the generated output path dictionaries.
|
||||
def Clear(self):
|
||||
"""Clears the generated output path dictionaries.
|
||||
"""
|
||||
self._noisy_signal_filepaths = {}
|
||||
self._apm_output_paths = {}
|
||||
self._reference_signal_filepaths = {}
|
||||
self._noisy_signal_filepaths = {}
|
||||
self._apm_output_paths = {}
|
||||
self._reference_signal_filepaths = {}
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Abstract method to be implemented in each concrete class.
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Abstract method to be implemented in each concrete class.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
|
||||
snr_value_pairs):
|
||||
"""Adds noisy-reference signal pairs.
|
||||
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
|
||||
snr_value_pairs):
|
||||
"""Adds noisy-reference signal pairs.
|
||||
|
||||
Args:
|
||||
base_output_path: noisy tracks base output path.
|
||||
@ -136,22 +135,22 @@ class TestDataGenerator(object):
|
||||
by noisy track name and SNR level.
|
||||
snr_value_pairs: list of SNR pairs.
|
||||
"""
|
||||
for noise_track_name in noisy_mix_filepaths:
|
||||
for snr_noisy, snr_refence in snr_value_pairs:
|
||||
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
|
||||
noise_track_name, snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_refence],
|
||||
output_path=output_path)
|
||||
for noise_track_name in noisy_mix_filepaths:
|
||||
for snr_noisy, snr_refence in snr_value_pairs:
|
||||
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
|
||||
noise_track_name, snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[noise_track_name]
|
||||
[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_refence],
|
||||
output_path=output_path)
|
||||
|
||||
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
|
||||
reference_signal_filepath, output_path):
|
||||
"""Adds one noisy-reference signal pair.
|
||||
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
|
||||
reference_signal_filepath, output_path):
|
||||
"""Adds one noisy-reference signal pair.
|
||||
|
||||
Args:
|
||||
config_name: name of the APM configuration.
|
||||
@ -159,264 +158,275 @@ class TestDataGenerator(object):
|
||||
reference_signal_filepath: path to reference audio track file.
|
||||
output_path: APM output path.
|
||||
"""
|
||||
assert config_name not in self._noisy_signal_filepaths
|
||||
self._noisy_signal_filepaths[config_name] = os.path.abspath(
|
||||
noisy_signal_filepath)
|
||||
self._apm_output_paths[config_name] = os.path.abspath(output_path)
|
||||
self._reference_signal_filepaths[config_name] = os.path.abspath(
|
||||
reference_signal_filepath)
|
||||
assert config_name not in self._noisy_signal_filepaths
|
||||
self._noisy_signal_filepaths[config_name] = os.path.abspath(
|
||||
noisy_signal_filepath)
|
||||
self._apm_output_paths[config_name] = os.path.abspath(output_path)
|
||||
self._reference_signal_filepaths[config_name] = os.path.abspath(
|
||||
reference_signal_filepath)
|
||||
|
||||
def _MakeDir(self, base_output_path, test_data_generator_config_name):
|
||||
output_path = os.path.join(
|
||||
base_output_path,
|
||||
self._output_directory_prefix + test_data_generator_config_name)
|
||||
data_access.MakeDirectory(output_path)
|
||||
return output_path
|
||||
def _MakeDir(self, base_output_path, test_data_generator_config_name):
|
||||
output_path = os.path.join(
|
||||
base_output_path,
|
||||
self._output_directory_prefix + test_data_generator_config_name)
|
||||
data_access.MakeDirectory(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class IdentityTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds no noise.
|
||||
"""Generator that adds no noise.
|
||||
|
||||
Both the noisy and the reference signals are the input signal.
|
||||
"""
|
||||
|
||||
NAME = 'identity'
|
||||
NAME = 'identity'
|
||||
|
||||
def __init__(self, output_directory_prefix, copy_with_identity):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._copy_with_identity = copy_with_identity
|
||||
def __init__(self, output_directory_prefix, copy_with_identity):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._copy_with_identity = copy_with_identity
|
||||
|
||||
@property
|
||||
def copy_with_identity(self):
|
||||
return self._copy_with_identity
|
||||
@property
|
||||
def copy_with_identity(self):
|
||||
return self._copy_with_identity
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
config_name = 'default'
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
config_name = 'default'
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
|
||||
if self._copy_with_identity:
|
||||
input_signal_filepath_new = os.path.join(
|
||||
test_data_cache_path, os.path.split(input_signal_filepath)[1])
|
||||
logging.info('copying ' + input_signal_filepath + ' to ' + (
|
||||
input_signal_filepath_new))
|
||||
shutil.copy(input_signal_filepath, input_signal_filepath_new)
|
||||
input_signal_filepath = input_signal_filepath_new
|
||||
if self._copy_with_identity:
|
||||
input_signal_filepath_new = os.path.join(
|
||||
test_data_cache_path,
|
||||
os.path.split(input_signal_filepath)[1])
|
||||
logging.info('copying ' + input_signal_filepath + ' to ' +
|
||||
(input_signal_filepath_new))
|
||||
shutil.copy(input_signal_filepath, input_signal_filepath_new)
|
||||
input_signal_filepath = input_signal_filepath_new
|
||||
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=input_signal_filepath,
|
||||
reference_signal_filepath=input_signal_filepath,
|
||||
output_path=output_path)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=input_signal_filepath,
|
||||
reference_signal_filepath=input_signal_filepath,
|
||||
output_path=output_path)
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class WhiteNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds white noise.
|
||||
"""Generator that adds white noise.
|
||||
"""
|
||||
|
||||
NAME = 'white_noise'
|
||||
NAME = 'white_noise'
|
||||
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
# Create the noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
input_signal)
|
||||
# Create the noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
input_signal)
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths = {}
|
||||
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths = {}
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[snr] = noisy_signal_filepath
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noisy-reference signal pairs.
|
||||
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
|
||||
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
|
||||
output_path=output_path)
|
||||
# Add all the noisy-reference signal pairs.
|
||||
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
|
||||
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
|
||||
output_path=output_path)
|
||||
|
||||
|
||||
# TODO(alessiob): remove comment when class implemented.
|
||||
# @TestDataGenerator.RegisterClass
|
||||
class NarrowBandNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds narrow-band noise.
|
||||
"""Generator that adds narrow-band noise.
|
||||
"""
|
||||
|
||||
NAME = 'narrow_band_noise'
|
||||
NAME = 'narrow_band_noise'
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
# TODO(alessiob): implement.
|
||||
pass
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
# TODO(alessiob): implement.
|
||||
pass
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class AdditiveNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds noise loops.
|
||||
"""Generator that adds noise loops.
|
||||
|
||||
This generator uses all the wav files in a given path (default: noise_tracks/)
|
||||
and mixes them to the clean speech with different target SNRs (hard-coded).
|
||||
"""
|
||||
|
||||
NAME = 'additive_noise'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
NAME = 'additive_noise'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
DEFAULT_NOISE_TRACKS_PATH = os.path.join(
|
||||
os.path.dirname(__file__), os.pardir, 'noise_tracks')
|
||||
DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__),
|
||||
os.pardir, 'noise_tracks')
|
||||
|
||||
# TODO(alessiob): Make the list of SNR pairs customizable.
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
# TODO(alessiob): Make the list of SNR pairs customizable.
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
|
||||
def __init__(self, output_directory_prefix, noise_tracks_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._noise_tracks_file_names = [n for n in os.listdir(
|
||||
self._noise_tracks_path) if n.lower().endswith('.wav')]
|
||||
if len(self._noise_tracks_file_names) == 0:
|
||||
raise exceptions.InitializationException(
|
||||
'No wav files found in the noise tracks path %s' % (
|
||||
self._noise_tracks_path))
|
||||
def __init__(self, output_directory_prefix, noise_tracks_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._noise_tracks_file_names = [
|
||||
n for n in os.listdir(self._noise_tracks_path)
|
||||
if n.lower().endswith('.wav')
|
||||
]
|
||||
if len(self._noise_tracks_file_names) == 0:
|
||||
raise exceptions.InitializationException(
|
||||
'No wav files found in the noise tracks path %s' %
|
||||
(self._noise_tracks_path))
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Generates test data pairs using environmental noise.
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates test data pairs using environmental noise.
|
||||
|
||||
For each noise track and pair of SNR values, the following two audio tracks
|
||||
are created: the noisy signal and the reference signal. The former is
|
||||
obtained by mixing the (clean) input signal to the corresponding noise
|
||||
track enforcing the target SNR.
|
||||
"""
|
||||
# Init.
|
||||
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
# Init.
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
noisy_mix_filepaths = {}
|
||||
for noise_track_filename in self._noise_tracks_file_names:
|
||||
# Load the noise track.
|
||||
noise_track_name, _ = os.path.splitext(noise_track_filename)
|
||||
noise_track_filepath = os.path.join(
|
||||
self._noise_tracks_path, noise_track_filename)
|
||||
if not os.path.exists(noise_track_filepath):
|
||||
logging.error('cannot find the <%s> noise track', noise_track_filename)
|
||||
raise exceptions.FileNotFoundError()
|
||||
noisy_mix_filepaths = {}
|
||||
for noise_track_filename in self._noise_tracks_file_names:
|
||||
# Load the noise track.
|
||||
noise_track_name, _ = os.path.splitext(noise_track_filename)
|
||||
noise_track_filepath = os.path.join(self._noise_tracks_path,
|
||||
noise_track_filename)
|
||||
if not os.path.exists(noise_track_filepath):
|
||||
logging.error('cannot find the <%s> noise track',
|
||||
noise_track_filename)
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[noise_track_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(noise_track_name, snr))
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[noise_track_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
noise_track_name, snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal,
|
||||
noise_signal,
|
||||
snr,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.
|
||||
MixPadding.LOOP)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[noise_track_name][snr] = noisy_signal_filepath
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[noise_track_name][
|
||||
snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(
|
||||
base_output_path, noisy_mix_filepaths, self._SNR_VALUE_PAIRS)
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds reverberation noise.
|
||||
"""Generator that adds reverberation noise.
|
||||
|
||||
TODO(alessiob): Make this class more generic since the impulse response can be
|
||||
anything (not just reverberation); call it e.g.,
|
||||
ConvolutionalNoiseTestDataGenerator.
|
||||
"""
|
||||
|
||||
NAME = 'reverberation'
|
||||
NAME = 'reverberation'
|
||||
|
||||
_IMPULSE_RESPONSES = {
|
||||
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
|
||||
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
|
||||
}
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH = None
|
||||
_IMPULSE_RESPONSES = {
|
||||
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
|
||||
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
|
||||
}
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH = None
|
||||
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 5 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[3, 8], # Smallest noise.
|
||||
[-3, 2], # Largest noise.
|
||||
]
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 5 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[3, 8], # Smallest noise.
|
||||
[-3, 2], # Largest noise.
|
||||
]
|
||||
|
||||
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
def __init__(self, output_directory_prefix, aechen_ir_database_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
def __init__(self, output_directory_prefix, aechen_ir_database_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
|
||||
def _Generate(
|
||||
self, input_signal_filepath, test_data_cache_path, base_output_path):
|
||||
"""Generates test data pairs using reverberation noise.
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates test data pairs using reverberation noise.
|
||||
|
||||
For each impulse response, one noise track is created. For each impulse
|
||||
response and pair of SNR values, the following 2 audio tracks are
|
||||
@ -424,61 +434,64 @@ class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
obtained by mixing the (clean) input signal to the corresponding noise
|
||||
track enforcing the target SNR.
|
||||
"""
|
||||
# Init.
|
||||
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
# Init.
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
noisy_mix_filepaths = {}
|
||||
for impulse_response_name in self._IMPULSE_RESPONSES:
|
||||
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name)
|
||||
noise_track_filepath = os.path.join(
|
||||
test_data_cache_path, noise_track_filename)
|
||||
noise_signal = None
|
||||
try:
|
||||
# Load noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
except exceptions.FileNotFoundError:
|
||||
# Generate noise track by applying the impulse response.
|
||||
impulse_response_filepath = os.path.join(
|
||||
self._aechen_ir_database_path,
|
||||
self._IMPULSE_RESPONSES[impulse_response_name])
|
||||
noise_signal = self._GenerateNoiseTrack(
|
||||
noise_track_filepath, input_signal, impulse_response_filepath)
|
||||
assert noise_signal is not None
|
||||
noisy_mix_filepaths = {}
|
||||
for impulse_response_name in self._IMPULSE_RESPONSES:
|
||||
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name)
|
||||
noise_track_filepath = os.path.join(test_data_cache_path,
|
||||
noise_track_filename)
|
||||
noise_signal = None
|
||||
try:
|
||||
# Load noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
except exceptions.FileNotFoundError:
|
||||
# Generate noise track by applying the impulse response.
|
||||
impulse_response_filepath = os.path.join(
|
||||
self._aechen_ir_database_path,
|
||||
self._IMPULSE_RESPONSES[impulse_response_name])
|
||||
noise_signal = self._GenerateNoiseTrack(
|
||||
noise_track_filepath, input_signal,
|
||||
impulse_response_filepath)
|
||||
assert noise_signal is not None
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[impulse_response_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name, snr))
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[impulse_response_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name, snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[impulse_response_name][snr] = noisy_signal_filepath
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[impulse_response_name][
|
||||
snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
|
||||
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
|
||||
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
|
||||
impulse_response_filepath):
|
||||
"""Generates noise track.
|
||||
"""Generates noise track.
|
||||
|
||||
Generate a signal by convolving input_signal with the impulse response in
|
||||
impulse_response_filepath; then save to noise_track_filepath.
|
||||
@ -491,21 +504,23 @@ class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
# Load impulse response.
|
||||
data = scipy.io.loadmat(impulse_response_filepath)
|
||||
impulse_response = data['h_air'].flatten()
|
||||
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
|
||||
logging.info('truncating impulse response from %d to %d samples',
|
||||
len(impulse_response), self._MAX_IMPULSE_RESPONSE_LENGTH)
|
||||
impulse_response = impulse_response[:self._MAX_IMPULSE_RESPONSE_LENGTH]
|
||||
# Load impulse response.
|
||||
data = scipy.io.loadmat(impulse_response_filepath)
|
||||
impulse_response = data['h_air'].flatten()
|
||||
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
|
||||
logging.info('truncating impulse response from %d to %d samples',
|
||||
len(impulse_response),
|
||||
self._MAX_IMPULSE_RESPONSE_LENGTH)
|
||||
impulse_response = impulse_response[:self.
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH]
|
||||
|
||||
# Apply impulse response.
|
||||
processed_signal = (
|
||||
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
input_signal, impulse_response))
|
||||
# Apply impulse response.
|
||||
processed_signal = (
|
||||
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
input_signal, impulse_response))
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noise_track_filepath, processed_signal)
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noise_track_filepath, processed_signal)
|
||||
|
||||
return processed_signal
|
||||
return processed_signal
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""TestDataGenerator factory class.
|
||||
"""
|
||||
|
||||
@ -16,15 +15,15 @@ from . import test_data_generation
|
||||
|
||||
|
||||
class TestDataGeneratorFactory(object):
|
||||
"""Factory class used to create test data generators.
|
||||
"""Factory class used to create test data generators.
|
||||
|
||||
Usage: Create a factory passing parameters to the ctor with which the
|
||||
generators will be produced.
|
||||
"""
|
||||
|
||||
def __init__(self, aechen_ir_database_path, noise_tracks_path,
|
||||
copy_with_identity):
|
||||
"""Ctor.
|
||||
def __init__(self, aechen_ir_database_path, noise_tracks_path,
|
||||
copy_with_identity):
|
||||
"""Ctor.
|
||||
|
||||
Args:
|
||||
aechen_ir_database_path: Path to the Aechen Impulse Response database.
|
||||
@ -32,16 +31,16 @@ class TestDataGeneratorFactory(object):
|
||||
copy_with_identity: Flag indicating whether the identity generator has to
|
||||
make copies of the clean speech input files.
|
||||
"""
|
||||
self._output_directory_prefix = None
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._copy_with_identity = copy_with_identity
|
||||
self._output_directory_prefix = None
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._copy_with_identity = copy_with_identity
|
||||
|
||||
def SetOutputDirectoryPrefix(self, prefix):
|
||||
self._output_directory_prefix = prefix
|
||||
def SetOutputDirectoryPrefix(self, prefix):
|
||||
self._output_directory_prefix = prefix
|
||||
|
||||
def GetInstance(self, test_data_generators_class):
|
||||
"""Creates an TestDataGenerator instance given a class object.
|
||||
def GetInstance(self, test_data_generators_class):
|
||||
"""Creates an TestDataGenerator instance given a class object.
|
||||
|
||||
Args:
|
||||
test_data_generators_class: TestDataGenerator class object (not an
|
||||
@ -50,22 +49,23 @@ class TestDataGeneratorFactory(object):
|
||||
Returns:
|
||||
TestDataGenerator instance.
|
||||
"""
|
||||
if self._output_directory_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The output directory prefix for test data generators is not set')
|
||||
logging.debug('factory producing %s', test_data_generators_class)
|
||||
if self._output_directory_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The output directory prefix for test data generators is not set'
|
||||
)
|
||||
logging.debug('factory producing %s', test_data_generators_class)
|
||||
|
||||
if test_data_generators_class == (
|
||||
test_data_generation.IdentityTestDataGenerator):
|
||||
return test_data_generation.IdentityTestDataGenerator(
|
||||
self._output_directory_prefix, self._copy_with_identity)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.ReverberationTestDataGenerator):
|
||||
return test_data_generation.ReverberationTestDataGenerator(
|
||||
self._output_directory_prefix, self._aechen_ir_database_path)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.AdditiveNoiseTestDataGenerator):
|
||||
return test_data_generation.AdditiveNoiseTestDataGenerator(
|
||||
self._output_directory_prefix, self._noise_tracks_path)
|
||||
else:
|
||||
return test_data_generators_class(self._output_directory_prefix)
|
||||
if test_data_generators_class == (
|
||||
test_data_generation.IdentityTestDataGenerator):
|
||||
return test_data_generation.IdentityTestDataGenerator(
|
||||
self._output_directory_prefix, self._copy_with_identity)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.ReverberationTestDataGenerator):
|
||||
return test_data_generation.ReverberationTestDataGenerator(
|
||||
self._output_directory_prefix, self._aechen_ir_database_path)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.AdditiveNoiseTestDataGenerator):
|
||||
return test_data_generation.AdditiveNoiseTestDataGenerator(
|
||||
self._output_directory_prefix, self._noise_tracks_path)
|
||||
else:
|
||||
return test_data_generators_class(self._output_directory_prefix)
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""
|
||||
|
||||
@ -23,141 +22,143 @@ from . import signal_processing
|
||||
|
||||
|
||||
class TestTestDataGenerators(unittest.TestCase):
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folders."""
|
||||
self._base_output_path = tempfile.mkdtemp()
|
||||
self._test_data_cache_path = tempfile.mkdtemp()
|
||||
self._fake_air_db_path = tempfile.mkdtemp()
|
||||
def setUp(self):
|
||||
"""Create temporary folders."""
|
||||
self._base_output_path = tempfile.mkdtemp()
|
||||
self._test_data_cache_path = tempfile.mkdtemp()
|
||||
self._fake_air_db_path = tempfile.mkdtemp()
|
||||
|
||||
# Fake AIR DB impulse responses.
|
||||
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
|
||||
# impulse responses. When changed, the coupling below between
|
||||
# impulse_response_mat_file_names and
|
||||
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
|
||||
impulse_response_mat_file_names = [
|
||||
'air_binaural_lecture_0_0_1.mat',
|
||||
'air_binaural_booth_0_0_1.mat',
|
||||
]
|
||||
for impulse_response_mat_file_name in impulse_response_mat_file_names:
|
||||
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
|
||||
scipy.io.savemat(os.path.join(
|
||||
self._fake_air_db_path, impulse_response_mat_file_name), data)
|
||||
# Fake AIR DB impulse responses.
|
||||
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
|
||||
# impulse responses. When changed, the coupling below between
|
||||
# impulse_response_mat_file_names and
|
||||
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
|
||||
impulse_response_mat_file_names = [
|
||||
'air_binaural_lecture_0_0_1.mat',
|
||||
'air_binaural_booth_0_0_1.mat',
|
||||
]
|
||||
for impulse_response_mat_file_name in impulse_response_mat_file_names:
|
||||
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
|
||||
scipy.io.savemat(
|
||||
os.path.join(self._fake_air_db_path,
|
||||
impulse_response_mat_file_name), data)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._base_output_path)
|
||||
shutil.rmtree(self._test_data_cache_path)
|
||||
shutil.rmtree(self._fake_air_db_path)
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._base_output_path)
|
||||
shutil.rmtree(self._test_data_cache_path)
|
||||
shutil.rmtree(self._fake_air_db_path)
|
||||
|
||||
def testTestDataGenerators(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
def testTestDataGenerators(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
|
||||
# Check that there is at least one registered test data generator.
|
||||
registered_classes = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
# Check that there is at least one registered test data generator.
|
||||
registered_classes = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance generators factory.
|
||||
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=self._fake_air_db_path,
|
||||
noise_tracks_path=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH,
|
||||
copy_with_identity=False)
|
||||
generators_factory.SetOutputDirectoryPrefix('datagen-')
|
||||
# Instance generators factory.
|
||||
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=self._fake_air_db_path,
|
||||
noise_tracks_path=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH,
|
||||
copy_with_identity=False)
|
||||
generators_factory.SetOutputDirectoryPrefix('datagen-')
|
||||
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(
|
||||
os.getcwd(), 'probing_signals', 'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
|
||||
'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
# Load input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
# Load input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
# Try each registered test data generator.
|
||||
for generator_name in registered_classes:
|
||||
# Instance test data generator.
|
||||
generator = generators_factory.GetInstance(
|
||||
registered_classes[generator_name])
|
||||
# Try each registered test data generator.
|
||||
for generator_name in registered_classes:
|
||||
# Instance test data generator.
|
||||
generator = generators_factory.GetInstance(
|
||||
registered_classes[generator_name])
|
||||
|
||||
# Generate the noisy input - reference pairs.
|
||||
generator.Generate(
|
||||
input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
# Generate the noisy input - reference pairs.
|
||||
generator.Generate(input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
|
||||
# Perform checks.
|
||||
self._CheckGeneratedPairsListSizes(generator)
|
||||
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
|
||||
self._CheckGeneratedPairsOutputPaths(generator)
|
||||
# Perform checks.
|
||||
self._CheckGeneratedPairsListSizes(generator)
|
||||
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
|
||||
self._CheckGeneratedPairsOutputPaths(generator)
|
||||
|
||||
def testTestidentityDataGenerator(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
def testTestidentityDataGenerator(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(
|
||||
os.getcwd(), 'probing_signals', 'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
|
||||
'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
def GetNoiseReferenceFilePaths(identity_generator):
|
||||
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
|
||||
reference_signal_filepaths = identity_generator.reference_signal_filepaths
|
||||
assert noisy_signal_filepaths.keys() == reference_signal_filepaths.keys()
|
||||
assert len(noisy_signal_filepaths.keys()) == 1
|
||||
key = noisy_signal_filepaths.keys()[0]
|
||||
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
|
||||
def GetNoiseReferenceFilePaths(identity_generator):
|
||||
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
|
||||
reference_signal_filepaths = identity_generator.reference_signal_filepaths
|
||||
assert noisy_signal_filepaths.keys(
|
||||
) == reference_signal_filepaths.keys()
|
||||
assert len(noisy_signal_filepaths.keys()) == 1
|
||||
key = noisy_signal_filepaths.keys()[0]
|
||||
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
|
||||
|
||||
# Test the |copy_with_identity| flag.
|
||||
for copy_with_identity in [False, True]:
|
||||
# Instance the generator through the factory.
|
||||
factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='', noise_tracks_path='',
|
||||
copy_with_identity=copy_with_identity)
|
||||
factory.SetOutputDirectoryPrefix('datagen-')
|
||||
generator = factory.GetInstance(
|
||||
test_data_generation.IdentityTestDataGenerator)
|
||||
# Check |copy_with_identity| is set correctly.
|
||||
self.assertEqual(copy_with_identity, generator.copy_with_identity)
|
||||
# Test the |copy_with_identity| flag.
|
||||
for copy_with_identity in [False, True]:
|
||||
# Instance the generator through the factory.
|
||||
factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=copy_with_identity)
|
||||
factory.SetOutputDirectoryPrefix('datagen-')
|
||||
generator = factory.GetInstance(
|
||||
test_data_generation.IdentityTestDataGenerator)
|
||||
# Check |copy_with_identity| is set correctly.
|
||||
self.assertEqual(copy_with_identity, generator.copy_with_identity)
|
||||
|
||||
# Generate test data and extract the paths to the noise and the reference
|
||||
# files.
|
||||
generator.Generate(
|
||||
input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
noisy_signal_filepath, reference_signal_filepath = (
|
||||
GetNoiseReferenceFilePaths(generator))
|
||||
# Generate test data and extract the paths to the noise and the reference
|
||||
# files.
|
||||
generator.Generate(input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
noisy_signal_filepath, reference_signal_filepath = (
|
||||
GetNoiseReferenceFilePaths(generator))
|
||||
|
||||
# Check that a copy is made if and only if |copy_with_identity| is True.
|
||||
if copy_with_identity:
|
||||
self.assertNotEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertNotEqual(reference_signal_filepath, input_signal_filepath)
|
||||
else:
|
||||
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertEqual(reference_signal_filepath, input_signal_filepath)
|
||||
# Check that a copy is made if and only if |copy_with_identity| is True.
|
||||
if copy_with_identity:
|
||||
self.assertNotEqual(noisy_signal_filepath,
|
||||
input_signal_filepath)
|
||||
self.assertNotEqual(reference_signal_filepath,
|
||||
input_signal_filepath)
|
||||
else:
|
||||
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertEqual(reference_signal_filepath,
|
||||
input_signal_filepath)
|
||||
|
||||
def _CheckGeneratedPairsListSizes(self, generator):
|
||||
config_names = generator.config_names
|
||||
number_of_pairs = len(config_names)
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.noisy_signal_filepaths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.apm_output_paths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.reference_signal_filepaths))
|
||||
def _CheckGeneratedPairsListSizes(self, generator):
|
||||
config_names = generator.config_names
|
||||
number_of_pairs = len(config_names)
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.noisy_signal_filepaths))
|
||||
self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.reference_signal_filepaths))
|
||||
|
||||
def _CheckGeneratedPairsSignalDurations(
|
||||
self, generator, input_signal):
|
||||
"""Checks duration of the generated signals.
|
||||
def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
|
||||
"""Checks duration of the generated signals.
|
||||
|
||||
Checks that the noisy input and the reference tracks are audio files
|
||||
with duration equal to or greater than that of the input signal.
|
||||
@ -166,41 +167,41 @@ class TestTestDataGenerators(unittest.TestCase):
|
||||
generator: TestDataGenerator instance.
|
||||
input_signal: AudioSegment instance.
|
||||
"""
|
||||
input_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
|
||||
input_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
|
||||
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
# Load the noisy input file.
|
||||
noisy_signal_filepath = generator.noisy_signal_filepaths[
|
||||
config_name]
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noisy_signal_filepath)
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
# Load the noisy input file.
|
||||
noisy_signal_filepath = generator.noisy_signal_filepaths[
|
||||
config_name]
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noisy_signal_filepath)
|
||||
|
||||
# Check noisy input signal length.
|
||||
noisy_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(noisy_signal))
|
||||
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
|
||||
# Check noisy input signal length.
|
||||
noisy_signal_length = (signal_processing.SignalProcessingUtils.
|
||||
CountSamples(noisy_signal))
|
||||
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
|
||||
|
||||
# Load the reference file.
|
||||
reference_signal_filepath = generator.reference_signal_filepaths[
|
||||
config_name]
|
||||
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
reference_signal_filepath)
|
||||
# Load the reference file.
|
||||
reference_signal_filepath = generator.reference_signal_filepaths[
|
||||
config_name]
|
||||
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
reference_signal_filepath)
|
||||
|
||||
# Check noisy input signal length.
|
||||
reference_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
reference_signal))
|
||||
self.assertGreaterEqual(reference_signal_length, input_signal_length)
|
||||
# Check noisy input signal length.
|
||||
reference_signal_length = (signal_processing.SignalProcessingUtils.
|
||||
CountSamples(reference_signal))
|
||||
self.assertGreaterEqual(reference_signal_length,
|
||||
input_signal_length)
|
||||
|
||||
def _CheckGeneratedPairsOutputPaths(self, generator):
|
||||
"""Checks that the output path created by the generator exists.
|
||||
def _CheckGeneratedPairsOutputPaths(self, generator):
|
||||
"""Checks that the output path created by the generator exists.
|
||||
|
||||
Args:
|
||||
generator: TestDataGenerator instance.
|
||||
"""
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
output_path = generator.apm_output_paths[config_name]
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
output_path = generator.apm_output_paths[config_name]
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Plots statistics from WebRTC integration test logs.
|
||||
|
||||
Usage: $ python plot_webrtc_test_logs.py filename.txt
|
||||
@ -52,43 +51,43 @@ AVG_DELTA_FRAME_SIZE = ('avg_delta_frame_size_bytes',
|
||||
|
||||
# Settings.
|
||||
SETTINGS = [
|
||||
WIDTH,
|
||||
HEIGHT,
|
||||
FILENAME,
|
||||
NUM_FRAMES,
|
||||
WIDTH,
|
||||
HEIGHT,
|
||||
FILENAME,
|
||||
NUM_FRAMES,
|
||||
]
|
||||
|
||||
# Settings, options for x-axis.
|
||||
X_SETTINGS = [
|
||||
CORES,
|
||||
FRAMERATE,
|
||||
DENOISING,
|
||||
RESILIENCE,
|
||||
ERROR_CONCEALMENT,
|
||||
BITRATE, # TODO(asapersson): Needs to be last.
|
||||
CORES,
|
||||
FRAMERATE,
|
||||
DENOISING,
|
||||
RESILIENCE,
|
||||
ERROR_CONCEALMENT,
|
||||
BITRATE, # TODO(asapersson): Needs to be last.
|
||||
]
|
||||
|
||||
# Settings, options for subplots.
|
||||
SUBPLOT_SETTINGS = [
|
||||
CODEC_TYPE,
|
||||
ENCODER_IMPLEMENTATION_NAME,
|
||||
DECODER_IMPLEMENTATION_NAME,
|
||||
CODEC_IMPLEMENTATION_NAME,
|
||||
CODEC_TYPE,
|
||||
ENCODER_IMPLEMENTATION_NAME,
|
||||
DECODER_IMPLEMENTATION_NAME,
|
||||
CODEC_IMPLEMENTATION_NAME,
|
||||
] + X_SETTINGS
|
||||
|
||||
# Results.
|
||||
RESULTS = [
|
||||
PSNR,
|
||||
SSIM,
|
||||
ENC_BITRATE,
|
||||
NUM_DROPPED_FRAMES,
|
||||
TIME_TO_TARGET,
|
||||
ENCODE_SPEED_FPS,
|
||||
DECODE_SPEED_FPS,
|
||||
QP,
|
||||
CPU_USAGE,
|
||||
AVG_KEY_FRAME_SIZE,
|
||||
AVG_DELTA_FRAME_SIZE,
|
||||
PSNR,
|
||||
SSIM,
|
||||
ENC_BITRATE,
|
||||
NUM_DROPPED_FRAMES,
|
||||
TIME_TO_TARGET,
|
||||
ENCODE_SPEED_FPS,
|
||||
DECODE_SPEED_FPS,
|
||||
QP,
|
||||
CPU_USAGE,
|
||||
AVG_KEY_FRAME_SIZE,
|
||||
AVG_DELTA_FRAME_SIZE,
|
||||
]
|
||||
|
||||
METRICS_TO_PARSE = SETTINGS + SUBPLOT_SETTINGS + RESULTS
|
||||
@ -102,7 +101,7 @@ GRID_COLOR = [0.45, 0.45, 0.45]
|
||||
|
||||
|
||||
def ParseSetting(filename, setting):
|
||||
"""Parses setting from file.
|
||||
"""Parses setting from file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
@ -111,36 +110,36 @@ def ParseSetting(filename, setting):
|
||||
Returns:
|
||||
A list holding parsed settings, e.g. ['width: 128.0', 'width: 160.0'] """
|
||||
|
||||
settings = []
|
||||
settings = []
|
||||
|
||||
settings_file = open(filename)
|
||||
while True:
|
||||
line = settings_file.readline()
|
||||
if not line:
|
||||
break
|
||||
if re.search(r'%s' % EVENT_START, line):
|
||||
# Parse event.
|
||||
parsed = {}
|
||||
while True:
|
||||
settings_file = open(filename)
|
||||
while True:
|
||||
line = settings_file.readline()
|
||||
if not line:
|
||||
break
|
||||
if re.search(r'%s' % EVENT_END, line):
|
||||
# Add parsed setting to list.
|
||||
if setting in parsed:
|
||||
s = setting + ': ' + str(parsed[setting])
|
||||
if s not in settings:
|
||||
settings.append(s)
|
||||
break
|
||||
break
|
||||
if re.search(r'%s' % EVENT_START, line):
|
||||
# Parse event.
|
||||
parsed = {}
|
||||
while True:
|
||||
line = settings_file.readline()
|
||||
if not line:
|
||||
break
|
||||
if re.search(r'%s' % EVENT_END, line):
|
||||
# Add parsed setting to list.
|
||||
if setting in parsed:
|
||||
s = setting + ': ' + str(parsed[setting])
|
||||
if s not in settings:
|
||||
settings.append(s)
|
||||
break
|
||||
|
||||
TryFindMetric(parsed, line)
|
||||
TryFindMetric(parsed, line)
|
||||
|
||||
settings_file.close()
|
||||
return settings
|
||||
settings_file.close()
|
||||
return settings
|
||||
|
||||
|
||||
def ParseMetrics(filename, setting1, setting2):
|
||||
"""Parses metrics from file.
|
||||
"""Parses metrics from file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
@ -175,82 +174,82 @@ def ParseMetrics(filename, setting1, setting2):
|
||||
}
|
||||
} """
|
||||
|
||||
metrics = {}
|
||||
metrics = {}
|
||||
|
||||
# Parse events.
|
||||
settings_file = open(filename)
|
||||
while True:
|
||||
line = settings_file.readline()
|
||||
if not line:
|
||||
break
|
||||
if re.search(r'%s' % EVENT_START, line):
|
||||
# Parse event.
|
||||
parsed = {}
|
||||
while True:
|
||||
# Parse events.
|
||||
settings_file = open(filename)
|
||||
while True:
|
||||
line = settings_file.readline()
|
||||
if not line:
|
||||
break
|
||||
if re.search(r'%s' % EVENT_END, line):
|
||||
# Add parsed values to metrics.
|
||||
key1 = setting1 + ': ' + str(parsed[setting1])
|
||||
key2 = setting2 + ': ' + str(parsed[setting2])
|
||||
if key1 not in metrics:
|
||||
metrics[key1] = {}
|
||||
if key2 not in metrics[key1]:
|
||||
metrics[key1][key2] = {}
|
||||
break
|
||||
if re.search(r'%s' % EVENT_START, line):
|
||||
# Parse event.
|
||||
parsed = {}
|
||||
while True:
|
||||
line = settings_file.readline()
|
||||
if not line:
|
||||
break
|
||||
if re.search(r'%s' % EVENT_END, line):
|
||||
# Add parsed values to metrics.
|
||||
key1 = setting1 + ': ' + str(parsed[setting1])
|
||||
key2 = setting2 + ': ' + str(parsed[setting2])
|
||||
if key1 not in metrics:
|
||||
metrics[key1] = {}
|
||||
if key2 not in metrics[key1]:
|
||||
metrics[key1][key2] = {}
|
||||
|
||||
for label in parsed:
|
||||
if label not in metrics[key1][key2]:
|
||||
metrics[key1][key2][label] = []
|
||||
metrics[key1][key2][label].append(parsed[label])
|
||||
for label in parsed:
|
||||
if label not in metrics[key1][key2]:
|
||||
metrics[key1][key2][label] = []
|
||||
metrics[key1][key2][label].append(parsed[label])
|
||||
|
||||
break
|
||||
break
|
||||
|
||||
TryFindMetric(parsed, line)
|
||||
TryFindMetric(parsed, line)
|
||||
|
||||
settings_file.close()
|
||||
return metrics
|
||||
settings_file.close()
|
||||
return metrics
|
||||
|
||||
|
||||
def TryFindMetric(parsed, line):
|
||||
for metric in METRICS_TO_PARSE:
|
||||
name = metric[0]
|
||||
label = metric[1]
|
||||
if re.search(r'%s' % name, line):
|
||||
found, value = GetMetric(name, line)
|
||||
if found:
|
||||
parsed[label] = value
|
||||
return
|
||||
for metric in METRICS_TO_PARSE:
|
||||
name = metric[0]
|
||||
label = metric[1]
|
||||
if re.search(r'%s' % name, line):
|
||||
found, value = GetMetric(name, line)
|
||||
if found:
|
||||
parsed[label] = value
|
||||
return
|
||||
|
||||
|
||||
def GetMetric(name, string):
|
||||
# Float (e.g. bitrate = 98.8253).
|
||||
pattern = r'%s\s*[:=]\s*([+-]?\d+\.*\d*)' % name
|
||||
m = re.search(r'%s' % pattern, string)
|
||||
if m is not None:
|
||||
return StringToFloat(m.group(1))
|
||||
# Float (e.g. bitrate = 98.8253).
|
||||
pattern = r'%s\s*[:=]\s*([+-]?\d+\.*\d*)' % name
|
||||
m = re.search(r'%s' % pattern, string)
|
||||
if m is not None:
|
||||
return StringToFloat(m.group(1))
|
||||
|
||||
# Alphanumeric characters (e.g. codec type : VP8).
|
||||
pattern = r'%s\s*[:=]\s*(\w+)' % name
|
||||
m = re.search(r'%s' % pattern, string)
|
||||
if m is not None:
|
||||
return True, m.group(1)
|
||||
# Alphanumeric characters (e.g. codec type : VP8).
|
||||
pattern = r'%s\s*[:=]\s*(\w+)' % name
|
||||
m = re.search(r'%s' % pattern, string)
|
||||
if m is not None:
|
||||
return True, m.group(1)
|
||||
|
||||
return False, -1
|
||||
return False, -1
|
||||
|
||||
|
||||
def StringToFloat(value):
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
print "Not a float, skipped %s" % value
|
||||
return False, -1
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
print "Not a float, skipped %s" % value
|
||||
return False, -1
|
||||
|
||||
return True, value
|
||||
return True, value
|
||||
|
||||
|
||||
def Plot(y_metric, x_metric, metrics):
|
||||
"""Plots y_metric vs x_metric per key in metrics.
|
||||
"""Plots y_metric vs x_metric per key in metrics.
|
||||
|
||||
For example:
|
||||
y_metric = 'PSNR (dB)'
|
||||
@ -266,26 +265,31 @@ def Plot(y_metric, x_metric, metrics):
|
||||
},
|
||||
}
|
||||
"""
|
||||
for key in sorted(metrics):
|
||||
data = metrics[key]
|
||||
if y_metric not in data:
|
||||
print "Failed to find metric: %s" % y_metric
|
||||
continue
|
||||
for key in sorted(metrics):
|
||||
data = metrics[key]
|
||||
if y_metric not in data:
|
||||
print "Failed to find metric: %s" % y_metric
|
||||
continue
|
||||
|
||||
y = numpy.array(data[y_metric])
|
||||
x = numpy.array(data[x_metric])
|
||||
if len(y) != len(x):
|
||||
print "Length mismatch for %s, %s" % (y, x)
|
||||
continue
|
||||
y = numpy.array(data[y_metric])
|
||||
x = numpy.array(data[x_metric])
|
||||
if len(y) != len(x):
|
||||
print "Length mismatch for %s, %s" % (y, x)
|
||||
continue
|
||||
|
||||
label = y_metric + ' - ' + str(key)
|
||||
label = y_metric + ' - ' + str(key)
|
||||
|
||||
plt.plot(x, y, label=label, linewidth=1.5, marker='o', markersize=5,
|
||||
markeredgewidth=0.0)
|
||||
plt.plot(x,
|
||||
y,
|
||||
label=label,
|
||||
linewidth=1.5,
|
||||
marker='o',
|
||||
markersize=5,
|
||||
markeredgewidth=0.0)
|
||||
|
||||
|
||||
def PlotFigure(settings, y_metrics, x_metric, metrics, title):
|
||||
"""Plots metrics in y_metrics list. One figure is plotted and each entry
|
||||
"""Plots metrics in y_metrics list. One figure is plotted and each entry
|
||||
in the list is plotted in a subplot (and sorted per settings).
|
||||
|
||||
For example:
|
||||
@ -295,136 +299,140 @@ def PlotFigure(settings, y_metrics, x_metric, metrics, title):
|
||||
|
||||
"""
|
||||
|
||||
plt.figure()
|
||||
plt.suptitle(title, fontsize='large', fontweight='bold')
|
||||
settings.sort()
|
||||
rows = len(settings)
|
||||
cols = 1
|
||||
pos = 1
|
||||
while pos <= rows:
|
||||
plt.rc('grid', color=GRID_COLOR)
|
||||
ax = plt.subplot(rows, cols, pos)
|
||||
plt.grid()
|
||||
plt.setp(ax.get_xticklabels(), visible=(pos == rows), fontsize='large')
|
||||
plt.setp(ax.get_yticklabels(), fontsize='large')
|
||||
setting = settings[pos - 1]
|
||||
Plot(y_metrics[pos - 1], x_metric, metrics[setting])
|
||||
if setting.startswith(WIDTH[1]):
|
||||
plt.title(setting, fontsize='medium')
|
||||
plt.legend(fontsize='large', loc='best')
|
||||
pos += 1
|
||||
plt.figure()
|
||||
plt.suptitle(title, fontsize='large', fontweight='bold')
|
||||
settings.sort()
|
||||
rows = len(settings)
|
||||
cols = 1
|
||||
pos = 1
|
||||
while pos <= rows:
|
||||
plt.rc('grid', color=GRID_COLOR)
|
||||
ax = plt.subplot(rows, cols, pos)
|
||||
plt.grid()
|
||||
plt.setp(ax.get_xticklabels(), visible=(pos == rows), fontsize='large')
|
||||
plt.setp(ax.get_yticklabels(), fontsize='large')
|
||||
setting = settings[pos - 1]
|
||||
Plot(y_metrics[pos - 1], x_metric, metrics[setting])
|
||||
if setting.startswith(WIDTH[1]):
|
||||
plt.title(setting, fontsize='medium')
|
||||
plt.legend(fontsize='large', loc='best')
|
||||
pos += 1
|
||||
|
||||
plt.xlabel(x_metric, fontsize='large')
|
||||
plt.subplots_adjust(left=0.06, right=0.98, bottom=0.05, top=0.94, hspace=0.08)
|
||||
plt.xlabel(x_metric, fontsize='large')
|
||||
plt.subplots_adjust(left=0.06,
|
||||
right=0.98,
|
||||
bottom=0.05,
|
||||
top=0.94,
|
||||
hspace=0.08)
|
||||
|
||||
|
||||
def GetTitle(filename, setting):
|
||||
title = ''
|
||||
if setting != CODEC_IMPLEMENTATION_NAME[1] and setting != CODEC_TYPE[1]:
|
||||
codec_types = ParseSetting(filename, CODEC_TYPE[1])
|
||||
for i in range(0, len(codec_types)):
|
||||
title += codec_types[i] + ', '
|
||||
title = ''
|
||||
if setting != CODEC_IMPLEMENTATION_NAME[1] and setting != CODEC_TYPE[1]:
|
||||
codec_types = ParseSetting(filename, CODEC_TYPE[1])
|
||||
for i in range(0, len(codec_types)):
|
||||
title += codec_types[i] + ', '
|
||||
|
||||
if setting != CORES[1]:
|
||||
cores = ParseSetting(filename, CORES[1])
|
||||
for i in range(0, len(cores)):
|
||||
title += cores[i].split('.')[0] + ', '
|
||||
if setting != CORES[1]:
|
||||
cores = ParseSetting(filename, CORES[1])
|
||||
for i in range(0, len(cores)):
|
||||
title += cores[i].split('.')[0] + ', '
|
||||
|
||||
if setting != FRAMERATE[1]:
|
||||
framerate = ParseSetting(filename, FRAMERATE[1])
|
||||
for i in range(0, len(framerate)):
|
||||
title += framerate[i].split('.')[0] + ', '
|
||||
if setting != FRAMERATE[1]:
|
||||
framerate = ParseSetting(filename, FRAMERATE[1])
|
||||
for i in range(0, len(framerate)):
|
||||
title += framerate[i].split('.')[0] + ', '
|
||||
|
||||
if (setting != CODEC_IMPLEMENTATION_NAME[1] and
|
||||
setting != ENCODER_IMPLEMENTATION_NAME[1]):
|
||||
enc_names = ParseSetting(filename, ENCODER_IMPLEMENTATION_NAME[1])
|
||||
for i in range(0, len(enc_names)):
|
||||
title += enc_names[i] + ', '
|
||||
if (setting != CODEC_IMPLEMENTATION_NAME[1]
|
||||
and setting != ENCODER_IMPLEMENTATION_NAME[1]):
|
||||
enc_names = ParseSetting(filename, ENCODER_IMPLEMENTATION_NAME[1])
|
||||
for i in range(0, len(enc_names)):
|
||||
title += enc_names[i] + ', '
|
||||
|
||||
if (setting != CODEC_IMPLEMENTATION_NAME[1] and
|
||||
setting != DECODER_IMPLEMENTATION_NAME[1]):
|
||||
dec_names = ParseSetting(filename, DECODER_IMPLEMENTATION_NAME[1])
|
||||
for i in range(0, len(dec_names)):
|
||||
title += dec_names[i] + ', '
|
||||
if (setting != CODEC_IMPLEMENTATION_NAME[1]
|
||||
and setting != DECODER_IMPLEMENTATION_NAME[1]):
|
||||
dec_names = ParseSetting(filename, DECODER_IMPLEMENTATION_NAME[1])
|
||||
for i in range(0, len(dec_names)):
|
||||
title += dec_names[i] + ', '
|
||||
|
||||
filenames = ParseSetting(filename, FILENAME[1])
|
||||
title += filenames[0].split('_')[0]
|
||||
filenames = ParseSetting(filename, FILENAME[1])
|
||||
title += filenames[0].split('_')[0]
|
||||
|
||||
num_frames = ParseSetting(filename, NUM_FRAMES[1])
|
||||
for i in range(0, len(num_frames)):
|
||||
title += ' (' + num_frames[i].split('.')[0] + ')'
|
||||
num_frames = ParseSetting(filename, NUM_FRAMES[1])
|
||||
for i in range(0, len(num_frames)):
|
||||
title += ' (' + num_frames[i].split('.')[0] + ')'
|
||||
|
||||
return title
|
||||
return title
|
||||
|
||||
|
||||
def ToString(input_list):
|
||||
return ToStringWithoutMetric(input_list, ('', ''))
|
||||
return ToStringWithoutMetric(input_list, ('', ''))
|
||||
|
||||
|
||||
def ToStringWithoutMetric(input_list, metric):
|
||||
i = 1
|
||||
output_str = ""
|
||||
for m in input_list:
|
||||
if m != metric:
|
||||
output_str = output_str + ("%s. %s\n" % (i, m[1]))
|
||||
i += 1
|
||||
return output_str
|
||||
i = 1
|
||||
output_str = ""
|
||||
for m in input_list:
|
||||
if m != metric:
|
||||
output_str = output_str + ("%s. %s\n" % (i, m[1]))
|
||||
i += 1
|
||||
return output_str
|
||||
|
||||
|
||||
def GetIdx(text_list):
|
||||
return int(raw_input(text_list)) - 1
|
||||
return int(raw_input(text_list)) - 1
|
||||
|
||||
|
||||
def main():
|
||||
filename = sys.argv[1]
|
||||
filename = sys.argv[1]
|
||||
|
||||
# Setup.
|
||||
idx_metric = GetIdx("Choose metric:\n0. All\n%s" % ToString(RESULTS))
|
||||
if idx_metric == -1:
|
||||
# Plot all metrics. One subplot for each metric.
|
||||
# Per subplot: metric vs bitrate (per resolution).
|
||||
cores = ParseSetting(filename, CORES[1])
|
||||
setting1 = CORES[1]
|
||||
setting2 = WIDTH[1]
|
||||
sub_keys = [cores[0]] * len(Y_METRICS)
|
||||
y_metrics = Y_METRICS
|
||||
x_metric = BITRATE[1]
|
||||
else:
|
||||
resolutions = ParseSetting(filename, WIDTH[1])
|
||||
idx = GetIdx("Select metric for x-axis:\n%s" % ToString(X_SETTINGS))
|
||||
if X_SETTINGS[idx] == BITRATE:
|
||||
idx = GetIdx("Plot per:\n%s" % ToStringWithoutMetric(SUBPLOT_SETTINGS,
|
||||
BITRATE))
|
||||
idx_setting = METRICS_TO_PARSE.index(SUBPLOT_SETTINGS[idx])
|
||||
# Plot one metric. One subplot for each resolution.
|
||||
# Per subplot: metric vs bitrate (per setting).
|
||||
setting1 = WIDTH[1]
|
||||
setting2 = METRICS_TO_PARSE[idx_setting][1]
|
||||
sub_keys = resolutions
|
||||
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
|
||||
x_metric = BITRATE[1]
|
||||
# Setup.
|
||||
idx_metric = GetIdx("Choose metric:\n0. All\n%s" % ToString(RESULTS))
|
||||
if idx_metric == -1:
|
||||
# Plot all metrics. One subplot for each metric.
|
||||
# Per subplot: metric vs bitrate (per resolution).
|
||||
cores = ParseSetting(filename, CORES[1])
|
||||
setting1 = CORES[1]
|
||||
setting2 = WIDTH[1]
|
||||
sub_keys = [cores[0]] * len(Y_METRICS)
|
||||
y_metrics = Y_METRICS
|
||||
x_metric = BITRATE[1]
|
||||
else:
|
||||
# Plot one metric. One subplot for each resolution.
|
||||
# Per subplot: metric vs setting (per bitrate).
|
||||
setting1 = WIDTH[1]
|
||||
setting2 = BITRATE[1]
|
||||
sub_keys = resolutions
|
||||
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
|
||||
x_metric = X_SETTINGS[idx][1]
|
||||
resolutions = ParseSetting(filename, WIDTH[1])
|
||||
idx = GetIdx("Select metric for x-axis:\n%s" % ToString(X_SETTINGS))
|
||||
if X_SETTINGS[idx] == BITRATE:
|
||||
idx = GetIdx("Plot per:\n%s" %
|
||||
ToStringWithoutMetric(SUBPLOT_SETTINGS, BITRATE))
|
||||
idx_setting = METRICS_TO_PARSE.index(SUBPLOT_SETTINGS[idx])
|
||||
# Plot one metric. One subplot for each resolution.
|
||||
# Per subplot: metric vs bitrate (per setting).
|
||||
setting1 = WIDTH[1]
|
||||
setting2 = METRICS_TO_PARSE[idx_setting][1]
|
||||
sub_keys = resolutions
|
||||
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
|
||||
x_metric = BITRATE[1]
|
||||
else:
|
||||
# Plot one metric. One subplot for each resolution.
|
||||
# Per subplot: metric vs setting (per bitrate).
|
||||
setting1 = WIDTH[1]
|
||||
setting2 = BITRATE[1]
|
||||
sub_keys = resolutions
|
||||
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
|
||||
x_metric = X_SETTINGS[idx][1]
|
||||
|
||||
metrics = ParseMetrics(filename, setting1, setting2)
|
||||
metrics = ParseMetrics(filename, setting1, setting2)
|
||||
|
||||
# Stretch fig size.
|
||||
figsize = plt.rcParams["figure.figsize"]
|
||||
figsize[0] *= FIG_SIZE_SCALE_FACTOR_X
|
||||
figsize[1] *= FIG_SIZE_SCALE_FACTOR_Y
|
||||
plt.rcParams["figure.figsize"] = figsize
|
||||
# Stretch fig size.
|
||||
figsize = plt.rcParams["figure.figsize"]
|
||||
figsize[0] *= FIG_SIZE_SCALE_FACTOR_X
|
||||
figsize[1] *= FIG_SIZE_SCALE_FACTOR_Y
|
||||
plt.rcParams["figure.figsize"] = figsize
|
||||
|
||||
PlotFigure(sub_keys, y_metrics, x_metric, metrics,
|
||||
GetTitle(filename, setting2))
|
||||
PlotFigure(sub_keys, y_metrics, x_metric, metrics,
|
||||
GetTitle(filename, setting2))
|
||||
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -20,146 +20,145 @@ from presubmit_test_mocks import MockInputApi, MockOutputApi, MockFile, MockChan
|
||||
|
||||
|
||||
class CheckBugEntryFieldTest(unittest.TestCase):
|
||||
def testCommitMessageBugEntryWithNoError(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['webrtc:1234'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
def testCommitMessageBugEntryWithNoError(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['webrtc:1234'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
|
||||
def testCommitMessageBugEntryReturnError(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['webrtc:1234', 'webrtc=4321'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertEqual(('Bogus Bug entry: webrtc=4321. Please specify'
|
||||
' the issue tracker prefix and the issue number,'
|
||||
' separated by a colon, e.g. webrtc:123 or'
|
||||
' chromium:12345.'), str(errors[0]))
|
||||
def testCommitMessageBugEntryReturnError(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['webrtc:1234', 'webrtc=4321'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertEqual(('Bogus Bug entry: webrtc=4321. Please specify'
|
||||
' the issue tracker prefix and the issue number,'
|
||||
' separated by a colon, e.g. webrtc:123 or'
|
||||
' chromium:12345.'), str(errors[0]))
|
||||
|
||||
def testCommitMessageBugEntryWithoutPrefix(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['1234'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertEqual(('Bug entry requires issue tracker prefix, '
|
||||
'e.g. webrtc:1234'), str(errors[0]))
|
||||
def testCommitMessageBugEntryWithoutPrefix(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['1234'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertEqual(('Bug entry requires issue tracker prefix, '
|
||||
'e.g. webrtc:1234'), str(errors[0]))
|
||||
|
||||
def testCommitMessageBugEntryIsNone(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['None'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
def testCommitMessageBugEntryIsNone(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change = MockChange([], ['None'])
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
|
||||
def testCommitMessageBugEntrySupportInternalBugReference(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change.BUG = 'b/12345'
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
mock_input_api.change.BUG = 'b/12345, webrtc:1234'
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
def testCommitMessageBugEntrySupportInternalBugReference(self):
|
||||
mock_input_api = MockInputApi()
|
||||
mock_output_api = MockOutputApi()
|
||||
mock_input_api.change.BUG = 'b/12345'
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
mock_input_api.change.BUG = 'b/12345, webrtc:1234'
|
||||
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
|
||||
mock_output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
|
||||
|
||||
class CheckNewlineAtTheEndOfProtoFilesTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
self.proto_file_path = os.path.join(self.tmp_dir, 'foo.proto')
|
||||
self.input_api = MockInputApi()
|
||||
self.output_api = MockOutputApi()
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
self.proto_file_path = os.path.join(self.tmp_dir, 'foo.proto')
|
||||
self.input_api = MockInputApi()
|
||||
self.output_api = MockOutputApi()
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
def testErrorIfProtoFileDoesNotEndWithNewline(self):
|
||||
self._GenerateProtoWithoutNewlineAtTheEnd()
|
||||
self.input_api.files = [MockFile(self.proto_file_path)]
|
||||
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(
|
||||
self.input_api, self.output_api, lambda x: True)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertEqual(
|
||||
'File %s must end with exactly one newline.' %
|
||||
self.proto_file_path, str(errors[0]))
|
||||
|
||||
def testErrorIfProtoFileDoesNotEndWithNewline(self):
|
||||
self._GenerateProtoWithoutNewlineAtTheEnd()
|
||||
self.input_api.files = [MockFile(self.proto_file_path)]
|
||||
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(self.input_api,
|
||||
self.output_api,
|
||||
lambda x: True)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertEqual(
|
||||
'File %s must end with exactly one newline.' % self.proto_file_path,
|
||||
str(errors[0]))
|
||||
def testNoErrorIfProtoFileEndsWithNewline(self):
|
||||
self._GenerateProtoWithNewlineAtTheEnd()
|
||||
self.input_api.files = [MockFile(self.proto_file_path)]
|
||||
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(
|
||||
self.input_api, self.output_api, lambda x: True)
|
||||
self.assertEqual(0, len(errors))
|
||||
|
||||
def testNoErrorIfProtoFileEndsWithNewline(self):
|
||||
self._GenerateProtoWithNewlineAtTheEnd()
|
||||
self.input_api.files = [MockFile(self.proto_file_path)]
|
||||
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(self.input_api,
|
||||
self.output_api,
|
||||
lambda x: True)
|
||||
self.assertEqual(0, len(errors))
|
||||
|
||||
def _GenerateProtoWithNewlineAtTheEnd(self):
|
||||
with open(self.proto_file_path, 'w') as f:
|
||||
f.write(textwrap.dedent("""
|
||||
def _GenerateProtoWithNewlineAtTheEnd(self):
|
||||
with open(self.proto_file_path, 'w') as f:
|
||||
f.write(
|
||||
textwrap.dedent("""
|
||||
syntax = "proto2";
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
package webrtc.audioproc;
|
||||
"""))
|
||||
|
||||
def _GenerateProtoWithoutNewlineAtTheEnd(self):
|
||||
with open(self.proto_file_path, 'w') as f:
|
||||
f.write(textwrap.dedent("""
|
||||
def _GenerateProtoWithoutNewlineAtTheEnd(self):
|
||||
with open(self.proto_file_path, 'w') as f:
|
||||
f.write(
|
||||
textwrap.dedent("""
|
||||
syntax = "proto2";
|
||||
option optimize_for = LITE_RUNTIME;
|
||||
package webrtc.audioproc;"""))
|
||||
|
||||
|
||||
class CheckNoMixingSourcesTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
self.file_path = os.path.join(self.tmp_dir, 'BUILD.gn')
|
||||
self.input_api = MockInputApi()
|
||||
self.output_api = MockOutputApi()
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
self.file_path = os.path.join(self.tmp_dir, 'BUILD.gn')
|
||||
self.input_api = MockInputApi()
|
||||
self.output_api = MockOutputApi()
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
def testErrorIfCAndCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.cc', 'bar.h'])
|
||||
|
||||
def testErrorIfCAndCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.cc', 'bar.h'])
|
||||
def testErrorIfCAndObjCAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.m', 'bar.h'])
|
||||
|
||||
def testErrorIfCAndObjCAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.m', 'bar.h'])
|
||||
def testErrorIfCAndObjCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.mm', 'bar.h'])
|
||||
|
||||
def testErrorIfCAndObjCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.mm', 'bar.h'])
|
||||
def testErrorIfCppAndObjCAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.m', 'bar.h'])
|
||||
|
||||
def testErrorIfCppAndObjCAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.m', 'bar.h'])
|
||||
def testErrorIfCppAndObjCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.mm', 'bar.h'])
|
||||
|
||||
def testErrorIfCppAndObjCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.mm', 'bar.h'])
|
||||
def testNoErrorIfOnlyC(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.c', 'bar.c', 'bar.h'])
|
||||
|
||||
def testNoErrorIfOnlyC(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.c', 'bar.c', 'bar.h'])
|
||||
def testNoErrorIfOnlyCpp(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.cc', 'bar.cc', 'bar.h'])
|
||||
|
||||
def testNoErrorIfOnlyCpp(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.cc', 'bar.cc', 'bar.h'])
|
||||
def testNoErrorIfOnlyObjC(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.m', 'bar.h'])
|
||||
|
||||
def testNoErrorIfOnlyObjC(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.m', 'bar.h'])
|
||||
def testNoErrorIfOnlyObjCpp(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.mm', 'bar.mm', 'bar.h'])
|
||||
|
||||
def testNoErrorIfOnlyObjCpp(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.mm', 'bar.mm', 'bar.h'])
|
||||
def testNoErrorIfObjCAndObjCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.mm', 'bar.h'])
|
||||
|
||||
def testNoErrorIfObjCAndObjCppAreMixed(self):
|
||||
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.mm', 'bar.h'])
|
||||
|
||||
def testNoErrorIfSourcesAreInExclusiveIfBranches(self):
|
||||
self._GenerateBuildFile(textwrap.dedent("""
|
||||
def testNoErrorIfSourcesAreInExclusiveIfBranches(self):
|
||||
self._GenerateBuildFile(
|
||||
textwrap.dedent("""
|
||||
rtc_library("bar_foo") {
|
||||
if (is_win) {
|
||||
sources = [
|
||||
@ -185,14 +184,15 @@ class CheckNoMixingSourcesTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
"""))
|
||||
self.input_api.files = [MockFile(self.file_path)]
|
||||
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
|
||||
[MockFile(self.file_path)],
|
||||
self.output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
self.input_api.files = [MockFile(self.file_path)]
|
||||
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
|
||||
[MockFile(self.file_path)],
|
||||
self.output_api)
|
||||
self.assertEqual(0, len(errors))
|
||||
|
||||
def testErrorIfSourcesAreNotInExclusiveIfBranches(self):
|
||||
self._GenerateBuildFile(textwrap.dedent("""
|
||||
def testErrorIfSourcesAreNotInExclusiveIfBranches(self):
|
||||
self._GenerateBuildFile(
|
||||
textwrap.dedent("""
|
||||
rtc_library("bar_foo") {
|
||||
if (is_win) {
|
||||
sources = [
|
||||
@ -224,21 +224,23 @@ class CheckNoMixingSourcesTest(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
"""))
|
||||
self.input_api.files = [MockFile(self.file_path)]
|
||||
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
|
||||
[MockFile(self.file_path)],
|
||||
self.output_api)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertTrue('bar.cc' in str(errors[0]))
|
||||
self.assertTrue('bar.mm' in str(errors[0]))
|
||||
self.assertTrue('foo.cc' in str(errors[0]))
|
||||
self.assertTrue('foo.mm' in str(errors[0]))
|
||||
self.assertTrue('bar.m' in str(errors[0]))
|
||||
self.assertTrue('bar.c' in str(errors[0]))
|
||||
self.input_api.files = [MockFile(self.file_path)]
|
||||
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
|
||||
[MockFile(self.file_path)],
|
||||
self.output_api)
|
||||
self.assertEqual(1, len(errors))
|
||||
self.assertTrue('bar.cc' in str(errors[0]))
|
||||
self.assertTrue('bar.mm' in str(errors[0]))
|
||||
self.assertTrue('foo.cc' in str(errors[0]))
|
||||
self.assertTrue('foo.mm' in str(errors[0]))
|
||||
self.assertTrue('bar.m' in str(errors[0]))
|
||||
self.assertTrue('bar.c' in str(errors[0]))
|
||||
|
||||
def _AssertNumberOfErrorsWithSources(self, number_of_errors, sources):
|
||||
assert len(sources) == 3, 'This function accepts a list of 3 source files'
|
||||
self._GenerateBuildFile(textwrap.dedent("""
|
||||
def _AssertNumberOfErrorsWithSources(self, number_of_errors, sources):
|
||||
assert len(
|
||||
sources) == 3, 'This function accepts a list of 3 source files'
|
||||
self._GenerateBuildFile(
|
||||
textwrap.dedent("""
|
||||
rtc_static_library("bar_foo") {
|
||||
sources = [
|
||||
"%s",
|
||||
@ -254,20 +256,20 @@ class CheckNoMixingSourcesTest(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
""" % (tuple(sources) * 2)))
|
||||
self.input_api.files = [MockFile(self.file_path)]
|
||||
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
|
||||
[MockFile(self.file_path)],
|
||||
self.output_api)
|
||||
self.assertEqual(number_of_errors, len(errors))
|
||||
if number_of_errors == 1:
|
||||
for source in sources:
|
||||
if not source.endswith('.h'):
|
||||
self.assertTrue(source in str(errors[0]))
|
||||
self.input_api.files = [MockFile(self.file_path)]
|
||||
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
|
||||
[MockFile(self.file_path)],
|
||||
self.output_api)
|
||||
self.assertEqual(number_of_errors, len(errors))
|
||||
if number_of_errors == 1:
|
||||
for source in sources:
|
||||
if not source.endswith('.h'):
|
||||
self.assertTrue(source in str(errors[0]))
|
||||
|
||||
def _GenerateBuildFile(self, content):
|
||||
with open(self.file_path, 'w') as f:
|
||||
f.write(content)
|
||||
def _GenerateBuildFile(self, content):
|
||||
with open(self.file_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -14,118 +14,125 @@ import re
|
||||
|
||||
|
||||
class MockInputApi(object):
|
||||
"""Mock class for the InputApi class.
|
||||
"""Mock class for the InputApi class.
|
||||
|
||||
This class can be used for unittests for presubmit by initializing the files
|
||||
attribute as the list of changed files.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.change = MockChange([], [])
|
||||
self.files = []
|
||||
self.presubmit_local_path = os.path.dirname(__file__)
|
||||
def __init__(self):
|
||||
self.change = MockChange([], [])
|
||||
self.files = []
|
||||
self.presubmit_local_path = os.path.dirname(__file__)
|
||||
|
||||
def AffectedSourceFiles(self, file_filter=None):
|
||||
return self.AffectedFiles(file_filter=file_filter)
|
||||
def AffectedSourceFiles(self, file_filter=None):
|
||||
return self.AffectedFiles(file_filter=file_filter)
|
||||
|
||||
def AffectedFiles(self, file_filter=None, include_deletes=False):
|
||||
# pylint: disable=unused-argument
|
||||
return self.files
|
||||
def AffectedFiles(self, file_filter=None, include_deletes=False):
|
||||
# pylint: disable=unused-argument
|
||||
return self.files
|
||||
|
||||
@classmethod
|
||||
def FilterSourceFile(cls, affected_file, files_to_check=(),
|
||||
files_to_skip=()):
|
||||
# pylint: disable=unused-argument
|
||||
return True
|
||||
@classmethod
|
||||
def FilterSourceFile(cls,
|
||||
affected_file,
|
||||
files_to_check=(),
|
||||
files_to_skip=()):
|
||||
# pylint: disable=unused-argument
|
||||
return True
|
||||
|
||||
def PresubmitLocalPath(self):
|
||||
return self.presubmit_local_path
|
||||
def PresubmitLocalPath(self):
|
||||
return self.presubmit_local_path
|
||||
|
||||
def ReadFile(self, affected_file, mode='rU'):
|
||||
filename = affected_file.AbsoluteLocalPath()
|
||||
for f in self.files:
|
||||
if f.LocalPath() == filename:
|
||||
with open(filename, mode) as f:
|
||||
return f.read()
|
||||
# Otherwise, file is not in our mock API.
|
||||
raise IOError, "No such file or directory: '%s'" % filename
|
||||
def ReadFile(self, affected_file, mode='rU'):
|
||||
filename = affected_file.AbsoluteLocalPath()
|
||||
for f in self.files:
|
||||
if f.LocalPath() == filename:
|
||||
with open(filename, mode) as f:
|
||||
return f.read()
|
||||
# Otherwise, file is not in our mock API.
|
||||
raise IOError, "No such file or directory: '%s'" % filename
|
||||
|
||||
|
||||
class MockOutputApi(object):
|
||||
"""Mock class for the OutputApi class.
|
||||
"""Mock class for the OutputApi class.
|
||||
|
||||
An instance of this class can be passed to presubmit unittests for outputing
|
||||
various types of results.
|
||||
"""
|
||||
|
||||
class PresubmitResult(object):
|
||||
def __init__(self, message, items=None, long_text=''):
|
||||
self.message = message
|
||||
self.items = items
|
||||
self.long_text = long_text
|
||||
class PresubmitResult(object):
|
||||
def __init__(self, message, items=None, long_text=''):
|
||||
self.message = message
|
||||
self.items = items
|
||||
self.long_text = long_text
|
||||
|
||||
def __repr__(self):
|
||||
return self.message
|
||||
def __repr__(self):
|
||||
return self.message
|
||||
|
||||
class PresubmitError(PresubmitResult):
|
||||
def __init__(self, message, items=None, long_text=''):
|
||||
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
|
||||
self.type = 'error'
|
||||
class PresubmitError(PresubmitResult):
|
||||
def __init__(self, message, items=None, long_text=''):
|
||||
MockOutputApi.PresubmitResult.__init__(self, message, items,
|
||||
long_text)
|
||||
self.type = 'error'
|
||||
|
||||
|
||||
class MockChange(object):
|
||||
"""Mock class for Change class.
|
||||
"""Mock class for Change class.
|
||||
|
||||
This class can be used in presubmit unittests to mock the query of the
|
||||
current change.
|
||||
"""
|
||||
|
||||
def __init__(self, changed_files, bugs_from_description, tags=None):
|
||||
self._changed_files = changed_files
|
||||
self._bugs_from_description = bugs_from_description
|
||||
self.tags = dict() if not tags else tags
|
||||
def __init__(self, changed_files, bugs_from_description, tags=None):
|
||||
self._changed_files = changed_files
|
||||
self._bugs_from_description = bugs_from_description
|
||||
self.tags = dict() if not tags else tags
|
||||
|
||||
def BugsFromDescription(self):
|
||||
return self._bugs_from_description
|
||||
def BugsFromDescription(self):
|
||||
return self._bugs_from_description
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Return tags directly as attributes on the object."""
|
||||
if not re.match(r"^[A-Z_]*$", attr):
|
||||
raise AttributeError(self, attr)
|
||||
return self.tags.get(attr)
|
||||
def __getattr__(self, attr):
|
||||
"""Return tags directly as attributes on the object."""
|
||||
if not re.match(r"^[A-Z_]*$", attr):
|
||||
raise AttributeError(self, attr)
|
||||
return self.tags.get(attr)
|
||||
|
||||
|
||||
class MockFile(object):
|
||||
"""Mock class for the File class.
|
||||
"""Mock class for the File class.
|
||||
|
||||
This class can be used to form the mock list of changed files in
|
||||
MockInputApi for presubmit unittests.
|
||||
"""
|
||||
|
||||
def __init__(self, local_path, new_contents=None, old_contents=None,
|
||||
action='A'):
|
||||
if new_contents is None:
|
||||
new_contents = ["Data"]
|
||||
self._local_path = local_path
|
||||
self._new_contents = new_contents
|
||||
self._changed_contents = [(i + 1, l) for i, l in enumerate(new_contents)]
|
||||
self._action = action
|
||||
self._old_contents = old_contents
|
||||
def __init__(self,
|
||||
local_path,
|
||||
new_contents=None,
|
||||
old_contents=None,
|
||||
action='A'):
|
||||
if new_contents is None:
|
||||
new_contents = ["Data"]
|
||||
self._local_path = local_path
|
||||
self._new_contents = new_contents
|
||||
self._changed_contents = [(i + 1, l)
|
||||
for i, l in enumerate(new_contents)]
|
||||
self._action = action
|
||||
self._old_contents = old_contents
|
||||
|
||||
def Action(self):
|
||||
return self._action
|
||||
def Action(self):
|
||||
return self._action
|
||||
|
||||
def ChangedContents(self):
|
||||
return self._changed_contents
|
||||
def ChangedContents(self):
|
||||
return self._changed_contents
|
||||
|
||||
def NewContents(self):
|
||||
return self._new_contents
|
||||
def NewContents(self):
|
||||
return self._new_contents
|
||||
|
||||
def LocalPath(self):
|
||||
return self._local_path
|
||||
def LocalPath(self):
|
||||
return self._local_path
|
||||
|
||||
def AbsoluteLocalPath(self):
|
||||
return self._local_path
|
||||
def AbsoluteLocalPath(self):
|
||||
return self._local_path
|
||||
|
||||
def OldContents(self):
|
||||
return self._old_contents
|
||||
def OldContents(self):
|
||||
return self._old_contents
|
||||
|
||||
@ -18,7 +18,6 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Chrome browsertests will throw away stderr; avoid that output gets lost.
|
||||
@ -26,131 +25,154 @@ sys.stderr = sys.stdout
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
"""Registers the command-line options."""
|
||||
usage = 'usage: %prog [options]'
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
"""Registers the command-line options."""
|
||||
usage = 'usage: %prog [options]'
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
|
||||
parser.add_option('--label', type='string', default='MY_TEST',
|
||||
help=('Label of the test, used to identify different '
|
||||
'tests. Default: %default'))
|
||||
parser.add_option('--ref_video', type='string',
|
||||
help='Reference video to compare with (YUV).')
|
||||
parser.add_option('--test_video', type='string',
|
||||
help=('Test video to be compared with the reference '
|
||||
'video (YUV).'))
|
||||
parser.add_option('--frame_analyzer', type='string',
|
||||
help='Path to the frame analyzer executable.')
|
||||
parser.add_option('--aligned_output_file', type='string',
|
||||
help='Path for output aligned YUV or Y4M file.')
|
||||
parser.add_option('--vmaf', type='string',
|
||||
help='Path to VMAF executable.')
|
||||
parser.add_option('--vmaf_model', type='string',
|
||||
help='Path to VMAF model.')
|
||||
parser.add_option('--vmaf_phone_model', action='store_true',
|
||||
help='Whether to use phone model in VMAF.')
|
||||
parser.add_option('--yuv_frame_width', type='int', default=640,
|
||||
help='Width of the YUV file\'s frames. Default: %default')
|
||||
parser.add_option('--yuv_frame_height', type='int', default=480,
|
||||
help='Height of the YUV file\'s frames. Default: %default')
|
||||
parser.add_option('--chartjson_result_file', type='str', default=None,
|
||||
help='Where to store perf results in chartjson format.')
|
||||
options, _ = parser.parse_args()
|
||||
parser.add_option('--label',
|
||||
type='string',
|
||||
default='MY_TEST',
|
||||
help=('Label of the test, used to identify different '
|
||||
'tests. Default: %default'))
|
||||
parser.add_option('--ref_video',
|
||||
type='string',
|
||||
help='Reference video to compare with (YUV).')
|
||||
parser.add_option('--test_video',
|
||||
type='string',
|
||||
help=('Test video to be compared with the reference '
|
||||
'video (YUV).'))
|
||||
parser.add_option('--frame_analyzer',
|
||||
type='string',
|
||||
help='Path to the frame analyzer executable.')
|
||||
parser.add_option('--aligned_output_file',
|
||||
type='string',
|
||||
help='Path for output aligned YUV or Y4M file.')
|
||||
parser.add_option('--vmaf', type='string', help='Path to VMAF executable.')
|
||||
parser.add_option('--vmaf_model',
|
||||
type='string',
|
||||
help='Path to VMAF model.')
|
||||
parser.add_option('--vmaf_phone_model',
|
||||
action='store_true',
|
||||
help='Whether to use phone model in VMAF.')
|
||||
parser.add_option(
|
||||
'--yuv_frame_width',
|
||||
type='int',
|
||||
default=640,
|
||||
help='Width of the YUV file\'s frames. Default: %default')
|
||||
parser.add_option(
|
||||
'--yuv_frame_height',
|
||||
type='int',
|
||||
default=480,
|
||||
help='Height of the YUV file\'s frames. Default: %default')
|
||||
parser.add_option('--chartjson_result_file',
|
||||
type='str',
|
||||
default=None,
|
||||
help='Where to store perf results in chartjson format.')
|
||||
options, _ = parser.parse_args()
|
||||
|
||||
if not options.ref_video:
|
||||
parser.error('You must provide a path to the reference video!')
|
||||
if not os.path.exists(options.ref_video):
|
||||
parser.error('Cannot find the reference video at %s' % options.ref_video)
|
||||
if not options.ref_video:
|
||||
parser.error('You must provide a path to the reference video!')
|
||||
if not os.path.exists(options.ref_video):
|
||||
parser.error('Cannot find the reference video at %s' %
|
||||
options.ref_video)
|
||||
|
||||
if not options.test_video:
|
||||
parser.error('You must provide a path to the test video!')
|
||||
if not os.path.exists(options.test_video):
|
||||
parser.error('Cannot find the test video at %s' % options.test_video)
|
||||
if not options.test_video:
|
||||
parser.error('You must provide a path to the test video!')
|
||||
if not os.path.exists(options.test_video):
|
||||
parser.error('Cannot find the test video at %s' % options.test_video)
|
||||
|
||||
if not options.frame_analyzer:
|
||||
parser.error('You must provide the path to the frame analyzer executable!')
|
||||
if not os.path.exists(options.frame_analyzer):
|
||||
parser.error('Cannot find frame analyzer executable at %s!' %
|
||||
options.frame_analyzer)
|
||||
if not options.frame_analyzer:
|
||||
parser.error(
|
||||
'You must provide the path to the frame analyzer executable!')
|
||||
if not os.path.exists(options.frame_analyzer):
|
||||
parser.error('Cannot find frame analyzer executable at %s!' %
|
||||
options.frame_analyzer)
|
||||
|
||||
if options.vmaf and not options.vmaf_model:
|
||||
parser.error('You must provide a path to a VMAF model to use VMAF.')
|
||||
if options.vmaf and not options.vmaf_model:
|
||||
parser.error('You must provide a path to a VMAF model to use VMAF.')
|
||||
|
||||
return options
|
||||
|
||||
return options
|
||||
|
||||
def _DevNull():
|
||||
"""On Windows, sometimes the inherited stdin handle from the parent process
|
||||
"""On Windows, sometimes the inherited stdin handle from the parent process
|
||||
fails. Workaround this by passing null to stdin to the subprocesses commands.
|
||||
This function can be used to create the null file handler.
|
||||
"""
|
||||
return open(os.devnull, 'r')
|
||||
return open(os.devnull, 'r')
|
||||
|
||||
|
||||
def _RunFrameAnalyzer(options, yuv_directory=None):
|
||||
"""Run frame analyzer to compare the videos and print output."""
|
||||
cmd = [
|
||||
options.frame_analyzer,
|
||||
'--label=%s' % options.label,
|
||||
'--reference_file=%s' % options.ref_video,
|
||||
'--test_file=%s' % options.test_video,
|
||||
'--width=%d' % options.yuv_frame_width,
|
||||
'--height=%d' % options.yuv_frame_height,
|
||||
]
|
||||
if options.chartjson_result_file:
|
||||
cmd.append('--chartjson_result_file=%s' % options.chartjson_result_file)
|
||||
if options.aligned_output_file:
|
||||
cmd.append('--aligned_output_file=%s' % options.aligned_output_file)
|
||||
if yuv_directory:
|
||||
cmd.append('--yuv_directory=%s' % yuv_directory)
|
||||
frame_analyzer = subprocess.Popen(cmd, stdin=_DevNull(),
|
||||
stdout=sys.stdout, stderr=sys.stderr)
|
||||
frame_analyzer.wait()
|
||||
if frame_analyzer.returncode != 0:
|
||||
print('Failed to run frame analyzer.')
|
||||
return frame_analyzer.returncode
|
||||
"""Run frame analyzer to compare the videos and print output."""
|
||||
cmd = [
|
||||
options.frame_analyzer,
|
||||
'--label=%s' % options.label,
|
||||
'--reference_file=%s' % options.ref_video,
|
||||
'--test_file=%s' % options.test_video,
|
||||
'--width=%d' % options.yuv_frame_width,
|
||||
'--height=%d' % options.yuv_frame_height,
|
||||
]
|
||||
if options.chartjson_result_file:
|
||||
cmd.append('--chartjson_result_file=%s' %
|
||||
options.chartjson_result_file)
|
||||
if options.aligned_output_file:
|
||||
cmd.append('--aligned_output_file=%s' % options.aligned_output_file)
|
||||
if yuv_directory:
|
||||
cmd.append('--yuv_directory=%s' % yuv_directory)
|
||||
frame_analyzer = subprocess.Popen(cmd,
|
||||
stdin=_DevNull(),
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
frame_analyzer.wait()
|
||||
if frame_analyzer.returncode != 0:
|
||||
print('Failed to run frame analyzer.')
|
||||
return frame_analyzer.returncode
|
||||
|
||||
|
||||
def _RunVmaf(options, yuv_directory, logfile):
|
||||
""" Run VMAF to compare videos and print output.
|
||||
""" Run VMAF to compare videos and print output.
|
||||
|
||||
The yuv_directory is assumed to have been populated with a reference and test
|
||||
video in .yuv format, with names according to the label.
|
||||
"""
|
||||
cmd = [
|
||||
options.vmaf,
|
||||
'yuv420p',
|
||||
str(options.yuv_frame_width),
|
||||
str(options.yuv_frame_height),
|
||||
os.path.join(yuv_directory, "ref.yuv"),
|
||||
os.path.join(yuv_directory, "test.yuv"),
|
||||
options.vmaf_model,
|
||||
'--log',
|
||||
logfile,
|
||||
'--log-fmt',
|
||||
'json',
|
||||
]
|
||||
if options.vmaf_phone_model:
|
||||
cmd.append('--phone-model')
|
||||
cmd = [
|
||||
options.vmaf,
|
||||
'yuv420p',
|
||||
str(options.yuv_frame_width),
|
||||
str(options.yuv_frame_height),
|
||||
os.path.join(yuv_directory, "ref.yuv"),
|
||||
os.path.join(yuv_directory, "test.yuv"),
|
||||
options.vmaf_model,
|
||||
'--log',
|
||||
logfile,
|
||||
'--log-fmt',
|
||||
'json',
|
||||
]
|
||||
if options.vmaf_phone_model:
|
||||
cmd.append('--phone-model')
|
||||
|
||||
vmaf = subprocess.Popen(cmd, stdin=_DevNull(),
|
||||
stdout=sys.stdout, stderr=sys.stderr)
|
||||
vmaf.wait()
|
||||
if vmaf.returncode != 0:
|
||||
print('Failed to run VMAF.')
|
||||
return 1
|
||||
vmaf = subprocess.Popen(cmd,
|
||||
stdin=_DevNull(),
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
vmaf.wait()
|
||||
if vmaf.returncode != 0:
|
||||
print('Failed to run VMAF.')
|
||||
return 1
|
||||
|
||||
# Read per-frame scores from VMAF output and print.
|
||||
with open(logfile) as f:
|
||||
vmaf_data = json.load(f)
|
||||
vmaf_scores = []
|
||||
for frame in vmaf_data['frames']:
|
||||
vmaf_scores.append(frame['metrics']['vmaf'])
|
||||
print('RESULT VMAF: %s=' % options.label, vmaf_scores)
|
||||
# Read per-frame scores from VMAF output and print.
|
||||
with open(logfile) as f:
|
||||
vmaf_data = json.load(f)
|
||||
vmaf_scores = []
|
||||
for frame in vmaf_data['frames']:
|
||||
vmaf_scores.append(frame['metrics']['vmaf'])
|
||||
print('RESULT VMAF: %s=' % options.label, vmaf_scores)
|
||||
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
"""The main function.
|
||||
"""The main function.
|
||||
|
||||
A simple invocation is:
|
||||
./webrtc/rtc_tools/compare_videos.py
|
||||
@ -161,27 +183,28 @@ def main():
|
||||
Running vmaf requires the following arguments:
|
||||
--vmaf, --vmaf_model, --yuv_frame_width, --yuv_frame_height
|
||||
"""
|
||||
options = _ParseArgs()
|
||||
options = _ParseArgs()
|
||||
|
||||
if options.vmaf:
|
||||
try:
|
||||
# Directory to save temporary YUV files for VMAF in frame_analyzer.
|
||||
yuv_directory = tempfile.mkdtemp()
|
||||
_, vmaf_logfile = tempfile.mkstemp()
|
||||
if options.vmaf:
|
||||
try:
|
||||
# Directory to save temporary YUV files for VMAF in frame_analyzer.
|
||||
yuv_directory = tempfile.mkdtemp()
|
||||
_, vmaf_logfile = tempfile.mkstemp()
|
||||
|
||||
# Run frame analyzer to compare the videos and print output.
|
||||
if _RunFrameAnalyzer(options, yuv_directory=yuv_directory) != 0:
|
||||
return 1
|
||||
# Run frame analyzer to compare the videos and print output.
|
||||
if _RunFrameAnalyzer(options, yuv_directory=yuv_directory) != 0:
|
||||
return 1
|
||||
|
||||
# Run VMAF for further video comparison and print output.
|
||||
return _RunVmaf(options, yuv_directory, vmaf_logfile)
|
||||
finally:
|
||||
shutil.rmtree(yuv_directory)
|
||||
os.remove(vmaf_logfile)
|
||||
else:
|
||||
return _RunFrameAnalyzer(options)
|
||||
# Run VMAF for further video comparison and print output.
|
||||
return _RunVmaf(options, yuv_directory, vmaf_logfile)
|
||||
finally:
|
||||
shutil.rmtree(yuv_directory)
|
||||
os.remove(vmaf_logfile)
|
||||
else:
|
||||
return _RunFrameAnalyzer(options)
|
||||
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -39,52 +39,59 @@ MICROSECONDS_IN_SECOND = 1e6
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Plots metrics exported from WebRTC perf tests')
|
||||
parser.add_argument('-m', '--metrics', type=str, nargs='*',
|
||||
help='Metrics to plot. If nothing specified then will plot all available')
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Plots metrics exported from WebRTC perf tests')
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
'--metrics',
|
||||
type=str,
|
||||
nargs='*',
|
||||
help=
|
||||
'Metrics to plot. If nothing specified then will plot all available')
|
||||
args = parser.parse_args()
|
||||
|
||||
metrics_to_plot = set()
|
||||
if args.metrics:
|
||||
for metric in args.metrics:
|
||||
metrics_to_plot.add(metric)
|
||||
metrics_to_plot = set()
|
||||
if args.metrics:
|
||||
for metric in args.metrics:
|
||||
metrics_to_plot.add(metric)
|
||||
|
||||
metrics = []
|
||||
for line in fileinput.input('-'):
|
||||
line = line.strip()
|
||||
if line.startswith(LINE_PREFIX):
|
||||
line = line.replace(LINE_PREFIX, '')
|
||||
metrics.append(json.loads(line))
|
||||
else:
|
||||
print line
|
||||
metrics = []
|
||||
for line in fileinput.input('-'):
|
||||
line = line.strip()
|
||||
if line.startswith(LINE_PREFIX):
|
||||
line = line.replace(LINE_PREFIX, '')
|
||||
metrics.append(json.loads(line))
|
||||
else:
|
||||
print line
|
||||
|
||||
for metric in metrics:
|
||||
if len(metrics_to_plot) > 0 and metric[GRAPH_NAME] not in metrics_to_plot:
|
||||
continue
|
||||
for metric in metrics:
|
||||
if len(metrics_to_plot
|
||||
) > 0 and metric[GRAPH_NAME] not in metrics_to_plot:
|
||||
continue
|
||||
|
||||
figure = plt.figure()
|
||||
figure.canvas.set_window_title(metric[TRACE_NAME])
|
||||
figure = plt.figure()
|
||||
figure.canvas.set_window_title(metric[TRACE_NAME])
|
||||
|
||||
x_values = []
|
||||
y_values = []
|
||||
start_x = None
|
||||
samples = metric['samples']
|
||||
samples.sort(key=lambda x: x['time'])
|
||||
for sample in samples:
|
||||
if start_x is None:
|
||||
start_x = sample['time']
|
||||
# Time is us, we want to show it in seconds.
|
||||
x_values.append((sample['time'] - start_x) / MICROSECONDS_IN_SECOND)
|
||||
y_values.append(sample['value'])
|
||||
x_values = []
|
||||
y_values = []
|
||||
start_x = None
|
||||
samples = metric['samples']
|
||||
samples.sort(key=lambda x: x['time'])
|
||||
for sample in samples:
|
||||
if start_x is None:
|
||||
start_x = sample['time']
|
||||
# Time is us, we want to show it in seconds.
|
||||
x_values.append(
|
||||
(sample['time'] - start_x) / MICROSECONDS_IN_SECOND)
|
||||
y_values.append(sample['value'])
|
||||
|
||||
plt.ylabel('%s (%s)' % (metric[GRAPH_NAME], metric[UNITS]))
|
||||
plt.xlabel('time (s)')
|
||||
plt.title(metric[GRAPH_NAME])
|
||||
plt.plot(x_values, y_values)
|
||||
plt.ylabel('%s (%s)' % (metric[GRAPH_NAME], metric[UNITS]))
|
||||
plt.xlabel('time (s)')
|
||||
plt.title(metric[GRAPH_NAME])
|
||||
plt.plot(x_values, y_values)
|
||||
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -10,21 +10,21 @@
|
||||
import network_tester_config_pb2
|
||||
|
||||
|
||||
def AddConfig(all_configs,
|
||||
packet_send_interval_ms,
|
||||
packet_size,
|
||||
def AddConfig(all_configs, packet_send_interval_ms, packet_size,
|
||||
execution_time_ms):
|
||||
config = all_configs.configs.add()
|
||||
config.packet_send_interval_ms = packet_send_interval_ms
|
||||
config.packet_size = packet_size
|
||||
config.execution_time_ms = execution_time_ms
|
||||
config = all_configs.configs.add()
|
||||
config.packet_send_interval_ms = packet_send_interval_ms
|
||||
config.packet_size = packet_size
|
||||
config.execution_time_ms = execution_time_ms
|
||||
|
||||
|
||||
def main():
|
||||
all_configs = network_tester_config_pb2.NetworkTesterAllConfigs()
|
||||
AddConfig(all_configs, 10, 50, 200)
|
||||
AddConfig(all_configs, 10, 100, 200)
|
||||
with open("network_tester_config.dat", 'wb') as f:
|
||||
f.write(all_configs.SerializeToString())
|
||||
all_configs = network_tester_config_pb2.NetworkTesterAllConfigs()
|
||||
AddConfig(all_configs, 10, 50, 200)
|
||||
AddConfig(all_configs, 10, 100, 200)
|
||||
with open("network_tester_config.dat", 'wb') as f:
|
||||
f.write(all_configs.SerializeToString())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -20,128 +20,131 @@ import matplotlib.pyplot as plt
|
||||
|
||||
import network_tester_packet_pb2
|
||||
|
||||
|
||||
def GetSize(file_to_parse):
|
||||
data = file_to_parse.read(1)
|
||||
if data == '':
|
||||
return 0
|
||||
return struct.unpack('<b', data)[0]
|
||||
data = file_to_parse.read(1)
|
||||
if data == '':
|
||||
return 0
|
||||
return struct.unpack('<b', data)[0]
|
||||
|
||||
|
||||
def ParsePacketLog(packet_log_file_to_parse):
|
||||
packets = []
|
||||
with open(packet_log_file_to_parse, 'rb') as file_to_parse:
|
||||
while True:
|
||||
size = GetSize(file_to_parse)
|
||||
if size == 0:
|
||||
break
|
||||
try:
|
||||
packet = network_tester_packet_pb2.NetworkTesterPacket()
|
||||
packet.ParseFromString(file_to_parse.read(size))
|
||||
packets.append(packet)
|
||||
except IOError:
|
||||
break
|
||||
return packets
|
||||
packets = []
|
||||
with open(packet_log_file_to_parse, 'rb') as file_to_parse:
|
||||
while True:
|
||||
size = GetSize(file_to_parse)
|
||||
if size == 0:
|
||||
break
|
||||
try:
|
||||
packet = network_tester_packet_pb2.NetworkTesterPacket()
|
||||
packet.ParseFromString(file_to_parse.read(size))
|
||||
packets.append(packet)
|
||||
except IOError:
|
||||
break
|
||||
return packets
|
||||
|
||||
|
||||
def GetTimeAxis(packets):
|
||||
first_arrival_time = packets[0].arrival_timestamp
|
||||
return [(packet.arrival_timestamp - first_arrival_time) / 1000000.0
|
||||
for packet in packets]
|
||||
first_arrival_time = packets[0].arrival_timestamp
|
||||
return [(packet.arrival_timestamp - first_arrival_time) / 1000000.0
|
||||
for packet in packets]
|
||||
|
||||
|
||||
def CreateSendTimeDiffPlot(packets, plot):
|
||||
first_send_time_diff = (
|
||||
packets[0].arrival_timestamp - packets[0].send_timestamp)
|
||||
y = [(packet.arrival_timestamp - packet.send_timestamp) - first_send_time_diff
|
||||
for packet in packets]
|
||||
plot.grid(True)
|
||||
plot.set_title("SendTime difference [us]")
|
||||
plot.plot(GetTimeAxis(packets), y)
|
||||
first_send_time_diff = (packets[0].arrival_timestamp -
|
||||
packets[0].send_timestamp)
|
||||
y = [(packet.arrival_timestamp - packet.send_timestamp) -
|
||||
first_send_time_diff for packet in packets]
|
||||
plot.grid(True)
|
||||
plot.set_title("SendTime difference [us]")
|
||||
plot.plot(GetTimeAxis(packets), y)
|
||||
|
||||
|
||||
class MovingAverageBitrate(object):
|
||||
def __init__(self):
|
||||
self.packet_window = []
|
||||
self.window_time = 1000000
|
||||
self.bytes = 0
|
||||
self.latest_packet_time = 0
|
||||
self.send_interval = 0
|
||||
|
||||
def __init__(self):
|
||||
self.packet_window = []
|
||||
self.window_time = 1000000
|
||||
self.bytes = 0
|
||||
self.latest_packet_time = 0
|
||||
self.send_interval = 0
|
||||
def RemoveOldPackets(self):
|
||||
for packet in self.packet_window:
|
||||
if (self.latest_packet_time - packet.arrival_timestamp >
|
||||
self.window_time):
|
||||
self.bytes = self.bytes - packet.packet_size
|
||||
self.packet_window.remove(packet)
|
||||
|
||||
def RemoveOldPackets(self):
|
||||
for packet in self.packet_window:
|
||||
if (self.latest_packet_time - packet.arrival_timestamp >
|
||||
self.window_time):
|
||||
self.bytes = self.bytes - packet.packet_size
|
||||
self.packet_window.remove(packet)
|
||||
|
||||
def AddPacket(self, packet):
|
||||
"""This functions returns bits / second"""
|
||||
self.send_interval = packet.arrival_timestamp - self.latest_packet_time
|
||||
self.latest_packet_time = packet.arrival_timestamp
|
||||
self.RemoveOldPackets()
|
||||
self.packet_window.append(packet)
|
||||
self.bytes = self.bytes + packet.packet_size
|
||||
return self.bytes * 8
|
||||
def AddPacket(self, packet):
|
||||
"""This functions returns bits / second"""
|
||||
self.send_interval = packet.arrival_timestamp - self.latest_packet_time
|
||||
self.latest_packet_time = packet.arrival_timestamp
|
||||
self.RemoveOldPackets()
|
||||
self.packet_window.append(packet)
|
||||
self.bytes = self.bytes + packet.packet_size
|
||||
return self.bytes * 8
|
||||
|
||||
|
||||
def CreateReceiveBiratePlot(packets, plot):
|
||||
bitrate = MovingAverageBitrate()
|
||||
y = [bitrate.AddPacket(packet) for packet in packets]
|
||||
plot.grid(True)
|
||||
plot.set_title("Receive birate [bps]")
|
||||
plot.plot(GetTimeAxis(packets), y)
|
||||
bitrate = MovingAverageBitrate()
|
||||
y = [bitrate.AddPacket(packet) for packet in packets]
|
||||
plot.grid(True)
|
||||
plot.set_title("Receive birate [bps]")
|
||||
plot.plot(GetTimeAxis(packets), y)
|
||||
|
||||
|
||||
def CreatePacketlossPlot(packets, plot):
|
||||
packets_look_up = {}
|
||||
first_sequence_number = packets[0].sequence_number
|
||||
last_sequence_number = packets[-1].sequence_number
|
||||
for packet in packets:
|
||||
packets_look_up[packet.sequence_number] = packet
|
||||
y = []
|
||||
x = []
|
||||
first_arrival_time = 0
|
||||
last_arrival_time = 0
|
||||
last_arrival_time_diff = 0
|
||||
for sequence_number in range(first_sequence_number, last_sequence_number + 1):
|
||||
if sequence_number in packets_look_up:
|
||||
y.append(0)
|
||||
if first_arrival_time == 0:
|
||||
first_arrival_time = packets_look_up[sequence_number].arrival_timestamp
|
||||
x_time = (packets_look_up[sequence_number].arrival_timestamp -
|
||||
first_arrival_time)
|
||||
if last_arrival_time != 0:
|
||||
last_arrival_time_diff = x_time - last_arrival_time
|
||||
last_arrival_time = x_time
|
||||
x.append(x_time / 1000000.0)
|
||||
else:
|
||||
if last_arrival_time != 0 and last_arrival_time_diff != 0:
|
||||
x.append((last_arrival_time + last_arrival_time_diff) / 1000000.0)
|
||||
y.append(1)
|
||||
plot.grid(True)
|
||||
plot.set_title("Lost packets [0/1]")
|
||||
plot.plot(x, y)
|
||||
packets_look_up = {}
|
||||
first_sequence_number = packets[0].sequence_number
|
||||
last_sequence_number = packets[-1].sequence_number
|
||||
for packet in packets:
|
||||
packets_look_up[packet.sequence_number] = packet
|
||||
y = []
|
||||
x = []
|
||||
first_arrival_time = 0
|
||||
last_arrival_time = 0
|
||||
last_arrival_time_diff = 0
|
||||
for sequence_number in range(first_sequence_number,
|
||||
last_sequence_number + 1):
|
||||
if sequence_number in packets_look_up:
|
||||
y.append(0)
|
||||
if first_arrival_time == 0:
|
||||
first_arrival_time = packets_look_up[
|
||||
sequence_number].arrival_timestamp
|
||||
x_time = (packets_look_up[sequence_number].arrival_timestamp -
|
||||
first_arrival_time)
|
||||
if last_arrival_time != 0:
|
||||
last_arrival_time_diff = x_time - last_arrival_time
|
||||
last_arrival_time = x_time
|
||||
x.append(x_time / 1000000.0)
|
||||
else:
|
||||
if last_arrival_time != 0 and last_arrival_time_diff != 0:
|
||||
x.append(
|
||||
(last_arrival_time + last_arrival_time_diff) / 1000000.0)
|
||||
y.append(1)
|
||||
plot.grid(True)
|
||||
plot.set_title("Lost packets [0/1]")
|
||||
plot.plot(x, y)
|
||||
|
||||
|
||||
def main():
|
||||
parser = OptionParser()
|
||||
parser.add_option("-f",
|
||||
"--packet_log_file",
|
||||
dest="packet_log_file",
|
||||
help="packet_log file to parse")
|
||||
parser = OptionParser()
|
||||
parser.add_option("-f",
|
||||
"--packet_log_file",
|
||||
dest="packet_log_file",
|
||||
help="packet_log file to parse")
|
||||
|
||||
options = parser.parse_args()[0]
|
||||
options = parser.parse_args()[0]
|
||||
|
||||
packets = ParsePacketLog(options.packet_log_file)
|
||||
f, plots = plt.subplots(3, sharex=True)
|
||||
plt.xlabel('time [sec]')
|
||||
CreateSendTimeDiffPlot(packets, plots[0])
|
||||
CreateReceiveBiratePlot(packets, plots[1])
|
||||
CreatePacketlossPlot(packets, plots[2])
|
||||
f.subplots_adjust(hspace=0.3)
|
||||
plt.show()
|
||||
packets = ParsePacketLog(options.packet_log_file)
|
||||
f, plots = plt.subplots(3, sharex=True)
|
||||
plt.xlabel('time [sec]')
|
||||
CreateSendTimeDiffPlot(packets, plots[0])
|
||||
CreateReceiveBiratePlot(packets, plots[1])
|
||||
CreatePacketlossPlot(packets, plots[2])
|
||||
f.subplots_adjust(hspace=0.3)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Utility functions for calculating statistics.
|
||||
"""
|
||||
|
||||
@ -15,18 +14,17 @@ import sys
|
||||
|
||||
|
||||
def CountReordered(sequence_numbers):
|
||||
"""Returns number of reordered indices.
|
||||
"""Returns number of reordered indices.
|
||||
|
||||
A reordered index is an index `i` for which sequence_numbers[i] >=
|
||||
sequence_numbers[i + 1]
|
||||
"""
|
||||
return sum(1 for (s1, s2) in zip(sequence_numbers,
|
||||
sequence_numbers[1:]) if
|
||||
s1 >= s2)
|
||||
return sum(1 for (s1, s2) in zip(sequence_numbers, sequence_numbers[1:])
|
||||
if s1 >= s2)
|
||||
|
||||
|
||||
def SsrcNormalizedSizeTable(data_points):
|
||||
"""Counts proportion of data for every SSRC.
|
||||
"""Counts proportion of data for every SSRC.
|
||||
|
||||
Args:
|
||||
data_points: list of pb_parse.DataPoint
|
||||
@ -37,14 +35,14 @@ def SsrcNormalizedSizeTable(data_points):
|
||||
SSRC `s` to the total size of all packets.
|
||||
|
||||
"""
|
||||
mapping = collections.defaultdict(int)
|
||||
for point in data_points:
|
||||
mapping[point.ssrc] += point.size
|
||||
return NormalizeCounter(mapping)
|
||||
mapping = collections.defaultdict(int)
|
||||
for point in data_points:
|
||||
mapping[point.ssrc] += point.size
|
||||
return NormalizeCounter(mapping)
|
||||
|
||||
|
||||
def NormalizeCounter(counter):
|
||||
"""Returns a normalized version of the dictionary `counter`.
|
||||
"""Returns a normalized version of the dictionary `counter`.
|
||||
|
||||
Does not modify `counter`.
|
||||
|
||||
@ -52,12 +50,12 @@ def NormalizeCounter(counter):
|
||||
A new dictionary, in which every value in `counter`
|
||||
has been divided by the total to sum up to 1.
|
||||
"""
|
||||
total = sum(counter.values())
|
||||
return {key: counter[key] / total for key in counter}
|
||||
total = sum(counter.values())
|
||||
return {key: counter[key] / total for key in counter}
|
||||
|
||||
|
||||
def Unwrap(data, mod):
|
||||
"""Returns `data` unwrapped modulo `mod`. Does not modify data.
|
||||
"""Returns `data` unwrapped modulo `mod`. Does not modify data.
|
||||
|
||||
Adds integer multiples of mod to all elements of data except the
|
||||
first, such that all pairs of consecutive elements (a, b) satisfy
|
||||
@ -66,22 +64,22 @@ def Unwrap(data, mod):
|
||||
E.g. Unwrap([0, 1, 2, 0, 1, 2, 7, 8], 3) -> [0, 1, 2, 3,
|
||||
4, 5, 4, 5]
|
||||
"""
|
||||
lst = data[:]
|
||||
for i in range(1, len(data)):
|
||||
lst[i] = lst[i - 1] + (lst[i] - lst[i - 1] +
|
||||
mod // 2) % mod - (mod // 2)
|
||||
return lst
|
||||
lst = data[:]
|
||||
for i in range(1, len(data)):
|
||||
lst[i] = lst[i - 1] + (lst[i] - lst[i - 1] + mod // 2) % mod - (mod //
|
||||
2)
|
||||
return lst
|
||||
|
||||
|
||||
def SsrcDirections(data_points):
|
||||
ssrc_is_incoming = {}
|
||||
for point in data_points:
|
||||
ssrc_is_incoming[point.ssrc] = point.incoming
|
||||
return ssrc_is_incoming
|
||||
ssrc_is_incoming = {}
|
||||
for point in data_points:
|
||||
ssrc_is_incoming[point.ssrc] = point.incoming
|
||||
return ssrc_is_incoming
|
||||
|
||||
|
||||
# Python 2/3-compatible input function
|
||||
if sys.version_info[0] <= 2:
|
||||
get_input = raw_input # pylint: disable=invalid-name
|
||||
get_input = raw_input # pylint: disable=invalid-name
|
||||
else:
|
||||
get_input = input # pylint: disable=invalid-name
|
||||
get_input = input # pylint: disable=invalid-name
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Run the tests with
|
||||
|
||||
python misc_test.py
|
||||
@ -22,51 +21,52 @@ import misc
|
||||
|
||||
|
||||
class TestMisc(unittest.TestCase):
|
||||
def testUnwrapMod3(self):
|
||||
data = [0, 1, 2, 0, -1, -2, -3, -4]
|
||||
unwrapped_3 = misc.Unwrap(data, 3)
|
||||
self.assertEqual([0, 1, 2, 3, 2, 1, 0, -1], unwrapped_3)
|
||||
def testUnwrapMod3(self):
|
||||
data = [0, 1, 2, 0, -1, -2, -3, -4]
|
||||
unwrapped_3 = misc.Unwrap(data, 3)
|
||||
self.assertEqual([0, 1, 2, 3, 2, 1, 0, -1], unwrapped_3)
|
||||
|
||||
def testUnwrapMod4(self):
|
||||
data = [0, 1, 2, 0, -1, -2, -3, -4]
|
||||
unwrapped_4 = misc.Unwrap(data, 4)
|
||||
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], unwrapped_4)
|
||||
def testUnwrapMod4(self):
|
||||
data = [0, 1, 2, 0, -1, -2, -3, -4]
|
||||
unwrapped_4 = misc.Unwrap(data, 4)
|
||||
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], unwrapped_4)
|
||||
|
||||
def testDataShouldNotChangeAfterUnwrap(self):
|
||||
data = [0, 1, 2, 0, -1, -2, -3, -4]
|
||||
_ = misc.Unwrap(data, 4)
|
||||
def testDataShouldNotChangeAfterUnwrap(self):
|
||||
data = [0, 1, 2, 0, -1, -2, -3, -4]
|
||||
_ = misc.Unwrap(data, 4)
|
||||
|
||||
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], data)
|
||||
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], data)
|
||||
|
||||
def testRandomlyMultiplesOfModAdded(self):
|
||||
# `unwrap` definition says only multiples of mod are added.
|
||||
random_data = [random.randint(0, 9) for _ in range(100)]
|
||||
def testRandomlyMultiplesOfModAdded(self):
|
||||
# `unwrap` definition says only multiples of mod are added.
|
||||
random_data = [random.randint(0, 9) for _ in range(100)]
|
||||
|
||||
for mod in range(1, 100):
|
||||
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
|
||||
for mod in range(1, 100):
|
||||
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
|
||||
|
||||
for (old_a, a) in zip(random_data, random_data_unwrapped_mod):
|
||||
self.assertEqual((old_a - a) % mod, 0)
|
||||
for (old_a, a) in zip(random_data, random_data_unwrapped_mod):
|
||||
self.assertEqual((old_a - a) % mod, 0)
|
||||
|
||||
def testRandomlyAgainstInequalityDefinition(self):
|
||||
# Data has to satisfy -mod/2 <= difference < mod/2 for every
|
||||
# difference between consecutive values after unwrap.
|
||||
random_data = [random.randint(0, 9) for _ in range(100)]
|
||||
def testRandomlyAgainstInequalityDefinition(self):
|
||||
# Data has to satisfy -mod/2 <= difference < mod/2 for every
|
||||
# difference between consecutive values after unwrap.
|
||||
random_data = [random.randint(0, 9) for _ in range(100)]
|
||||
|
||||
for mod in range(1, 100):
|
||||
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
|
||||
for mod in range(1, 100):
|
||||
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
|
||||
|
||||
for (a, b) in zip(random_data_unwrapped_mod,
|
||||
random_data_unwrapped_mod[1:]):
|
||||
self.assertTrue(-mod / 2 <= b - a < mod / 2)
|
||||
for (a, b) in zip(random_data_unwrapped_mod,
|
||||
random_data_unwrapped_mod[1:]):
|
||||
self.assertTrue(-mod / 2 <= b - a < mod / 2)
|
||||
|
||||
def testRandomlyDataShouldNotChangeAfterUnwrap(self):
|
||||
random_data = [random.randint(0, 9) for _ in range(100)]
|
||||
random_data_copy = random_data[:]
|
||||
for mod in range(1, 100):
|
||||
_ = misc.Unwrap(random_data, mod)
|
||||
def testRandomlyDataShouldNotChangeAfterUnwrap(self):
|
||||
random_data = [random.randint(0, 9) for _ in range(100)]
|
||||
random_data_copy = random_data[:]
|
||||
for mod in range(1, 100):
|
||||
_ = misc.Unwrap(random_data, mod)
|
||||
|
||||
self.assertEqual(random_data, random_data_copy)
|
||||
|
||||
self.assertEqual(random_data, random_data_copy)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Parses protobuf RTC dumps."""
|
||||
|
||||
from __future__ import division
|
||||
@ -14,26 +13,26 @@ import pyproto.logging.rtc_event_log.rtc_event_log_pb2 as rtc_pb
|
||||
|
||||
|
||||
class DataPoint(object):
|
||||
"""Simple container class for RTP events."""
|
||||
"""Simple container class for RTP events."""
|
||||
|
||||
def __init__(self, rtp_header_str, packet_size,
|
||||
arrival_timestamp_us, incoming):
|
||||
"""Builds a data point by parsing an RTP header, size and arrival time.
|
||||
def __init__(self, rtp_header_str, packet_size, arrival_timestamp_us,
|
||||
incoming):
|
||||
"""Builds a data point by parsing an RTP header, size and arrival time.
|
||||
|
||||
RTP header structure is defined in RFC 3550 section 5.1.
|
||||
"""
|
||||
self.size = packet_size
|
||||
self.arrival_timestamp_ms = arrival_timestamp_us / 1000
|
||||
self.incoming = incoming
|
||||
header = struct.unpack_from("!HHII", rtp_header_str, 0)
|
||||
(first2header_bytes, self.sequence_number, self.timestamp,
|
||||
self.ssrc) = header
|
||||
self.payload_type = first2header_bytes & 0b01111111
|
||||
self.marker_bit = (first2header_bytes & 0b10000000) >> 7
|
||||
self.size = packet_size
|
||||
self.arrival_timestamp_ms = arrival_timestamp_us / 1000
|
||||
self.incoming = incoming
|
||||
header = struct.unpack_from("!HHII", rtp_header_str, 0)
|
||||
(first2header_bytes, self.sequence_number, self.timestamp,
|
||||
self.ssrc) = header
|
||||
self.payload_type = first2header_bytes & 0b01111111
|
||||
self.marker_bit = (first2header_bytes & 0b10000000) >> 7
|
||||
|
||||
|
||||
def ParseProtobuf(file_path):
|
||||
"""Parses RTC event log from protobuf file.
|
||||
"""Parses RTC event log from protobuf file.
|
||||
|
||||
Args:
|
||||
file_path: path to protobuf file of RTC event stream
|
||||
@ -41,12 +40,12 @@ def ParseProtobuf(file_path):
|
||||
Returns:
|
||||
all RTP packet events from the event stream as a list of DataPoints
|
||||
"""
|
||||
event_stream = rtc_pb.EventStream()
|
||||
with open(file_path, "rb") as f:
|
||||
event_stream.ParseFromString(f.read())
|
||||
event_stream = rtc_pb.EventStream()
|
||||
with open(file_path, "rb") as f:
|
||||
event_stream.ParseFromString(f.read())
|
||||
|
||||
return [DataPoint(event.rtp_packet.header,
|
||||
event.rtp_packet.packet_length,
|
||||
event.timestamp_us, event.rtp_packet.incoming)
|
||||
for event in event_stream.stream
|
||||
if event.HasField("rtp_packet")]
|
||||
return [
|
||||
DataPoint(event.rtp_packet.header, event.rtp_packet.packet_length,
|
||||
event.timestamp_us, event.rtp_packet.incoming)
|
||||
for event in event_stream.stream if event.HasField("rtp_packet")
|
||||
]
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Displays statistics and plots graphs from RTC protobuf dump."""
|
||||
|
||||
from __future__ import division
|
||||
@ -24,13 +23,13 @@ import pb_parse
|
||||
|
||||
|
||||
class RTPStatistics(object):
|
||||
"""Has methods for calculating and plotting RTP stream statistics."""
|
||||
"""Has methods for calculating and plotting RTP stream statistics."""
|
||||
|
||||
BANDWIDTH_SMOOTHING_WINDOW_SIZE = 10
|
||||
PLOT_RESOLUTION_MS = 50
|
||||
BANDWIDTH_SMOOTHING_WINDOW_SIZE = 10
|
||||
PLOT_RESOLUTION_MS = 50
|
||||
|
||||
def __init__(self, data_points):
|
||||
"""Initializes object with data_points and computes simple statistics.
|
||||
def __init__(self, data_points):
|
||||
"""Initializes object with data_points and computes simple statistics.
|
||||
|
||||
Computes percentages of number of packets and packet sizes by
|
||||
SSRC.
|
||||
@ -41,238 +40,245 @@ class RTPStatistics(object):
|
||||
|
||||
"""
|
||||
|
||||
self.data_points = data_points
|
||||
self.ssrc_frequencies = misc.NormalizeCounter(
|
||||
collections.Counter([pt.ssrc for pt in self.data_points]))
|
||||
self.ssrc_size_table = misc.SsrcNormalizedSizeTable(self.data_points)
|
||||
self.bandwidth_kbps = None
|
||||
self.smooth_bw_kbps = None
|
||||
self.data_points = data_points
|
||||
self.ssrc_frequencies = misc.NormalizeCounter(
|
||||
collections.Counter([pt.ssrc for pt in self.data_points]))
|
||||
self.ssrc_size_table = misc.SsrcNormalizedSizeTable(self.data_points)
|
||||
self.bandwidth_kbps = None
|
||||
self.smooth_bw_kbps = None
|
||||
|
||||
def PrintHeaderStatistics(self):
|
||||
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
|
||||
"SeqNo", "TimeStamp", "SendTime", "Size", "PT", "M", "SSRC"))
|
||||
for point in self.data_points:
|
||||
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
|
||||
point.sequence_number, point.timestamp,
|
||||
int(point.arrival_timestamp_ms), point.size, point.payload_type,
|
||||
point.marker_bit, "0x{:x}".format(point.ssrc)))
|
||||
def PrintHeaderStatistics(self):
|
||||
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
|
||||
"SeqNo", "TimeStamp", "SendTime", "Size", "PT", "M", "SSRC"))
|
||||
for point in self.data_points:
|
||||
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
|
||||
point.sequence_number, point.timestamp,
|
||||
int(point.arrival_timestamp_ms), point.size,
|
||||
point.payload_type, point.marker_bit,
|
||||
"0x{:x}".format(point.ssrc)))
|
||||
|
||||
def PrintSsrcInfo(self, ssrc_id, ssrc):
|
||||
"""Prints packet and size statistics for a given SSRC.
|
||||
def PrintSsrcInfo(self, ssrc_id, ssrc):
|
||||
"""Prints packet and size statistics for a given SSRC.
|
||||
|
||||
Args:
|
||||
ssrc_id: textual identifier of SSRC printed beside statistics for it.
|
||||
ssrc: SSRC by which to filter data and display statistics
|
||||
"""
|
||||
filtered_ssrc = [point for point in self.data_points if point.ssrc
|
||||
== ssrc]
|
||||
payloads = misc.NormalizeCounter(
|
||||
collections.Counter([point.payload_type for point in
|
||||
filtered_ssrc]))
|
||||
filtered_ssrc = [
|
||||
point for point in self.data_points if point.ssrc == ssrc
|
||||
]
|
||||
payloads = misc.NormalizeCounter(
|
||||
collections.Counter(
|
||||
[point.payload_type for point in filtered_ssrc]))
|
||||
|
||||
payload_info = "payload type(s): {}".format(
|
||||
", ".join(str(payload) for payload in payloads))
|
||||
print("{} 0x{:x} {}, {:.2f}% packets, {:.2f}% data".format(
|
||||
ssrc_id, ssrc, payload_info, self.ssrc_frequencies[ssrc] * 100,
|
||||
self.ssrc_size_table[ssrc] * 100))
|
||||
print(" packet sizes:")
|
||||
(bin_counts, bin_bounds) = numpy.histogram([point.size for point in
|
||||
filtered_ssrc], bins=5,
|
||||
density=False)
|
||||
bin_proportions = bin_counts / sum(bin_counts)
|
||||
print("\n".join([
|
||||
" {:.1f} - {:.1f}: {:.2f}%".format(bin_bounds[i], bin_bounds[i + 1],
|
||||
bin_proportions[i] * 100)
|
||||
for i in range(len(bin_proportions))
|
||||
]))
|
||||
payload_info = "payload type(s): {}".format(", ".join(
|
||||
str(payload) for payload in payloads))
|
||||
print("{} 0x{:x} {}, {:.2f}% packets, {:.2f}% data".format(
|
||||
ssrc_id, ssrc, payload_info, self.ssrc_frequencies[ssrc] * 100,
|
||||
self.ssrc_size_table[ssrc] * 100))
|
||||
print(" packet sizes:")
|
||||
(bin_counts,
|
||||
bin_bounds) = numpy.histogram([point.size for point in filtered_ssrc],
|
||||
bins=5,
|
||||
density=False)
|
||||
bin_proportions = bin_counts / sum(bin_counts)
|
||||
print("\n".join([
|
||||
" {:.1f} - {:.1f}: {:.2f}%".format(bin_bounds[i],
|
||||
bin_bounds[i + 1],
|
||||
bin_proportions[i] * 100)
|
||||
for i in range(len(bin_proportions))
|
||||
]))
|
||||
|
||||
def ChooseSsrc(self):
|
||||
"""Queries user for SSRC."""
|
||||
def ChooseSsrc(self):
|
||||
"""Queries user for SSRC."""
|
||||
|
||||
if len(self.ssrc_frequencies) == 1:
|
||||
chosen_ssrc = self.ssrc_frequencies.keys()[0]
|
||||
self.PrintSsrcInfo("", chosen_ssrc)
|
||||
return chosen_ssrc
|
||||
if len(self.ssrc_frequencies) == 1:
|
||||
chosen_ssrc = self.ssrc_frequencies.keys()[0]
|
||||
self.PrintSsrcInfo("", chosen_ssrc)
|
||||
return chosen_ssrc
|
||||
|
||||
ssrc_is_incoming = misc.SsrcDirections(self.data_points)
|
||||
incoming = [ssrc for ssrc in ssrc_is_incoming if ssrc_is_incoming[ssrc]]
|
||||
outgoing = [ssrc for ssrc in ssrc_is_incoming if not ssrc_is_incoming[ssrc]]
|
||||
ssrc_is_incoming = misc.SsrcDirections(self.data_points)
|
||||
incoming = [
|
||||
ssrc for ssrc in ssrc_is_incoming if ssrc_is_incoming[ssrc]
|
||||
]
|
||||
outgoing = [
|
||||
ssrc for ssrc in ssrc_is_incoming if not ssrc_is_incoming[ssrc]
|
||||
]
|
||||
|
||||
print("\nIncoming:\n")
|
||||
for (i, ssrc) in enumerate(incoming):
|
||||
self.PrintSsrcInfo(i, ssrc)
|
||||
print("\nIncoming:\n")
|
||||
for (i, ssrc) in enumerate(incoming):
|
||||
self.PrintSsrcInfo(i, ssrc)
|
||||
|
||||
print("\nOutgoing:\n")
|
||||
for (i, ssrc) in enumerate(outgoing):
|
||||
self.PrintSsrcInfo(i + len(incoming), ssrc)
|
||||
print("\nOutgoing:\n")
|
||||
for (i, ssrc) in enumerate(outgoing):
|
||||
self.PrintSsrcInfo(i + len(incoming), ssrc)
|
||||
|
||||
while True:
|
||||
chosen_index = int(misc.get_input("choose one> "))
|
||||
if 0 <= chosen_index < len(self.ssrc_frequencies):
|
||||
return (incoming + outgoing)[chosen_index]
|
||||
else:
|
||||
print("Invalid index!")
|
||||
while True:
|
||||
chosen_index = int(misc.get_input("choose one> "))
|
||||
if 0 <= chosen_index < len(self.ssrc_frequencies):
|
||||
return (incoming + outgoing)[chosen_index]
|
||||
else:
|
||||
print("Invalid index!")
|
||||
|
||||
def FilterSsrc(self, chosen_ssrc):
|
||||
"""Filters and wraps data points.
|
||||
def FilterSsrc(self, chosen_ssrc):
|
||||
"""Filters and wraps data points.
|
||||
|
||||
Removes data points with `ssrc != chosen_ssrc`. Unwraps sequence
|
||||
numbers and timestamps for the chosen selection.
|
||||
"""
|
||||
self.data_points = [point for point in self.data_points if
|
||||
point.ssrc == chosen_ssrc]
|
||||
unwrapped_sequence_numbers = misc.Unwrap(
|
||||
[point.sequence_number for point in self.data_points], 2**16 - 1)
|
||||
for (data_point, sequence_number) in zip(self.data_points,
|
||||
unwrapped_sequence_numbers):
|
||||
data_point.sequence_number = sequence_number
|
||||
self.data_points = [
|
||||
point for point in self.data_points if point.ssrc == chosen_ssrc
|
||||
]
|
||||
unwrapped_sequence_numbers = misc.Unwrap(
|
||||
[point.sequence_number for point in self.data_points], 2**16 - 1)
|
||||
for (data_point, sequence_number) in zip(self.data_points,
|
||||
unwrapped_sequence_numbers):
|
||||
data_point.sequence_number = sequence_number
|
||||
|
||||
unwrapped_timestamps = misc.Unwrap([point.timestamp for point in
|
||||
self.data_points], 2**32 - 1)
|
||||
unwrapped_timestamps = misc.Unwrap(
|
||||
[point.timestamp for point in self.data_points], 2**32 - 1)
|
||||
|
||||
for (data_point, timestamp) in zip(self.data_points,
|
||||
unwrapped_timestamps):
|
||||
data_point.timestamp = timestamp
|
||||
for (data_point, timestamp) in zip(self.data_points,
|
||||
unwrapped_timestamps):
|
||||
data_point.timestamp = timestamp
|
||||
|
||||
def PrintSequenceNumberStatistics(self):
|
||||
seq_no_set = set(point.sequence_number for point in
|
||||
self.data_points)
|
||||
missing_sequence_numbers = max(seq_no_set) - min(seq_no_set) + (
|
||||
1 - len(seq_no_set))
|
||||
print("Missing sequence numbers: {} out of {} ({:.2f}%)".format(
|
||||
missing_sequence_numbers,
|
||||
len(seq_no_set),
|
||||
100 * missing_sequence_numbers / len(seq_no_set)
|
||||
))
|
||||
print("Duplicated packets: {}".format(len(self.data_points) -
|
||||
len(seq_no_set)))
|
||||
print("Reordered packets: {}".format(
|
||||
misc.CountReordered([point.sequence_number for point in
|
||||
self.data_points])))
|
||||
def PrintSequenceNumberStatistics(self):
|
||||
seq_no_set = set(point.sequence_number for point in self.data_points)
|
||||
missing_sequence_numbers = max(seq_no_set) - min(seq_no_set) + (
|
||||
1 - len(seq_no_set))
|
||||
print("Missing sequence numbers: {} out of {} ({:.2f}%)".format(
|
||||
missing_sequence_numbers, len(seq_no_set),
|
||||
100 * missing_sequence_numbers / len(seq_no_set)))
|
||||
print("Duplicated packets: {}".format(
|
||||
len(self.data_points) - len(seq_no_set)))
|
||||
print("Reordered packets: {}".format(
|
||||
misc.CountReordered(
|
||||
[point.sequence_number for point in self.data_points])))
|
||||
|
||||
def EstimateFrequency(self, always_query_sample_rate):
|
||||
"""Estimates frequency and updates data.
|
||||
def EstimateFrequency(self, always_query_sample_rate):
|
||||
"""Estimates frequency and updates data.
|
||||
|
||||
Guesses the most probable frequency by looking at changes in
|
||||
timestamps (RFC 3550 section 5.1), calculates clock drifts and
|
||||
sending time of packets. Updates `self.data_points` with changes
|
||||
in delay and send time.
|
||||
"""
|
||||
delta_timestamp = (self.data_points[-1].timestamp -
|
||||
self.data_points[0].timestamp)
|
||||
delta_arr_timestamp = float((self.data_points[-1].arrival_timestamp_ms -
|
||||
self.data_points[0].arrival_timestamp_ms))
|
||||
freq_est = delta_timestamp / delta_arr_timestamp
|
||||
delta_timestamp = (self.data_points[-1].timestamp -
|
||||
self.data_points[0].timestamp)
|
||||
delta_arr_timestamp = float(
|
||||
(self.data_points[-1].arrival_timestamp_ms -
|
||||
self.data_points[0].arrival_timestamp_ms))
|
||||
freq_est = delta_timestamp / delta_arr_timestamp
|
||||
|
||||
freq_vec = [8, 16, 32, 48, 90]
|
||||
freq = None
|
||||
for f in freq_vec:
|
||||
if abs((freq_est - f) / f) < 0.05:
|
||||
freq = f
|
||||
freq_vec = [8, 16, 32, 48, 90]
|
||||
freq = None
|
||||
for f in freq_vec:
|
||||
if abs((freq_est - f) / f) < 0.05:
|
||||
freq = f
|
||||
|
||||
print("Estimated frequency: {:.3f}kHz".format(freq_est))
|
||||
if freq is None or always_query_sample_rate:
|
||||
if not always_query_sample_rate:
|
||||
print ("Frequency could not be guessed.", end=" ")
|
||||
freq = int(misc.get_input("Input frequency (in kHz)> "))
|
||||
else:
|
||||
print("Guessed frequency: {}kHz".format(freq))
|
||||
print("Estimated frequency: {:.3f}kHz".format(freq_est))
|
||||
if freq is None or always_query_sample_rate:
|
||||
if not always_query_sample_rate:
|
||||
print("Frequency could not be guessed.", end=" ")
|
||||
freq = int(misc.get_input("Input frequency (in kHz)> "))
|
||||
else:
|
||||
print("Guessed frequency: {}kHz".format(freq))
|
||||
|
||||
for point in self.data_points:
|
||||
point.real_send_time_ms = (point.timestamp -
|
||||
self.data_points[0].timestamp) / freq
|
||||
point.delay = point.arrival_timestamp_ms - point.real_send_time_ms
|
||||
for point in self.data_points:
|
||||
point.real_send_time_ms = (point.timestamp -
|
||||
self.data_points[0].timestamp) / freq
|
||||
point.delay = point.arrival_timestamp_ms - point.real_send_time_ms
|
||||
|
||||
def PrintDurationStatistics(self):
|
||||
"""Prints delay, clock drift and bitrate statistics."""
|
||||
def PrintDurationStatistics(self):
|
||||
"""Prints delay, clock drift and bitrate statistics."""
|
||||
|
||||
min_delay = min(point.delay for point in self.data_points)
|
||||
min_delay = min(point.delay for point in self.data_points)
|
||||
|
||||
for point in self.data_points:
|
||||
point.absdelay = point.delay - min_delay
|
||||
for point in self.data_points:
|
||||
point.absdelay = point.delay - min_delay
|
||||
|
||||
stream_duration_sender = self.data_points[-1].real_send_time_ms / 1000
|
||||
print("Stream duration at sender: {:.1f} seconds".format(
|
||||
stream_duration_sender
|
||||
))
|
||||
stream_duration_sender = self.data_points[-1].real_send_time_ms / 1000
|
||||
print("Stream duration at sender: {:.1f} seconds".format(
|
||||
stream_duration_sender))
|
||||
|
||||
arrival_timestamps_ms = [point.arrival_timestamp_ms for point in
|
||||
self.data_points]
|
||||
stream_duration_receiver = (max(arrival_timestamps_ms) -
|
||||
min(arrival_timestamps_ms)) / 1000
|
||||
print("Stream duration at receiver: {:.1f} seconds".format(
|
||||
stream_duration_receiver
|
||||
))
|
||||
arrival_timestamps_ms = [
|
||||
point.arrival_timestamp_ms for point in self.data_points
|
||||
]
|
||||
stream_duration_receiver = (max(arrival_timestamps_ms) -
|
||||
min(arrival_timestamps_ms)) / 1000
|
||||
print("Stream duration at receiver: {:.1f} seconds".format(
|
||||
stream_duration_receiver))
|
||||
|
||||
print("Clock drift: {:.2f}%".format(
|
||||
100 * (stream_duration_receiver / stream_duration_sender - 1)
|
||||
))
|
||||
print("Clock drift: {:.2f}%".format(
|
||||
100 * (stream_duration_receiver / stream_duration_sender - 1)))
|
||||
|
||||
total_size = sum(point.size for point in self.data_points) * 8 / 1000
|
||||
print("Send average bitrate: {:.2f} kbps".format(
|
||||
total_size / stream_duration_sender))
|
||||
total_size = sum(point.size for point in self.data_points) * 8 / 1000
|
||||
print("Send average bitrate: {:.2f} kbps".format(
|
||||
total_size / stream_duration_sender))
|
||||
|
||||
print("Receive average bitrate: {:.2f} kbps".format(
|
||||
total_size / stream_duration_receiver))
|
||||
print("Receive average bitrate: {:.2f} kbps".format(
|
||||
total_size / stream_duration_receiver))
|
||||
|
||||
def RemoveReordered(self):
|
||||
last = self.data_points[0]
|
||||
data_points_ordered = [last]
|
||||
for point in self.data_points[1:]:
|
||||
if point.sequence_number > last.sequence_number and (
|
||||
point.real_send_time_ms > last.real_send_time_ms):
|
||||
data_points_ordered.append(point)
|
||||
last = point
|
||||
self.data_points = data_points_ordered
|
||||
def RemoveReordered(self):
|
||||
last = self.data_points[0]
|
||||
data_points_ordered = [last]
|
||||
for point in self.data_points[1:]:
|
||||
if point.sequence_number > last.sequence_number and (
|
||||
point.real_send_time_ms > last.real_send_time_ms):
|
||||
data_points_ordered.append(point)
|
||||
last = point
|
||||
self.data_points = data_points_ordered
|
||||
|
||||
def ComputeBandwidth(self):
|
||||
"""Computes bandwidth averaged over several consecutive packets.
|
||||
def ComputeBandwidth(self):
|
||||
"""Computes bandwidth averaged over several consecutive packets.
|
||||
|
||||
The number of consecutive packets used in the average is
|
||||
BANDWIDTH_SMOOTHING_WINDOW_SIZE. Averaging is done with
|
||||
numpy.correlate.
|
||||
"""
|
||||
start_ms = self.data_points[0].real_send_time_ms
|
||||
stop_ms = self.data_points[-1].real_send_time_ms
|
||||
(self.bandwidth_kbps, _) = numpy.histogram(
|
||||
[point.real_send_time_ms for point in self.data_points],
|
||||
bins=numpy.arange(start_ms, stop_ms,
|
||||
RTPStatistics.PLOT_RESOLUTION_MS),
|
||||
weights=[point.size * 8 / RTPStatistics.PLOT_RESOLUTION_MS
|
||||
for point in self.data_points]
|
||||
)
|
||||
correlate_filter = (numpy.ones(
|
||||
RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE) /
|
||||
RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE)
|
||||
self.smooth_bw_kbps = numpy.correlate(self.bandwidth_kbps, correlate_filter)
|
||||
start_ms = self.data_points[0].real_send_time_ms
|
||||
stop_ms = self.data_points[-1].real_send_time_ms
|
||||
(self.bandwidth_kbps, _) = numpy.histogram(
|
||||
[point.real_send_time_ms for point in self.data_points],
|
||||
bins=numpy.arange(start_ms, stop_ms,
|
||||
RTPStatistics.PLOT_RESOLUTION_MS),
|
||||
weights=[
|
||||
point.size * 8 / RTPStatistics.PLOT_RESOLUTION_MS
|
||||
for point in self.data_points
|
||||
])
|
||||
correlate_filter = (
|
||||
numpy.ones(RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE) /
|
||||
RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE)
|
||||
self.smooth_bw_kbps = numpy.correlate(self.bandwidth_kbps,
|
||||
correlate_filter)
|
||||
|
||||
def PlotStatistics(self):
|
||||
"""Plots changes in delay and average bandwidth."""
|
||||
def PlotStatistics(self):
|
||||
"""Plots changes in delay and average bandwidth."""
|
||||
|
||||
start_ms = self.data_points[0].real_send_time_ms
|
||||
stop_ms = self.data_points[-1].real_send_time_ms
|
||||
time_axis = numpy.arange(start_ms / 1000, stop_ms / 1000,
|
||||
RTPStatistics.PLOT_RESOLUTION_MS / 1000)
|
||||
start_ms = self.data_points[0].real_send_time_ms
|
||||
stop_ms = self.data_points[-1].real_send_time_ms
|
||||
time_axis = numpy.arange(start_ms / 1000, stop_ms / 1000,
|
||||
RTPStatistics.PLOT_RESOLUTION_MS / 1000)
|
||||
|
||||
delay = CalculateDelay(start_ms, stop_ms,
|
||||
RTPStatistics.PLOT_RESOLUTION_MS,
|
||||
self.data_points)
|
||||
delay = CalculateDelay(start_ms, stop_ms,
|
||||
RTPStatistics.PLOT_RESOLUTION_MS,
|
||||
self.data_points)
|
||||
|
||||
plt.figure(1)
|
||||
plt.plot(time_axis, delay[:len(time_axis)])
|
||||
plt.xlabel("Send time [s]")
|
||||
plt.ylabel("Relative transport delay [ms]")
|
||||
plt.figure(1)
|
||||
plt.plot(time_axis, delay[:len(time_axis)])
|
||||
plt.xlabel("Send time [s]")
|
||||
plt.ylabel("Relative transport delay [ms]")
|
||||
|
||||
plt.figure(2)
|
||||
plt.plot(time_axis[:len(self.smooth_bw_kbps)], self.smooth_bw_kbps)
|
||||
plt.xlabel("Send time [s]")
|
||||
plt.ylabel("Bandwidth [kbps]")
|
||||
plt.figure(2)
|
||||
plt.plot(time_axis[:len(self.smooth_bw_kbps)], self.smooth_bw_kbps)
|
||||
plt.xlabel("Send time [s]")
|
||||
plt.ylabel("Bandwidth [kbps]")
|
||||
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
|
||||
def CalculateDelay(start, stop, step, points):
|
||||
"""Quantizes the time coordinates for the delay.
|
||||
"""Quantizes the time coordinates for the delay.
|
||||
|
||||
Quantizes points by rounding the timestamps downwards to the nearest
|
||||
point in the time sequence start, start+step, start+2*step... Takes
|
||||
@ -280,61 +286,67 @@ def CalculateDelay(start, stop, step, points):
|
||||
masked array, in which time points with no value are masked.
|
||||
|
||||
"""
|
||||
grouped_delays = [[] for _ in numpy.arange(start, stop + step, step)]
|
||||
rounded_value_index = lambda x: int((x - start) / step)
|
||||
for point in points:
|
||||
grouped_delays[rounded_value_index(point.real_send_time_ms)
|
||||
].append(point.absdelay)
|
||||
regularized_delays = [numpy.average(arr) if arr else -1 for arr in
|
||||
grouped_delays]
|
||||
return numpy.ma.masked_values(regularized_delays, -1)
|
||||
grouped_delays = [[] for _ in numpy.arange(start, stop + step, step)]
|
||||
rounded_value_index = lambda x: int((x - start) / step)
|
||||
for point in points:
|
||||
grouped_delays[rounded_value_index(point.real_send_time_ms)].append(
|
||||
point.absdelay)
|
||||
regularized_delays = [
|
||||
numpy.average(arr) if arr else -1 for arr in grouped_delays
|
||||
]
|
||||
return numpy.ma.masked_values(regularized_delays, -1)
|
||||
|
||||
|
||||
def main():
|
||||
usage = "Usage: %prog [options] <filename of rtc event log>"
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
parser.add_option("--dump_header_to_stdout",
|
||||
default=False, action="store_true",
|
||||
help="print header info to stdout; similar to rtp_analyze")
|
||||
parser.add_option("--query_sample_rate",
|
||||
default=False, action="store_true",
|
||||
help="always query user for real sample rate")
|
||||
usage = "Usage: %prog [options] <filename of rtc event log>"
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
parser.add_option(
|
||||
"--dump_header_to_stdout",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="print header info to stdout; similar to rtp_analyze")
|
||||
parser.add_option("--query_sample_rate",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="always query user for real sample rate")
|
||||
|
||||
parser.add_option("--working_directory",
|
||||
default=None, action="store",
|
||||
help="directory in which to search for relative paths")
|
||||
parser.add_option("--working_directory",
|
||||
default=None,
|
||||
action="store",
|
||||
help="directory in which to search for relative paths")
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
if len(args) < 1:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
if len(args) < 1:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
input_file = args[0]
|
||||
input_file = args[0]
|
||||
|
||||
if options.working_directory and not os.path.isabs(input_file):
|
||||
input_file = os.path.join(options.working_directory, input_file)
|
||||
if options.working_directory and not os.path.isabs(input_file):
|
||||
input_file = os.path.join(options.working_directory, input_file)
|
||||
|
||||
data_points = pb_parse.ParseProtobuf(input_file)
|
||||
rtp_stats = RTPStatistics(data_points)
|
||||
data_points = pb_parse.ParseProtobuf(input_file)
|
||||
rtp_stats = RTPStatistics(data_points)
|
||||
|
||||
if options.dump_header_to_stdout:
|
||||
print("Printing header info to stdout.", file=sys.stderr)
|
||||
rtp_stats.PrintHeaderStatistics()
|
||||
sys.exit(0)
|
||||
if options.dump_header_to_stdout:
|
||||
print("Printing header info to stdout.", file=sys.stderr)
|
||||
rtp_stats.PrintHeaderStatistics()
|
||||
sys.exit(0)
|
||||
|
||||
chosen_ssrc = rtp_stats.ChooseSsrc()
|
||||
print("Chosen SSRC: 0X{:X}".format(chosen_ssrc))
|
||||
chosen_ssrc = rtp_stats.ChooseSsrc()
|
||||
print("Chosen SSRC: 0X{:X}".format(chosen_ssrc))
|
||||
|
||||
rtp_stats.FilterSsrc(chosen_ssrc)
|
||||
rtp_stats.FilterSsrc(chosen_ssrc)
|
||||
|
||||
print("Statistics:")
|
||||
rtp_stats.PrintSequenceNumberStatistics()
|
||||
rtp_stats.EstimateFrequency(options.query_sample_rate)
|
||||
rtp_stats.PrintDurationStatistics()
|
||||
rtp_stats.RemoveReordered()
|
||||
rtp_stats.ComputeBandwidth()
|
||||
rtp_stats.PlotStatistics()
|
||||
|
||||
print("Statistics:")
|
||||
rtp_stats.PrintSequenceNumberStatistics()
|
||||
rtp_stats.EstimateFrequency(options.query_sample_rate)
|
||||
rtp_stats.PrintDurationStatistics()
|
||||
rtp_stats.RemoveReordered()
|
||||
rtp_stats.ComputeBandwidth()
|
||||
rtp_stats.PlotStatistics()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Run the tests with
|
||||
|
||||
python rtp_analyzer_test.py
|
||||
@ -19,43 +18,43 @@ import unittest
|
||||
|
||||
MISSING_NUMPY = False # pylint: disable=invalid-name
|
||||
try:
|
||||
import numpy
|
||||
import rtp_analyzer
|
||||
import numpy
|
||||
import rtp_analyzer
|
||||
except ImportError:
|
||||
MISSING_NUMPY = True
|
||||
MISSING_NUMPY = True
|
||||
|
||||
FakePoint = collections.namedtuple("FakePoint",
|
||||
["real_send_time_ms", "absdelay"])
|
||||
|
||||
|
||||
class TestDelay(unittest.TestCase):
|
||||
def AssertMaskEqual(self, masked_array, data, mask):
|
||||
self.assertEqual(list(masked_array.data), data)
|
||||
def AssertMaskEqual(self, masked_array, data, mask):
|
||||
self.assertEqual(list(masked_array.data), data)
|
||||
|
||||
if isinstance(masked_array.mask, numpy.bool_):
|
||||
array_mask = masked_array.mask
|
||||
else:
|
||||
array_mask = list(masked_array.mask)
|
||||
self.assertEqual(array_mask, mask)
|
||||
if isinstance(masked_array.mask, numpy.bool_):
|
||||
array_mask = masked_array.mask
|
||||
else:
|
||||
array_mask = list(masked_array.mask)
|
||||
self.assertEqual(array_mask, mask)
|
||||
|
||||
def testCalculateDelaySimple(self):
|
||||
points = [FakePoint(0, 0), FakePoint(1, 0)]
|
||||
mask = rtp_analyzer.CalculateDelay(0, 1, 1, points)
|
||||
self.AssertMaskEqual(mask, [0, 0], False)
|
||||
def testCalculateDelaySimple(self):
|
||||
points = [FakePoint(0, 0), FakePoint(1, 0)]
|
||||
mask = rtp_analyzer.CalculateDelay(0, 1, 1, points)
|
||||
self.AssertMaskEqual(mask, [0, 0], False)
|
||||
|
||||
def testCalculateDelayMissing(self):
|
||||
points = [FakePoint(0, 0), FakePoint(2, 0)]
|
||||
mask = rtp_analyzer.CalculateDelay(0, 2, 1, points)
|
||||
self.AssertMaskEqual(mask, [0, -1, 0], [False, True, False])
|
||||
def testCalculateDelayMissing(self):
|
||||
points = [FakePoint(0, 0), FakePoint(2, 0)]
|
||||
mask = rtp_analyzer.CalculateDelay(0, 2, 1, points)
|
||||
self.AssertMaskEqual(mask, [0, -1, 0], [False, True, False])
|
||||
|
||||
def testCalculateDelayBorders(self):
|
||||
points = [FakePoint(0, 0), FakePoint(2, 0)]
|
||||
mask = rtp_analyzer.CalculateDelay(0, 3, 2, points)
|
||||
self.AssertMaskEqual(mask, [0, 0, -1], [False, False, True])
|
||||
def testCalculateDelayBorders(self):
|
||||
points = [FakePoint(0, 0), FakePoint(2, 0)]
|
||||
mask = rtp_analyzer.CalculateDelay(0, 3, 2, points)
|
||||
self.AssertMaskEqual(mask, [0, 0, -1], [False, False, True])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if MISSING_NUMPY:
|
||||
print "Missing numpy, skipping test."
|
||||
else:
|
||||
unittest.main()
|
||||
if MISSING_NUMPY:
|
||||
print "Missing numpy, skipping test."
|
||||
else:
|
||||
unittest.main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Builds the AppRTC collider using the golang toolchain.
|
||||
|
||||
The golang toolchain is downloaded by download_apprtc.py. We use that here
|
||||
@ -24,44 +23,44 @@ import sys
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
USAGE_STR = "Usage: {} <apprtc_dir> <go_dir> <output_dir>"
|
||||
|
||||
|
||||
def _ConfigureApprtcServerToDeveloperMode(app_yaml_path):
|
||||
for line in fileinput.input(app_yaml_path, inplace=True):
|
||||
# We can't click past these in browser-based tests, so disable them.
|
||||
line = line.replace('BYPASS_JOIN_CONFIRMATION: false',
|
||||
'BYPASS_JOIN_CONFIRMATION: true')
|
||||
sys.stdout.write(line)
|
||||
for line in fileinput.input(app_yaml_path, inplace=True):
|
||||
# We can't click past these in browser-based tests, so disable them.
|
||||
line = line.replace('BYPASS_JOIN_CONFIRMATION: false',
|
||||
'BYPASS_JOIN_CONFIRMATION: true')
|
||||
sys.stdout.write(line)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) != 4:
|
||||
return USAGE_STR.format(argv[0])
|
||||
if len(argv) != 4:
|
||||
return USAGE_STR.format(argv[0])
|
||||
|
||||
apprtc_dir = os.path.abspath(argv[1])
|
||||
go_root_dir = os.path.abspath(argv[2])
|
||||
golang_workspace = os.path.abspath(argv[3])
|
||||
apprtc_dir = os.path.abspath(argv[1])
|
||||
go_root_dir = os.path.abspath(argv[2])
|
||||
golang_workspace = os.path.abspath(argv[3])
|
||||
|
||||
app_yaml_path = os.path.join(apprtc_dir, 'out', 'app_engine', 'app.yaml')
|
||||
_ConfigureApprtcServerToDeveloperMode(app_yaml_path)
|
||||
app_yaml_path = os.path.join(apprtc_dir, 'out', 'app_engine', 'app.yaml')
|
||||
_ConfigureApprtcServerToDeveloperMode(app_yaml_path)
|
||||
|
||||
utils.RemoveDirectory(golang_workspace)
|
||||
utils.RemoveDirectory(golang_workspace)
|
||||
|
||||
collider_dir = os.path.join(apprtc_dir, 'src', 'collider')
|
||||
shutil.copytree(collider_dir, os.path.join(golang_workspace, 'src'))
|
||||
collider_dir = os.path.join(apprtc_dir, 'src', 'collider')
|
||||
shutil.copytree(collider_dir, os.path.join(golang_workspace, 'src'))
|
||||
|
||||
golang_path = os.path.join(go_root_dir, 'bin',
|
||||
'go' + utils.GetExecutableExtension())
|
||||
golang_env = os.environ.copy()
|
||||
golang_env['GOROOT'] = go_root_dir
|
||||
golang_env['GOPATH'] = golang_workspace
|
||||
collider_out = os.path.join(golang_workspace,
|
||||
'collidermain' + utils.GetExecutableExtension())
|
||||
subprocess.check_call([golang_path, 'build', '-o', collider_out,
|
||||
'collidermain'], env=golang_env)
|
||||
golang_path = os.path.join(go_root_dir, 'bin',
|
||||
'go' + utils.GetExecutableExtension())
|
||||
golang_env = os.environ.copy()
|
||||
golang_env['GOROOT'] = go_root_dir
|
||||
golang_env['GOPATH'] = golang_workspace
|
||||
collider_out = os.path.join(
|
||||
golang_workspace, 'collidermain' + utils.GetExecutableExtension())
|
||||
subprocess.check_call(
|
||||
[golang_path, 'build', '-o', collider_out, 'collidermain'],
|
||||
env=golang_env)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv))
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Downloads prebuilt AppRTC and Go from WebRTC storage and unpacks it.
|
||||
|
||||
Requires that depot_tools is installed and in the PATH.
|
||||
@ -21,38 +20,37 @@ import sys
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def _GetGoArchivePathForPlatform():
|
||||
archive_extension = 'zip' if utils.GetPlatform() == 'win' else 'tar.gz'
|
||||
return os.path.join(utils.GetPlatform(), 'go.%s' % archive_extension)
|
||||
archive_extension = 'zip' if utils.GetPlatform() == 'win' else 'tar.gz'
|
||||
return os.path.join(utils.GetPlatform(), 'go.%s' % archive_extension)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 2:
|
||||
return 'Usage: %s [output_dir]' % argv[0]
|
||||
if len(argv) > 2:
|
||||
return 'Usage: %s [output_dir]' % argv[0]
|
||||
|
||||
output_dir = os.path.abspath(argv[1]) if len(argv) > 1 else None
|
||||
output_dir = os.path.abspath(argv[1]) if len(argv) > 1 else None
|
||||
|
||||
apprtc_zip_path = os.path.join(SCRIPT_DIR, 'prebuilt_apprtc.zip')
|
||||
if os.path.isfile(apprtc_zip_path + '.sha1'):
|
||||
utils.DownloadFilesFromGoogleStorage(SCRIPT_DIR, auto_platform=False)
|
||||
apprtc_zip_path = os.path.join(SCRIPT_DIR, 'prebuilt_apprtc.zip')
|
||||
if os.path.isfile(apprtc_zip_path + '.sha1'):
|
||||
utils.DownloadFilesFromGoogleStorage(SCRIPT_DIR, auto_platform=False)
|
||||
|
||||
if output_dir is not None:
|
||||
utils.RemoveDirectory(os.path.join(output_dir, 'apprtc'))
|
||||
utils.UnpackArchiveTo(apprtc_zip_path, output_dir)
|
||||
if output_dir is not None:
|
||||
utils.RemoveDirectory(os.path.join(output_dir, 'apprtc'))
|
||||
utils.UnpackArchiveTo(apprtc_zip_path, output_dir)
|
||||
|
||||
golang_path = os.path.join(SCRIPT_DIR, 'golang')
|
||||
golang_zip_path = os.path.join(golang_path, _GetGoArchivePathForPlatform())
|
||||
if os.path.isfile(golang_zip_path + '.sha1'):
|
||||
utils.DownloadFilesFromGoogleStorage(golang_path)
|
||||
golang_path = os.path.join(SCRIPT_DIR, 'golang')
|
||||
golang_zip_path = os.path.join(golang_path, _GetGoArchivePathForPlatform())
|
||||
if os.path.isfile(golang_zip_path + '.sha1'):
|
||||
utils.DownloadFilesFromGoogleStorage(golang_path)
|
||||
|
||||
if output_dir is not None:
|
||||
utils.RemoveDirectory(os.path.join(output_dir, 'go'))
|
||||
utils.UnpackArchiveTo(golang_zip_path, output_dir)
|
||||
if output_dir is not None:
|
||||
utils.RemoveDirectory(os.path.join(output_dir, 'go'))
|
||||
utils.UnpackArchiveTo(golang_zip_path, output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv))
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""This script sets up AppRTC and its dependencies.
|
||||
|
||||
Requires that depot_tools is installed and in the PATH.
|
||||
@ -19,27 +18,26 @@ import sys
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) == 1:
|
||||
return 'Usage %s <output_dir>' % argv[0]
|
||||
if len(argv) == 1:
|
||||
return 'Usage %s <output_dir>' % argv[0]
|
||||
|
||||
output_dir = os.path.abspath(argv[1])
|
||||
output_dir = os.path.abspath(argv[1])
|
||||
|
||||
download_apprtc_path = os.path.join(SCRIPT_DIR, 'download_apprtc.py')
|
||||
utils.RunSubprocessWithRetry([sys.executable, download_apprtc_path,
|
||||
output_dir])
|
||||
download_apprtc_path = os.path.join(SCRIPT_DIR, 'download_apprtc.py')
|
||||
utils.RunSubprocessWithRetry(
|
||||
[sys.executable, download_apprtc_path, output_dir])
|
||||
|
||||
build_apprtc_path = os.path.join(SCRIPT_DIR, 'build_apprtc.py')
|
||||
apprtc_dir = os.path.join(output_dir, 'apprtc')
|
||||
go_dir = os.path.join(output_dir, 'go')
|
||||
collider_dir = os.path.join(output_dir, 'collider')
|
||||
utils.RunSubprocessWithRetry([sys.executable, build_apprtc_path,
|
||||
apprtc_dir, go_dir, collider_dir])
|
||||
build_apprtc_path = os.path.join(SCRIPT_DIR, 'build_apprtc.py')
|
||||
apprtc_dir = os.path.join(output_dir, 'apprtc')
|
||||
go_dir = os.path.join(output_dir, 'go')
|
||||
collider_dir = os.path.join(output_dir, 'collider')
|
||||
utils.RunSubprocessWithRetry(
|
||||
[sys.executable, build_apprtc_path, apprtc_dir, go_dir, collider_dir])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv))
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Utilities for all our deps-management stuff."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -23,36 +22,37 @@ import zipfile
|
||||
|
||||
|
||||
def RunSubprocessWithRetry(cmd):
|
||||
"""Invokes the subprocess and backs off exponentially on fail."""
|
||||
for i in range(5):
|
||||
try:
|
||||
subprocess.check_call(cmd)
|
||||
return
|
||||
except subprocess.CalledProcessError as exception:
|
||||
backoff = pow(2, i)
|
||||
print('Got %s, retrying in %d seconds...' % (exception, backoff))
|
||||
time.sleep(backoff)
|
||||
"""Invokes the subprocess and backs off exponentially on fail."""
|
||||
for i in range(5):
|
||||
try:
|
||||
subprocess.check_call(cmd)
|
||||
return
|
||||
except subprocess.CalledProcessError as exception:
|
||||
backoff = pow(2, i)
|
||||
print('Got %s, retrying in %d seconds...' % (exception, backoff))
|
||||
time.sleep(backoff)
|
||||
|
||||
print('Giving up.')
|
||||
raise exception
|
||||
print('Giving up.')
|
||||
raise exception
|
||||
|
||||
|
||||
def DownloadFilesFromGoogleStorage(path, auto_platform=True):
|
||||
print('Downloading files in %s...' % path)
|
||||
print('Downloading files in %s...' % path)
|
||||
|
||||
extension = 'bat' if 'win32' in sys.platform else 'py'
|
||||
cmd = ['download_from_google_storage.%s' % extension,
|
||||
'--bucket=chromium-webrtc-resources',
|
||||
'--directory', path]
|
||||
if auto_platform:
|
||||
cmd += ['--auto_platform', '--recursive']
|
||||
subprocess.check_call(cmd)
|
||||
extension = 'bat' if 'win32' in sys.platform else 'py'
|
||||
cmd = [
|
||||
'download_from_google_storage.%s' % extension,
|
||||
'--bucket=chromium-webrtc-resources', '--directory', path
|
||||
]
|
||||
if auto_platform:
|
||||
cmd += ['--auto_platform', '--recursive']
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
# Code partially copied from
|
||||
# https://cs.chromium.org#chromium/build/scripts/common/chromium_utils.py
|
||||
def RemoveDirectory(*path):
|
||||
"""Recursively removes a directory, even if it's marked read-only.
|
||||
"""Recursively removes a directory, even if it's marked read-only.
|
||||
|
||||
Remove the directory located at *path, if it exists.
|
||||
|
||||
@ -67,62 +67,63 @@ def RemoveDirectory(*path):
|
||||
bit and try again, so we do that too. It's hand-waving, but sometimes it
|
||||
works. :/
|
||||
"""
|
||||
file_path = os.path.join(*path)
|
||||
print('Deleting `{}`.'.format(file_path))
|
||||
if not os.path.exists(file_path):
|
||||
print('`{}` does not exist.'.format(file_path))
|
||||
return
|
||||
file_path = os.path.join(*path)
|
||||
print('Deleting `{}`.'.format(file_path))
|
||||
if not os.path.exists(file_path):
|
||||
print('`{}` does not exist.'.format(file_path))
|
||||
return
|
||||
|
||||
if sys.platform == 'win32':
|
||||
# Give up and use cmd.exe's rd command.
|
||||
file_path = os.path.normcase(file_path)
|
||||
for _ in range(3):
|
||||
print('RemoveDirectory running %s' % (' '.join(
|
||||
['cmd.exe', '/c', 'rd', '/q', '/s', file_path])))
|
||||
if not subprocess.call(['cmd.exe', '/c', 'rd', '/q', '/s', file_path]):
|
||||
break
|
||||
print(' Failed')
|
||||
time.sleep(3)
|
||||
return
|
||||
else:
|
||||
shutil.rmtree(file_path, ignore_errors=True)
|
||||
if sys.platform == 'win32':
|
||||
# Give up and use cmd.exe's rd command.
|
||||
file_path = os.path.normcase(file_path)
|
||||
for _ in range(3):
|
||||
print('RemoveDirectory running %s' %
|
||||
(' '.join(['cmd.exe', '/c', 'rd', '/q', '/s', file_path])))
|
||||
if not subprocess.call(
|
||||
['cmd.exe', '/c', 'rd', '/q', '/s', file_path]):
|
||||
break
|
||||
print(' Failed')
|
||||
time.sleep(3)
|
||||
return
|
||||
else:
|
||||
shutil.rmtree(file_path, ignore_errors=True)
|
||||
|
||||
|
||||
def UnpackArchiveTo(archive_path, output_dir):
|
||||
extension = os.path.splitext(archive_path)[1]
|
||||
if extension == '.zip':
|
||||
_UnzipArchiveTo(archive_path, output_dir)
|
||||
else:
|
||||
_UntarArchiveTo(archive_path, output_dir)
|
||||
extension = os.path.splitext(archive_path)[1]
|
||||
if extension == '.zip':
|
||||
_UnzipArchiveTo(archive_path, output_dir)
|
||||
else:
|
||||
_UntarArchiveTo(archive_path, output_dir)
|
||||
|
||||
|
||||
def _UnzipArchiveTo(archive_path, output_dir):
|
||||
print('Unzipping {} in {}.'.format(archive_path, output_dir))
|
||||
zip_file = zipfile.ZipFile(archive_path)
|
||||
try:
|
||||
zip_file.extractall(output_dir)
|
||||
finally:
|
||||
zip_file.close()
|
||||
print('Unzipping {} in {}.'.format(archive_path, output_dir))
|
||||
zip_file = zipfile.ZipFile(archive_path)
|
||||
try:
|
||||
zip_file.extractall(output_dir)
|
||||
finally:
|
||||
zip_file.close()
|
||||
|
||||
|
||||
def _UntarArchiveTo(archive_path, output_dir):
|
||||
print('Untarring {} in {}.'.format(archive_path, output_dir))
|
||||
tar_file = tarfile.open(archive_path, 'r:gz')
|
||||
try:
|
||||
tar_file.extractall(output_dir)
|
||||
finally:
|
||||
tar_file.close()
|
||||
print('Untarring {} in {}.'.format(archive_path, output_dir))
|
||||
tar_file = tarfile.open(archive_path, 'r:gz')
|
||||
try:
|
||||
tar_file.extractall(output_dir)
|
||||
finally:
|
||||
tar_file.close()
|
||||
|
||||
|
||||
def GetPlatform():
|
||||
if sys.platform.startswith('win'):
|
||||
return 'win'
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
if sys.platform.startswith('darwin'):
|
||||
return 'mac'
|
||||
raise Exception("Can't run on platform %s." % sys.platform)
|
||||
if sys.platform.startswith('win'):
|
||||
return 'win'
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
if sys.platform.startswith('darwin'):
|
||||
return 'mac'
|
||||
raise Exception("Can't run on platform %s." % sys.platform)
|
||||
|
||||
|
||||
def GetExecutableExtension():
|
||||
return '.exe' if GetPlatform() == 'win' else ''
|
||||
return '.exe' if GetPlatform() == 'win' else ''
|
||||
|
||||
@ -6,23 +6,26 @@
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
|
||||
def CheckChangeOnUpload(input_api, output_api):
|
||||
results = []
|
||||
results.extend(CheckPatchFormatted(input_api, output_api))
|
||||
return results
|
||||
results = []
|
||||
results.extend(CheckPatchFormatted(input_api, output_api))
|
||||
return results
|
||||
|
||||
|
||||
def CheckPatchFormatted(input_api, output_api):
|
||||
import git_cl
|
||||
cmd = ['cl', 'format', '--dry-run', input_api.PresubmitLocalPath()]
|
||||
code, _ = git_cl.RunGitWithCode(cmd, suppress_stderr=True)
|
||||
if code == 2:
|
||||
short_path = input_api.basename(input_api.PresubmitLocalPath())
|
||||
full_path = input_api.os_path.relpath(input_api.PresubmitLocalPath(),
|
||||
input_api.change.RepositoryRoot())
|
||||
return [output_api.PresubmitPromptWarning(
|
||||
'The %s directory requires source formatting. '
|
||||
'Please run git cl format %s' %
|
||||
(short_path, full_path))]
|
||||
# As this is just a warning, ignore all other errors if the user
|
||||
# happens to have a broken clang-format, doesn't use git, etc etc.
|
||||
return []
|
||||
import git_cl
|
||||
cmd = ['cl', 'format', '--dry-run', input_api.PresubmitLocalPath()]
|
||||
code, _ = git_cl.RunGitWithCode(cmd, suppress_stderr=True)
|
||||
if code == 2:
|
||||
short_path = input_api.basename(input_api.PresubmitLocalPath())
|
||||
full_path = input_api.os_path.relpath(
|
||||
input_api.PresubmitLocalPath(), input_api.change.RepositoryRoot())
|
||||
return [
|
||||
output_api.PresubmitPromptWarning(
|
||||
'The %s directory requires source formatting. '
|
||||
'Please run git cl format %s' % (short_path, full_path))
|
||||
]
|
||||
# As this is just a warning, ignore all other errors if the user
|
||||
# happens to have a broken clang-format, doesn't use git, etc etc.
|
||||
return []
|
||||
|
||||
@ -8,39 +8,43 @@
|
||||
|
||||
|
||||
def _LicenseHeader(input_api):
|
||||
"""Returns the license header regexp."""
|
||||
# Accept any year number from 2003 to the current year
|
||||
current_year = int(input_api.time.strftime('%Y'))
|
||||
allowed_years = (str(s) for s in reversed(xrange(2003, current_year + 1)))
|
||||
years_re = '(' + '|'.join(allowed_years) + ')'
|
||||
license_header = (
|
||||
r'.*? Copyright( \(c\))? %(year)s The WebRTC [Pp]roject [Aa]uthors\. '
|
||||
"""Returns the license header regexp."""
|
||||
# Accept any year number from 2003 to the current year
|
||||
current_year = int(input_api.time.strftime('%Y'))
|
||||
allowed_years = (str(s) for s in reversed(xrange(2003, current_year + 1)))
|
||||
years_re = '(' + '|'.join(allowed_years) + ')'
|
||||
license_header = (
|
||||
r'.*? Copyright( \(c\))? %(year)s The WebRTC [Pp]roject [Aa]uthors\. '
|
||||
r'All [Rr]ights [Rr]eserved\.\n'
|
||||
r'.*?\n'
|
||||
r'.*? Use of this source code is governed by a BSD-style license\n'
|
||||
r'.*? that can be found in the LICENSE file in the root of the source\n'
|
||||
r'.*? tree\. An additional intellectual property rights grant can be '
|
||||
r'.*?\n'
|
||||
r'.*? Use of this source code is governed by a BSD-style license\n'
|
||||
r'.*? that can be found in the LICENSE file in the root of the source\n'
|
||||
r'.*? tree\. An additional intellectual property rights grant can be '
|
||||
r'found\n'
|
||||
r'.*? in the file PATENTS\. All contributing project authors may\n'
|
||||
r'.*? be found in the AUTHORS file in the root of the source tree\.\n'
|
||||
) % {
|
||||
'year': years_re,
|
||||
}
|
||||
return license_header
|
||||
r'.*? in the file PATENTS\. All contributing project authors may\n'
|
||||
r'.*? be found in the AUTHORS file in the root of the source tree\.\n'
|
||||
) % {
|
||||
'year': years_re,
|
||||
}
|
||||
return license_header
|
||||
|
||||
|
||||
def _CommonChecks(input_api, output_api):
|
||||
"""Checks common to both upload and commit."""
|
||||
results = []
|
||||
results.extend(input_api.canned_checks.CheckLicense(
|
||||
input_api, output_api, _LicenseHeader(input_api)))
|
||||
return results
|
||||
"""Checks common to both upload and commit."""
|
||||
results = []
|
||||
results.extend(
|
||||
input_api.canned_checks.CheckLicense(input_api, output_api,
|
||||
_LicenseHeader(input_api)))
|
||||
return results
|
||||
|
||||
|
||||
def CheckChangeOnUpload(input_api, output_api):
|
||||
results = []
|
||||
results.extend(_CommonChecks(input_api, output_api))
|
||||
return results
|
||||
results = []
|
||||
results.extend(_CommonChecks(input_api, output_api))
|
||||
return results
|
||||
|
||||
|
||||
def CheckChangeOnCommit(input_api, output_api):
|
||||
results = []
|
||||
results.extend(_CommonChecks(input_api, output_api))
|
||||
return results
|
||||
results = []
|
||||
results.extend(_CommonChecks(input_api, output_api))
|
||||
return results
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Script to generate libwebrtc.aar for distribution.
|
||||
|
||||
The script has to be run from the root src folder.
|
||||
@ -33,7 +32,6 @@ import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
|
||||
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
|
||||
DEFAULT_ARCHS = ['armeabi-v7a', 'arm64-v8a', 'x86', 'x86_64']
|
||||
@ -41,8 +39,8 @@ NEEDED_SO_FILES = ['libjingle_peerconnection_so.so']
|
||||
JAR_FILE = 'lib.java/sdk/android/libwebrtc.jar'
|
||||
MANIFEST_FILE = 'sdk/android/AndroidManifest.xml'
|
||||
TARGETS = [
|
||||
'sdk/android:libwebrtc',
|
||||
'sdk/android:libjingle_peerconnection_so',
|
||||
'sdk/android:libwebrtc',
|
||||
'sdk/android:libjingle_peerconnection_so',
|
||||
]
|
||||
|
||||
sys.path.append(os.path.join(SCRIPT_DIR, '..', 'libs'))
|
||||
@ -52,183 +50,209 @@ sys.path.append(os.path.join(SRC_DIR, 'build'))
|
||||
import find_depot_tools
|
||||
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
parser = argparse.ArgumentParser(description='libwebrtc.aar generator.')
|
||||
parser.add_argument('--build-dir',
|
||||
help='Build dir. By default will create and use temporary dir.')
|
||||
parser.add_argument('--output', default='libwebrtc.aar',
|
||||
help='Output file of the script.')
|
||||
parser.add_argument('--arch', default=DEFAULT_ARCHS, nargs='*',
|
||||
help='Architectures to build. Defaults to %(default)s.')
|
||||
parser.add_argument('--use-goma', action='store_true', default=False,
|
||||
help='Use goma.')
|
||||
parser.add_argument('--verbose', action='store_true', default=False,
|
||||
help='Debug logging.')
|
||||
parser.add_argument('--extra-gn-args', default=[], nargs='*',
|
||||
help="""Additional GN arguments to be used during Ninja generation.
|
||||
parser = argparse.ArgumentParser(description='libwebrtc.aar generator.')
|
||||
parser.add_argument(
|
||||
'--build-dir',
|
||||
help='Build dir. By default will create and use temporary dir.')
|
||||
parser.add_argument('--output',
|
||||
default='libwebrtc.aar',
|
||||
help='Output file of the script.')
|
||||
parser.add_argument(
|
||||
'--arch',
|
||||
default=DEFAULT_ARCHS,
|
||||
nargs='*',
|
||||
help='Architectures to build. Defaults to %(default)s.')
|
||||
parser.add_argument('--use-goma',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use goma.')
|
||||
parser.add_argument('--verbose',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Debug logging.')
|
||||
parser.add_argument(
|
||||
'--extra-gn-args',
|
||||
default=[],
|
||||
nargs='*',
|
||||
help="""Additional GN arguments to be used during Ninja generation.
|
||||
These are passed to gn inside `--args` switch and
|
||||
applied after any other arguments and will
|
||||
override any values defined by the script.
|
||||
Example of building debug aar file:
|
||||
build_aar.py --extra-gn-args='is_debug=true'""")
|
||||
parser.add_argument('--extra-ninja-switches', default=[], nargs='*',
|
||||
help="""Additional Ninja switches to be used during compilation.
|
||||
parser.add_argument(
|
||||
'--extra-ninja-switches',
|
||||
default=[],
|
||||
nargs='*',
|
||||
help="""Additional Ninja switches to be used during compilation.
|
||||
These are applied after any other Ninja switches.
|
||||
Example of enabling verbose Ninja output:
|
||||
build_aar.py --extra-ninja-switches='-v'""")
|
||||
parser.add_argument('--extra-gn-switches', default=[], nargs='*',
|
||||
help="""Additional GN switches to be used during compilation.
|
||||
parser.add_argument(
|
||||
'--extra-gn-switches',
|
||||
default=[],
|
||||
nargs='*',
|
||||
help="""Additional GN switches to be used during compilation.
|
||||
These are applied after any other GN switches.
|
||||
Example of enabling verbose GN output:
|
||||
build_aar.py --extra-gn-switches='-v'""")
|
||||
return parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _RunGN(args):
|
||||
cmd = [sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')]
|
||||
cmd.extend(args)
|
||||
logging.debug('Running: %r', cmd)
|
||||
subprocess.check_call(cmd)
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')
|
||||
]
|
||||
cmd.extend(args)
|
||||
logging.debug('Running: %r', cmd)
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def _RunNinja(output_directory, args):
|
||||
cmd = [os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'),
|
||||
'-C', output_directory]
|
||||
cmd.extend(args)
|
||||
logging.debug('Running: %r', cmd)
|
||||
subprocess.check_call(cmd)
|
||||
cmd = [
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'), '-C',
|
||||
output_directory
|
||||
]
|
||||
cmd.extend(args)
|
||||
logging.debug('Running: %r', cmd)
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def _EncodeForGN(value):
|
||||
"""Encodes value as a GN literal."""
|
||||
if isinstance(value, str):
|
||||
return '"' + value + '"'
|
||||
elif isinstance(value, bool):
|
||||
return repr(value).lower()
|
||||
else:
|
||||
return repr(value)
|
||||
"""Encodes value as a GN literal."""
|
||||
if isinstance(value, str):
|
||||
return '"' + value + '"'
|
||||
elif isinstance(value, bool):
|
||||
return repr(value).lower()
|
||||
else:
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _GetOutputDirectory(build_dir, arch):
|
||||
"""Returns the GN output directory for the target architecture."""
|
||||
return os.path.join(build_dir, arch)
|
||||
"""Returns the GN output directory for the target architecture."""
|
||||
return os.path.join(build_dir, arch)
|
||||
|
||||
|
||||
def _GetTargetCpu(arch):
|
||||
"""Returns target_cpu for the GN build with the given architecture."""
|
||||
if arch in ['armeabi', 'armeabi-v7a']:
|
||||
return 'arm'
|
||||
elif arch == 'arm64-v8a':
|
||||
return 'arm64'
|
||||
elif arch == 'x86':
|
||||
return 'x86'
|
||||
elif arch == 'x86_64':
|
||||
return 'x64'
|
||||
else:
|
||||
raise Exception('Unknown arch: ' + arch)
|
||||
"""Returns target_cpu for the GN build with the given architecture."""
|
||||
if arch in ['armeabi', 'armeabi-v7a']:
|
||||
return 'arm'
|
||||
elif arch == 'arm64-v8a':
|
||||
return 'arm64'
|
||||
elif arch == 'x86':
|
||||
return 'x86'
|
||||
elif arch == 'x86_64':
|
||||
return 'x64'
|
||||
else:
|
||||
raise Exception('Unknown arch: ' + arch)
|
||||
|
||||
|
||||
def _GetArmVersion(arch):
|
||||
"""Returns arm_version for the GN build with the given architecture."""
|
||||
if arch == 'armeabi':
|
||||
return 6
|
||||
elif arch == 'armeabi-v7a':
|
||||
return 7
|
||||
elif arch in ['arm64-v8a', 'x86', 'x86_64']:
|
||||
return None
|
||||
else:
|
||||
raise Exception('Unknown arch: ' + arch)
|
||||
"""Returns arm_version for the GN build with the given architecture."""
|
||||
if arch == 'armeabi':
|
||||
return 6
|
||||
elif arch == 'armeabi-v7a':
|
||||
return 7
|
||||
elif arch in ['arm64-v8a', 'x86', 'x86_64']:
|
||||
return None
|
||||
else:
|
||||
raise Exception('Unknown arch: ' + arch)
|
||||
|
||||
|
||||
def Build(build_dir, arch, use_goma, extra_gn_args, extra_gn_switches,
|
||||
extra_ninja_switches):
|
||||
"""Generates target architecture using GN and builds it using ninja."""
|
||||
logging.info('Building: %s', arch)
|
||||
output_directory = _GetOutputDirectory(build_dir, arch)
|
||||
gn_args = {
|
||||
'target_os': 'android',
|
||||
'is_debug': False,
|
||||
'is_component_build': False,
|
||||
'rtc_include_tests': False,
|
||||
'target_cpu': _GetTargetCpu(arch),
|
||||
'use_goma': use_goma
|
||||
}
|
||||
arm_version = _GetArmVersion(arch)
|
||||
if arm_version:
|
||||
gn_args['arm_version'] = arm_version
|
||||
gn_args_str = '--args=' + ' '.join([
|
||||
k + '=' + _EncodeForGN(v) for k, v in gn_args.items()] + extra_gn_args)
|
||||
"""Generates target architecture using GN and builds it using ninja."""
|
||||
logging.info('Building: %s', arch)
|
||||
output_directory = _GetOutputDirectory(build_dir, arch)
|
||||
gn_args = {
|
||||
'target_os': 'android',
|
||||
'is_debug': False,
|
||||
'is_component_build': False,
|
||||
'rtc_include_tests': False,
|
||||
'target_cpu': _GetTargetCpu(arch),
|
||||
'use_goma': use_goma
|
||||
}
|
||||
arm_version = _GetArmVersion(arch)
|
||||
if arm_version:
|
||||
gn_args['arm_version'] = arm_version
|
||||
gn_args_str = '--args=' + ' '.join(
|
||||
[k + '=' + _EncodeForGN(v)
|
||||
for k, v in gn_args.items()] + extra_gn_args)
|
||||
|
||||
gn_args_list = ['gen', output_directory, gn_args_str]
|
||||
gn_args_list.extend(extra_gn_switches)
|
||||
_RunGN(gn_args_list)
|
||||
gn_args_list = ['gen', output_directory, gn_args_str]
|
||||
gn_args_list.extend(extra_gn_switches)
|
||||
_RunGN(gn_args_list)
|
||||
|
||||
ninja_args = TARGETS[:]
|
||||
if use_goma:
|
||||
ninja_args.extend(['-j', '200'])
|
||||
ninja_args.extend(extra_ninja_switches)
|
||||
_RunNinja(output_directory, ninja_args)
|
||||
ninja_args = TARGETS[:]
|
||||
if use_goma:
|
||||
ninja_args.extend(['-j', '200'])
|
||||
ninja_args.extend(extra_ninja_switches)
|
||||
_RunNinja(output_directory, ninja_args)
|
||||
|
||||
|
||||
def CollectCommon(aar_file, build_dir, arch):
|
||||
"""Collects architecture independent files into the .aar-archive."""
|
||||
logging.info('Collecting common files.')
|
||||
output_directory = _GetOutputDirectory(build_dir, arch)
|
||||
aar_file.write(MANIFEST_FILE, 'AndroidManifest.xml')
|
||||
aar_file.write(os.path.join(output_directory, JAR_FILE), 'classes.jar')
|
||||
"""Collects architecture independent files into the .aar-archive."""
|
||||
logging.info('Collecting common files.')
|
||||
output_directory = _GetOutputDirectory(build_dir, arch)
|
||||
aar_file.write(MANIFEST_FILE, 'AndroidManifest.xml')
|
||||
aar_file.write(os.path.join(output_directory, JAR_FILE), 'classes.jar')
|
||||
|
||||
|
||||
def Collect(aar_file, build_dir, arch):
|
||||
"""Collects architecture specific files into the .aar-archive."""
|
||||
logging.info('Collecting: %s', arch)
|
||||
output_directory = _GetOutputDirectory(build_dir, arch)
|
||||
"""Collects architecture specific files into the .aar-archive."""
|
||||
logging.info('Collecting: %s', arch)
|
||||
output_directory = _GetOutputDirectory(build_dir, arch)
|
||||
|
||||
abi_dir = os.path.join('jni', arch)
|
||||
for so_file in NEEDED_SO_FILES:
|
||||
aar_file.write(os.path.join(output_directory, so_file),
|
||||
os.path.join(abi_dir, so_file))
|
||||
abi_dir = os.path.join('jni', arch)
|
||||
for so_file in NEEDED_SO_FILES:
|
||||
aar_file.write(os.path.join(output_directory, so_file),
|
||||
os.path.join(abi_dir, so_file))
|
||||
|
||||
|
||||
def GenerateLicenses(output_dir, build_dir, archs):
|
||||
builder = LicenseBuilder(
|
||||
[_GetOutputDirectory(build_dir, arch) for arch in archs], TARGETS)
|
||||
builder.GenerateLicenseText(output_dir)
|
||||
builder = LicenseBuilder(
|
||||
[_GetOutputDirectory(build_dir, arch) for arch in archs], TARGETS)
|
||||
builder.GenerateLicenseText(output_dir)
|
||||
|
||||
|
||||
def BuildAar(archs, output_file, use_goma=False, extra_gn_args=None,
|
||||
ext_build_dir=None, extra_gn_switches=None,
|
||||
def BuildAar(archs,
|
||||
output_file,
|
||||
use_goma=False,
|
||||
extra_gn_args=None,
|
||||
ext_build_dir=None,
|
||||
extra_gn_switches=None,
|
||||
extra_ninja_switches=None):
|
||||
extra_gn_args = extra_gn_args or []
|
||||
extra_gn_switches = extra_gn_switches or []
|
||||
extra_ninja_switches = extra_ninja_switches or []
|
||||
build_dir = ext_build_dir if ext_build_dir else tempfile.mkdtemp()
|
||||
extra_gn_args = extra_gn_args or []
|
||||
extra_gn_switches = extra_gn_switches or []
|
||||
extra_ninja_switches = extra_ninja_switches or []
|
||||
build_dir = ext_build_dir if ext_build_dir else tempfile.mkdtemp()
|
||||
|
||||
for arch in archs:
|
||||
Build(build_dir, arch, use_goma, extra_gn_args, extra_gn_switches,
|
||||
extra_ninja_switches)
|
||||
|
||||
with zipfile.ZipFile(output_file, 'w') as aar_file:
|
||||
# Architecture doesn't matter here, arbitrarily using the first one.
|
||||
CollectCommon(aar_file, build_dir, archs[0])
|
||||
for arch in archs:
|
||||
Collect(aar_file, build_dir, arch)
|
||||
Build(build_dir, arch, use_goma, extra_gn_args, extra_gn_switches,
|
||||
extra_ninja_switches)
|
||||
|
||||
license_dir = os.path.dirname(os.path.realpath(output_file))
|
||||
GenerateLicenses(license_dir, build_dir, archs)
|
||||
with zipfile.ZipFile(output_file, 'w') as aar_file:
|
||||
# Architecture doesn't matter here, arbitrarily using the first one.
|
||||
CollectCommon(aar_file, build_dir, archs[0])
|
||||
for arch in archs:
|
||||
Collect(aar_file, build_dir, arch)
|
||||
|
||||
if not ext_build_dir:
|
||||
shutil.rmtree(build_dir, True)
|
||||
license_dir = os.path.dirname(os.path.realpath(output_file))
|
||||
GenerateLicenses(license_dir, build_dir, archs)
|
||||
|
||||
if not ext_build_dir:
|
||||
shutil.rmtree(build_dir, True)
|
||||
|
||||
|
||||
def main():
|
||||
args = _ParseArgs()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
args = _ParseArgs()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
BuildAar(args.arch, args.output, args.use_goma, args.extra_gn_args,
|
||||
args.build_dir, args.extra_gn_switches, args.extra_ninja_switches)
|
||||
BuildAar(args.arch, args.output, args.use_goma, args.extra_gn_args,
|
||||
args.build_dir, args.extra_gn_switches, args.extra_ninja_switches)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Script for publishing WebRTC AAR on Bintray.
|
||||
|
||||
Set BINTRAY_USER and BINTRAY_API_KEY environment variables before running
|
||||
@ -25,7 +24,6 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
|
||||
CHECKOUT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
|
||||
|
||||
@ -36,7 +34,6 @@ import jinja2
|
||||
sys.path.append(os.path.join(CHECKOUT_ROOT, 'tools_webrtc'))
|
||||
from android.build_aar import BuildAar
|
||||
|
||||
|
||||
ARCHS = ['armeabi-v7a', 'arm64-v8a', 'x86', 'x86_64']
|
||||
MAVEN_REPOSITORY = 'https://google.bintray.com/webrtc'
|
||||
API = 'https://api.bintray.com'
|
||||
@ -62,230 +59,249 @@ AAR_PROJECT_VERSION_DEPENDENCY = "implementation 'org.webrtc:google-webrtc:%s'"
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
parser = argparse.ArgumentParser(description='Releases WebRTC on Bintray.')
|
||||
parser.add_argument('--use-goma', action='store_true', default=False,
|
||||
help='Use goma.')
|
||||
parser.add_argument('--skip-tests', action='store_true', default=False,
|
||||
help='Skips running the tests.')
|
||||
parser.add_argument('--publish', action='store_true', default=False,
|
||||
help='Automatically publishes the library if the tests pass.')
|
||||
parser.add_argument('--build-dir', default=None,
|
||||
help='Temporary directory to store the build files. If not specified, '
|
||||
'a new directory will be created.')
|
||||
parser.add_argument('--verbose', action='store_true', default=False,
|
||||
help='Debug logging.')
|
||||
return parser.parse_args()
|
||||
parser = argparse.ArgumentParser(description='Releases WebRTC on Bintray.')
|
||||
parser.add_argument('--use-goma',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use goma.')
|
||||
parser.add_argument('--skip-tests',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Skips running the tests.')
|
||||
parser.add_argument(
|
||||
'--publish',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Automatically publishes the library if the tests pass.')
|
||||
parser.add_argument(
|
||||
'--build-dir',
|
||||
default=None,
|
||||
help='Temporary directory to store the build files. If not specified, '
|
||||
'a new directory will be created.')
|
||||
parser.add_argument('--verbose',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Debug logging.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _GetCommitHash():
|
||||
commit_hash = subprocess.check_output(
|
||||
['git', 'rev-parse', 'HEAD'], cwd=CHECKOUT_ROOT).strip()
|
||||
return commit_hash
|
||||
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
|
||||
cwd=CHECKOUT_ROOT).strip()
|
||||
return commit_hash
|
||||
|
||||
|
||||
def _GetCommitPos():
|
||||
commit_message = subprocess.check_output(
|
||||
['git', 'rev-list', '--format=%B', '--max-count=1', 'HEAD'],
|
||||
cwd=CHECKOUT_ROOT)
|
||||
commit_pos_match = re.search(
|
||||
COMMIT_POSITION_REGEX, commit_message, re.MULTILINE)
|
||||
if not commit_pos_match:
|
||||
raise Exception('Commit position not found in the commit message: %s'
|
||||
% commit_message)
|
||||
return commit_pos_match.group(1)
|
||||
commit_message = subprocess.check_output(
|
||||
['git', 'rev-list', '--format=%B', '--max-count=1', 'HEAD'],
|
||||
cwd=CHECKOUT_ROOT)
|
||||
commit_pos_match = re.search(COMMIT_POSITION_REGEX, commit_message,
|
||||
re.MULTILINE)
|
||||
if not commit_pos_match:
|
||||
raise Exception('Commit position not found in the commit message: %s' %
|
||||
commit_message)
|
||||
return commit_pos_match.group(1)
|
||||
|
||||
|
||||
def _UploadFile(user, password, filename, version, target_file):
|
||||
# URL is of format:
|
||||
# <repository_api>/<version>/<group_id>/<artifact_id>/<version>/<target_file>
|
||||
# Example:
|
||||
# https://api.bintray.com/content/google/webrtc/google-webrtc/1.0.19742/org/webrtc/google-webrtc/1.0.19742/google-webrtc-1.0.19742.aar
|
||||
# URL is of format:
|
||||
# <repository_api>/<version>/<group_id>/<artifact_id>/<version>/<target_file>
|
||||
# Example:
|
||||
# https://api.bintray.com/content/google/webrtc/google-webrtc/1.0.19742/org/webrtc/google-webrtc/1.0.19742/google-webrtc-1.0.19742.aar
|
||||
|
||||
target_dir = version + '/' + GROUP_ID + '/' + ARTIFACT_ID + '/' + version
|
||||
target_path = target_dir + '/' + target_file
|
||||
url = CONTENT_API + '/' + target_path
|
||||
target_dir = version + '/' + GROUP_ID + '/' + ARTIFACT_ID + '/' + version
|
||||
target_path = target_dir + '/' + target_file
|
||||
url = CONTENT_API + '/' + target_path
|
||||
|
||||
logging.info('Uploading %s to %s', filename, url)
|
||||
with open(filename) as fh:
|
||||
file_data = fh.read()
|
||||
logging.info('Uploading %s to %s', filename, url)
|
||||
with open(filename) as fh:
|
||||
file_data = fh.read()
|
||||
|
||||
for attempt in xrange(UPLOAD_TRIES):
|
||||
try:
|
||||
response = requests.put(url, data=file_data, auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
break
|
||||
except requests.exceptions.Timeout as e:
|
||||
logging.warning('Timeout while uploading: %s', e)
|
||||
time.sleep(UPLOAD_RETRY_BASE_SLEEP_SECONDS ** attempt)
|
||||
else:
|
||||
raise Exception('Failed to upload %s' % filename)
|
||||
for attempt in xrange(UPLOAD_TRIES):
|
||||
try:
|
||||
response = requests.put(url,
|
||||
data=file_data,
|
||||
auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
break
|
||||
except requests.exceptions.Timeout as e:
|
||||
logging.warning('Timeout while uploading: %s', e)
|
||||
time.sleep(UPLOAD_RETRY_BASE_SLEEP_SECONDS**attempt)
|
||||
else:
|
||||
raise Exception('Failed to upload %s' % filename)
|
||||
|
||||
if not response.ok:
|
||||
raise Exception('Failed to upload %s. Response: %s' % (filename, response))
|
||||
logging.info('Uploaded %s: %s', filename, response)
|
||||
if not response.ok:
|
||||
raise Exception('Failed to upload %s. Response: %s' %
|
||||
(filename, response))
|
||||
logging.info('Uploaded %s: %s', filename, response)
|
||||
|
||||
|
||||
def _GeneratePom(target_file, version, commit):
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader('release_aar'),
|
||||
)
|
||||
template = env.get_template('pom.jinja')
|
||||
pom = template.render(version=version, commit=commit)
|
||||
with open(target_file, 'w') as fh:
|
||||
fh.write(pom)
|
||||
env = jinja2.Environment(loader=jinja2.PackageLoader('release_aar'), )
|
||||
template = env.get_template('pom.jinja')
|
||||
pom = template.render(version=version, commit=commit)
|
||||
with open(target_file, 'w') as fh:
|
||||
fh.write(pom)
|
||||
|
||||
|
||||
def _TestAAR(tmp_dir, username, password, version):
|
||||
"""Runs AppRTCMobile tests using the AAR. Returns true if the tests pass."""
|
||||
logging.info('Testing library.')
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader('release_aar'),
|
||||
)
|
||||
"""Runs AppRTCMobile tests using the AAR. Returns true if the tests pass."""
|
||||
logging.info('Testing library.')
|
||||
env = jinja2.Environment(loader=jinja2.PackageLoader('release_aar'), )
|
||||
|
||||
gradle_backup = os.path.join(tmp_dir, 'build.gradle.backup')
|
||||
app_gradle_backup = os.path.join(tmp_dir, 'app-build.gradle.backup')
|
||||
gradle_backup = os.path.join(tmp_dir, 'build.gradle.backup')
|
||||
app_gradle_backup = os.path.join(tmp_dir, 'app-build.gradle.backup')
|
||||
|
||||
# Make backup copies of the project files before modifying them.
|
||||
shutil.copy2(AAR_PROJECT_GRADLE, gradle_backup)
|
||||
shutil.copy2(AAR_PROJECT_APP_GRADLE, app_gradle_backup)
|
||||
# Make backup copies of the project files before modifying them.
|
||||
shutil.copy2(AAR_PROJECT_GRADLE, gradle_backup)
|
||||
shutil.copy2(AAR_PROJECT_APP_GRADLE, app_gradle_backup)
|
||||
|
||||
try:
|
||||
maven_repository_template = env.get_template('maven-repository.jinja')
|
||||
maven_repository = maven_repository_template.render(
|
||||
url=MAVEN_REPOSITORY, username=username, password=password)
|
||||
|
||||
# Append Maven repository to build file to download unpublished files.
|
||||
with open(AAR_PROJECT_GRADLE, 'a') as gradle_file:
|
||||
gradle_file.write(maven_repository)
|
||||
|
||||
# Read app build file.
|
||||
with open(AAR_PROJECT_APP_GRADLE, 'r') as gradle_app_file:
|
||||
gradle_app = gradle_app_file.read()
|
||||
|
||||
if AAR_PROJECT_DEPENDENCY not in gradle_app:
|
||||
raise Exception(
|
||||
'%s not found in the build file.' % AAR_PROJECT_DEPENDENCY)
|
||||
# Set version to the version to be tested.
|
||||
target_dependency = AAR_PROJECT_VERSION_DEPENDENCY % version
|
||||
gradle_app = gradle_app.replace(AAR_PROJECT_DEPENDENCY, target_dependency)
|
||||
|
||||
# Write back.
|
||||
with open(AAR_PROJECT_APP_GRADLE, 'w') as gradle_app_file:
|
||||
gradle_app_file.write(gradle_app)
|
||||
|
||||
# Uninstall any existing version of AppRTCMobile.
|
||||
logging.info('Uninstalling previous AppRTCMobile versions. It is okay for '
|
||||
'these commands to fail if AppRTCMobile is not installed.')
|
||||
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc'])
|
||||
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc.test'])
|
||||
|
||||
# Run tests.
|
||||
try:
|
||||
# First clean the project.
|
||||
subprocess.check_call([GRADLEW_BIN, 'clean'], cwd=AAR_PROJECT_DIR)
|
||||
# Then run the tests.
|
||||
subprocess.check_call([GRADLEW_BIN, 'connectedDebugAndroidTest'],
|
||||
cwd=AAR_PROJECT_DIR)
|
||||
except subprocess.CalledProcessError:
|
||||
logging.exception('Test failure.')
|
||||
return False # Clean or tests failed
|
||||
maven_repository_template = env.get_template('maven-repository.jinja')
|
||||
maven_repository = maven_repository_template.render(
|
||||
url=MAVEN_REPOSITORY, username=username, password=password)
|
||||
|
||||
return True # Tests pass
|
||||
finally:
|
||||
# Restore backups.
|
||||
shutil.copy2(gradle_backup, AAR_PROJECT_GRADLE)
|
||||
shutil.copy2(app_gradle_backup, AAR_PROJECT_APP_GRADLE)
|
||||
# Append Maven repository to build file to download unpublished files.
|
||||
with open(AAR_PROJECT_GRADLE, 'a') as gradle_file:
|
||||
gradle_file.write(maven_repository)
|
||||
|
||||
# Read app build file.
|
||||
with open(AAR_PROJECT_APP_GRADLE, 'r') as gradle_app_file:
|
||||
gradle_app = gradle_app_file.read()
|
||||
|
||||
if AAR_PROJECT_DEPENDENCY not in gradle_app:
|
||||
raise Exception('%s not found in the build file.' %
|
||||
AAR_PROJECT_DEPENDENCY)
|
||||
# Set version to the version to be tested.
|
||||
target_dependency = AAR_PROJECT_VERSION_DEPENDENCY % version
|
||||
gradle_app = gradle_app.replace(AAR_PROJECT_DEPENDENCY,
|
||||
target_dependency)
|
||||
|
||||
# Write back.
|
||||
with open(AAR_PROJECT_APP_GRADLE, 'w') as gradle_app_file:
|
||||
gradle_app_file.write(gradle_app)
|
||||
|
||||
# Uninstall any existing version of AppRTCMobile.
|
||||
logging.info(
|
||||
'Uninstalling previous AppRTCMobile versions. It is okay for '
|
||||
'these commands to fail if AppRTCMobile is not installed.')
|
||||
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc'])
|
||||
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc.test'])
|
||||
|
||||
# Run tests.
|
||||
try:
|
||||
# First clean the project.
|
||||
subprocess.check_call([GRADLEW_BIN, 'clean'], cwd=AAR_PROJECT_DIR)
|
||||
# Then run the tests.
|
||||
subprocess.check_call([GRADLEW_BIN, 'connectedDebugAndroidTest'],
|
||||
cwd=AAR_PROJECT_DIR)
|
||||
except subprocess.CalledProcessError:
|
||||
logging.exception('Test failure.')
|
||||
return False # Clean or tests failed
|
||||
|
||||
return True # Tests pass
|
||||
finally:
|
||||
# Restore backups.
|
||||
shutil.copy2(gradle_backup, AAR_PROJECT_GRADLE)
|
||||
shutil.copy2(app_gradle_backup, AAR_PROJECT_APP_GRADLE)
|
||||
|
||||
|
||||
def _PublishAAR(user, password, version, additional_args):
|
||||
args = {
|
||||
'publish_wait_for_secs': 0 # Publish asynchronously.
|
||||
}
|
||||
args.update(additional_args)
|
||||
args = {
|
||||
'publish_wait_for_secs': 0 # Publish asynchronously.
|
||||
}
|
||||
args.update(additional_args)
|
||||
|
||||
url = CONTENT_API + '/' + version + '/publish'
|
||||
response = requests.post(url, data=json.dumps(args), auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
url = CONTENT_API + '/' + version + '/publish'
|
||||
response = requests.post(url,
|
||||
data=json.dumps(args),
|
||||
auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
|
||||
if not response.ok:
|
||||
raise Exception('Failed to publish. Response: %s' % response)
|
||||
if not response.ok:
|
||||
raise Exception('Failed to publish. Response: %s' % response)
|
||||
|
||||
|
||||
def _DeleteUnpublishedVersion(user, password, version):
|
||||
url = PACKAGES_API + '/versions/' + version
|
||||
response = requests.get(url, auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
if not response.ok:
|
||||
raise Exception('Failed to get version info. Response: %s' % response)
|
||||
url = PACKAGES_API + '/versions/' + version
|
||||
response = requests.get(url,
|
||||
auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
if not response.ok:
|
||||
raise Exception('Failed to get version info. Response: %s' % response)
|
||||
|
||||
version_info = json.loads(response.content)
|
||||
if version_info['published']:
|
||||
logging.info('Version has already been published, not deleting.')
|
||||
return
|
||||
version_info = json.loads(response.content)
|
||||
if version_info['published']:
|
||||
logging.info('Version has already been published, not deleting.')
|
||||
return
|
||||
|
||||
logging.info('Deleting unpublished version.')
|
||||
response = requests.delete(url, auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
if not response.ok:
|
||||
raise Exception('Failed to delete version. Response: %s' % response)
|
||||
logging.info('Deleting unpublished version.')
|
||||
response = requests.delete(url,
|
||||
auth=(user, password),
|
||||
timeout=API_TIMEOUT_SECONDS)
|
||||
if not response.ok:
|
||||
raise Exception('Failed to delete version. Response: %s' % response)
|
||||
|
||||
|
||||
def ReleaseAar(use_goma, skip_tests, publish, build_dir):
|
||||
version = '1.0.' + _GetCommitPos()
|
||||
commit = _GetCommitHash()
|
||||
logging.info('Releasing AAR version %s with hash %s', version, commit)
|
||||
version = '1.0.' + _GetCommitPos()
|
||||
commit = _GetCommitHash()
|
||||
logging.info('Releasing AAR version %s with hash %s', version, commit)
|
||||
|
||||
user = os.environ.get('BINTRAY_USER', None)
|
||||
api_key = os.environ.get('BINTRAY_API_KEY', None)
|
||||
if not user or not api_key:
|
||||
raise Exception('Environment variables BINTRAY_USER and BINTRAY_API_KEY '
|
||||
'must be defined.')
|
||||
user = os.environ.get('BINTRAY_USER', None)
|
||||
api_key = os.environ.get('BINTRAY_API_KEY', None)
|
||||
if not user or not api_key:
|
||||
raise Exception(
|
||||
'Environment variables BINTRAY_USER and BINTRAY_API_KEY '
|
||||
'must be defined.')
|
||||
|
||||
# If build directory is not specified, create a temporary directory.
|
||||
use_tmp_dir = not build_dir
|
||||
if use_tmp_dir:
|
||||
build_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
base_name = ARTIFACT_ID + '-' + version
|
||||
aar_file = os.path.join(build_dir, base_name + '.aar')
|
||||
third_party_licenses_file = os.path.join(build_dir, 'LICENSE.md')
|
||||
pom_file = os.path.join(build_dir, base_name + '.pom')
|
||||
|
||||
logging.info('Building at %s', build_dir)
|
||||
BuildAar(ARCHS, aar_file,
|
||||
use_goma=use_goma,
|
||||
ext_build_dir=os.path.join(build_dir, 'aar-build'))
|
||||
_GeneratePom(pom_file, version, commit)
|
||||
|
||||
_UploadFile(user, api_key, aar_file, version, base_name + '.aar')
|
||||
_UploadFile(user, api_key, third_party_licenses_file, version,
|
||||
'THIRD_PARTY_LICENSES.md')
|
||||
_UploadFile(user, api_key, pom_file, version, base_name + '.pom')
|
||||
|
||||
tests_pass = skip_tests or _TestAAR(build_dir, user, api_key, version)
|
||||
if not tests_pass:
|
||||
logging.info('Discarding library.')
|
||||
_PublishAAR(user, api_key, version, {'discard': True})
|
||||
_DeleteUnpublishedVersion(user, api_key, version)
|
||||
raise Exception('Test failure. Discarded library.')
|
||||
|
||||
if publish:
|
||||
logging.info('Publishing library.')
|
||||
_PublishAAR(user, api_key, version, {})
|
||||
else:
|
||||
logging.info('Note: The library has not not been published automatically.'
|
||||
' Please do so manually if desired.')
|
||||
finally:
|
||||
# If build directory is not specified, create a temporary directory.
|
||||
use_tmp_dir = not build_dir
|
||||
if use_tmp_dir:
|
||||
shutil.rmtree(build_dir, True)
|
||||
build_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
base_name = ARTIFACT_ID + '-' + version
|
||||
aar_file = os.path.join(build_dir, base_name + '.aar')
|
||||
third_party_licenses_file = os.path.join(build_dir, 'LICENSE.md')
|
||||
pom_file = os.path.join(build_dir, base_name + '.pom')
|
||||
|
||||
logging.info('Building at %s', build_dir)
|
||||
BuildAar(ARCHS,
|
||||
aar_file,
|
||||
use_goma=use_goma,
|
||||
ext_build_dir=os.path.join(build_dir, 'aar-build'))
|
||||
_GeneratePom(pom_file, version, commit)
|
||||
|
||||
_UploadFile(user, api_key, aar_file, version, base_name + '.aar')
|
||||
_UploadFile(user, api_key, third_party_licenses_file, version,
|
||||
'THIRD_PARTY_LICENSES.md')
|
||||
_UploadFile(user, api_key, pom_file, version, base_name + '.pom')
|
||||
|
||||
tests_pass = skip_tests or _TestAAR(build_dir, user, api_key, version)
|
||||
if not tests_pass:
|
||||
logging.info('Discarding library.')
|
||||
_PublishAAR(user, api_key, version, {'discard': True})
|
||||
_DeleteUnpublishedVersion(user, api_key, version)
|
||||
raise Exception('Test failure. Discarded library.')
|
||||
|
||||
if publish:
|
||||
logging.info('Publishing library.')
|
||||
_PublishAAR(user, api_key, version, {})
|
||||
else:
|
||||
logging.info(
|
||||
'Note: The library has not not been published automatically.'
|
||||
' Please do so manually if desired.')
|
||||
finally:
|
||||
if use_tmp_dir:
|
||||
shutil.rmtree(build_dir, True)
|
||||
|
||||
|
||||
def main():
|
||||
args = _ParseArgs()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
ReleaseAar(args.use_goma, args.skip_tests, args.publish, args.build_dir)
|
||||
args = _ParseArgs()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
ReleaseAar(args.use_goma, args.skip_tests, args.publish, args.build_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -14,7 +14,6 @@ import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
PARENT_DIR = os.path.join(SCRIPT_DIR, os.pardir)
|
||||
sys.path.append(PARENT_DIR)
|
||||
@ -27,15 +26,15 @@ from roll_deps import CalculateChangedDeps, FindAddedDeps, \
|
||||
import mock
|
||||
|
||||
TEST_DATA_VARS = {
|
||||
'chromium_git': 'https://chromium.googlesource.com',
|
||||
'chromium_revision': '1b9c098a08e40114e44b6c1ec33ddf95c40b901d',
|
||||
'chromium_git': 'https://chromium.googlesource.com',
|
||||
'chromium_revision': '1b9c098a08e40114e44b6c1ec33ddf95c40b901d',
|
||||
}
|
||||
|
||||
DEPS_ENTRIES = {
|
||||
'src/build': 'https://build.com',
|
||||
'src/third_party/depot_tools': 'https://depottools.com',
|
||||
'src/testing/gtest': 'https://gtest.com',
|
||||
'src/testing/gmock': 'https://gmock.com',
|
||||
'src/build': 'https://build.com',
|
||||
'src/third_party/depot_tools': 'https://depottools.com',
|
||||
'src/testing/gtest': 'https://gtest.com',
|
||||
'src/testing/gmock': 'https://gmock.com',
|
||||
}
|
||||
|
||||
BUILD_OLD_REV = '52f7afeca991d96d68cf0507e20dbdd5b845691f'
|
||||
@ -47,291 +46,298 @@ NO_CHROMIUM_REVISION_UPDATE = ChromiumRevisionUpdate('cafe', 'cafe')
|
||||
|
||||
|
||||
class TestError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class FakeCmd(object):
|
||||
def __init__(self):
|
||||
self.expectations = []
|
||||
def __init__(self):
|
||||
self.expectations = []
|
||||
|
||||
def AddExpectation(self, *args, **kwargs):
|
||||
returns = kwargs.pop('_returns', None)
|
||||
ignores = kwargs.pop('_ignores', [])
|
||||
self.expectations.append((args, kwargs, returns, ignores))
|
||||
def AddExpectation(self, *args, **kwargs):
|
||||
returns = kwargs.pop('_returns', None)
|
||||
ignores = kwargs.pop('_ignores', [])
|
||||
self.expectations.append((args, kwargs, returns, ignores))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.expectations:
|
||||
raise TestError('Got unexpected\n%s\n%s' % (args, kwargs))
|
||||
exp_args, exp_kwargs, exp_returns, ignores = self.expectations.pop(0)
|
||||
for item in ignores:
|
||||
kwargs.pop(item, None)
|
||||
if args != exp_args or kwargs != exp_kwargs:
|
||||
message = 'Expected:\n args: %s\n kwargs: %s\n' % (exp_args, exp_kwargs)
|
||||
message += 'Got:\n args: %s\n kwargs: %s\n' % (args, kwargs)
|
||||
raise TestError(message)
|
||||
return exp_returns
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.expectations:
|
||||
raise TestError('Got unexpected\n%s\n%s' % (args, kwargs))
|
||||
exp_args, exp_kwargs, exp_returns, ignores = self.expectations.pop(0)
|
||||
for item in ignores:
|
||||
kwargs.pop(item, None)
|
||||
if args != exp_args or kwargs != exp_kwargs:
|
||||
message = 'Expected:\n args: %s\n kwargs: %s\n' % (exp_args,
|
||||
exp_kwargs)
|
||||
message += 'Got:\n args: %s\n kwargs: %s\n' % (args, kwargs)
|
||||
raise TestError(message)
|
||||
return exp_returns
|
||||
|
||||
|
||||
class NullCmd(object):
|
||||
"""No-op mock when calls mustn't be checked. """
|
||||
"""No-op mock when calls mustn't be checked. """
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Empty stdout and stderr.
|
||||
return None, None
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Empty stdout and stderr.
|
||||
return None, None
|
||||
|
||||
|
||||
class TestRollChromiumRevision(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._output_dir = tempfile.mkdtemp()
|
||||
test_data_dir = os.path.join(SCRIPT_DIR, 'testdata', 'roll_deps')
|
||||
for test_file in glob.glob(os.path.join(test_data_dir, '*')):
|
||||
shutil.copy(test_file, self._output_dir)
|
||||
join = lambda f: os.path.join(self._output_dir, f)
|
||||
self._webrtc_depsfile = join('DEPS')
|
||||
self._new_cr_depsfile = join('DEPS.chromium.new')
|
||||
self._webrtc_depsfile_android = join('DEPS.with_android_deps')
|
||||
self._new_cr_depsfile_android = join('DEPS.chromium.with_android_deps')
|
||||
self.fake = FakeCmd()
|
||||
def setUp(self):
|
||||
self._output_dir = tempfile.mkdtemp()
|
||||
test_data_dir = os.path.join(SCRIPT_DIR, 'testdata', 'roll_deps')
|
||||
for test_file in glob.glob(os.path.join(test_data_dir, '*')):
|
||||
shutil.copy(test_file, self._output_dir)
|
||||
join = lambda f: os.path.join(self._output_dir, f)
|
||||
self._webrtc_depsfile = join('DEPS')
|
||||
self._new_cr_depsfile = join('DEPS.chromium.new')
|
||||
self._webrtc_depsfile_android = join('DEPS.with_android_deps')
|
||||
self._new_cr_depsfile_android = join('DEPS.chromium.with_android_deps')
|
||||
self.fake = FakeCmd()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self._output_dir, ignore_errors=True)
|
||||
self.assertEqual(self.fake.expectations, [])
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self._output_dir, ignore_errors=True)
|
||||
self.assertEqual(self.fake.expectations, [])
|
||||
|
||||
def testVarLookup(self):
|
||||
local_scope = {'foo': 'wrong', 'vars': {'foo': 'bar'}}
|
||||
lookup = roll_deps.VarLookup(local_scope)
|
||||
self.assertEquals(lookup('foo'), 'bar')
|
||||
def testVarLookup(self):
|
||||
local_scope = {'foo': 'wrong', 'vars': {'foo': 'bar'}}
|
||||
lookup = roll_deps.VarLookup(local_scope)
|
||||
self.assertEquals(lookup('foo'), 'bar')
|
||||
|
||||
def testUpdateDepsFile(self):
|
||||
new_rev = 'aaaaabbbbbcccccdddddeeeeefffff0000011111'
|
||||
current_rev = TEST_DATA_VARS['chromium_revision']
|
||||
def testUpdateDepsFile(self):
|
||||
new_rev = 'aaaaabbbbbcccccdddddeeeeefffff0000011111'
|
||||
current_rev = TEST_DATA_VARS['chromium_revision']
|
||||
|
||||
with open(self._new_cr_depsfile_android) as deps_file:
|
||||
new_cr_contents = deps_file.read()
|
||||
with open(self._new_cr_depsfile_android) as deps_file:
|
||||
new_cr_contents = deps_file.read()
|
||||
|
||||
UpdateDepsFile(self._webrtc_depsfile,
|
||||
ChromiumRevisionUpdate(current_rev, new_rev),
|
||||
[],
|
||||
new_cr_contents)
|
||||
with open(self._webrtc_depsfile) as deps_file:
|
||||
deps_contents = deps_file.read()
|
||||
self.assertTrue(new_rev in deps_contents,
|
||||
'Failed to find %s in\n%s' % (new_rev, deps_contents))
|
||||
UpdateDepsFile(self._webrtc_depsfile,
|
||||
ChromiumRevisionUpdate(current_rev, new_rev), [],
|
||||
new_cr_contents)
|
||||
with open(self._webrtc_depsfile) as deps_file:
|
||||
deps_contents = deps_file.read()
|
||||
self.assertTrue(
|
||||
new_rev in deps_contents,
|
||||
'Failed to find %s in\n%s' % (new_rev, deps_contents))
|
||||
|
||||
def _UpdateDepsSetup(self):
|
||||
with open(self._webrtc_depsfile_android) as deps_file:
|
||||
webrtc_contents = deps_file.read()
|
||||
with open(self._new_cr_depsfile_android) as deps_file:
|
||||
new_cr_contents = deps_file.read()
|
||||
webrtc_deps = ParseDepsDict(webrtc_contents)
|
||||
new_cr_deps = ParseDepsDict(new_cr_contents)
|
||||
def _UpdateDepsSetup(self):
|
||||
with open(self._webrtc_depsfile_android) as deps_file:
|
||||
webrtc_contents = deps_file.read()
|
||||
with open(self._new_cr_depsfile_android) as deps_file:
|
||||
new_cr_contents = deps_file.read()
|
||||
webrtc_deps = ParseDepsDict(webrtc_contents)
|
||||
new_cr_deps = ParseDepsDict(new_cr_contents)
|
||||
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
with mock.patch('roll_deps._RunCommand', NullCmd()):
|
||||
UpdateDepsFile(self._webrtc_depsfile_android,
|
||||
NO_CHROMIUM_REVISION_UPDATE,
|
||||
changed_deps,
|
||||
new_cr_contents)
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
with mock.patch('roll_deps._RunCommand', NullCmd()):
|
||||
UpdateDepsFile(self._webrtc_depsfile_android,
|
||||
NO_CHROMIUM_REVISION_UPDATE, changed_deps,
|
||||
new_cr_contents)
|
||||
|
||||
with open(self._webrtc_depsfile_android) as deps_file:
|
||||
updated_contents = deps_file.read()
|
||||
with open(self._webrtc_depsfile_android) as deps_file:
|
||||
updated_contents = deps_file.read()
|
||||
|
||||
return webrtc_contents, updated_contents
|
||||
return webrtc_contents, updated_contents
|
||||
|
||||
def testUpdateAndroidGeneratedDeps(self):
|
||||
_, updated_contents = self._UpdateDepsSetup()
|
||||
def testUpdateAndroidGeneratedDeps(self):
|
||||
_, updated_contents = self._UpdateDepsSetup()
|
||||
|
||||
changed = 'third_party/android_deps/libs/android_arch_core_common'
|
||||
changed_version = '1.0.0-cr0'
|
||||
self.assertTrue(changed in updated_contents)
|
||||
self.assertTrue(changed_version in updated_contents)
|
||||
changed = 'third_party/android_deps/libs/android_arch_core_common'
|
||||
changed_version = '1.0.0-cr0'
|
||||
self.assertTrue(changed in updated_contents)
|
||||
self.assertTrue(changed_version in updated_contents)
|
||||
|
||||
def testAddAndroidGeneratedDeps(self):
|
||||
webrtc_contents, updated_contents = self._UpdateDepsSetup()
|
||||
def testAddAndroidGeneratedDeps(self):
|
||||
webrtc_contents, updated_contents = self._UpdateDepsSetup()
|
||||
|
||||
added = 'third_party/android_deps/libs/android_arch_lifecycle_common'
|
||||
self.assertFalse(added in webrtc_contents)
|
||||
self.assertTrue(added in updated_contents)
|
||||
added = 'third_party/android_deps/libs/android_arch_lifecycle_common'
|
||||
self.assertFalse(added in webrtc_contents)
|
||||
self.assertTrue(added in updated_contents)
|
||||
|
||||
def testRemoveAndroidGeneratedDeps(self):
|
||||
webrtc_contents, updated_contents = self._UpdateDepsSetup()
|
||||
def testRemoveAndroidGeneratedDeps(self):
|
||||
webrtc_contents, updated_contents = self._UpdateDepsSetup()
|
||||
|
||||
removed = 'third_party/android_deps/libs/android_arch_lifecycle_runtime'
|
||||
self.assertTrue(removed in webrtc_contents)
|
||||
self.assertFalse(removed in updated_contents)
|
||||
removed = 'third_party/android_deps/libs/android_arch_lifecycle_runtime'
|
||||
self.assertTrue(removed in webrtc_contents)
|
||||
self.assertFalse(removed in updated_contents)
|
||||
|
||||
def testParseDepsDict(self):
|
||||
with open(self._webrtc_depsfile) as deps_file:
|
||||
deps_contents = deps_file.read()
|
||||
local_scope = ParseDepsDict(deps_contents)
|
||||
vars_dict = local_scope['vars']
|
||||
def testParseDepsDict(self):
|
||||
with open(self._webrtc_depsfile) as deps_file:
|
||||
deps_contents = deps_file.read()
|
||||
local_scope = ParseDepsDict(deps_contents)
|
||||
vars_dict = local_scope['vars']
|
||||
|
||||
def AssertVar(variable_name):
|
||||
self.assertEquals(vars_dict[variable_name], TEST_DATA_VARS[variable_name])
|
||||
AssertVar('chromium_git')
|
||||
AssertVar('chromium_revision')
|
||||
self.assertEquals(len(local_scope['deps']), 3)
|
||||
self.assertEquals(len(local_scope['deps_os']), 1)
|
||||
def AssertVar(variable_name):
|
||||
self.assertEquals(vars_dict[variable_name],
|
||||
TEST_DATA_VARS[variable_name])
|
||||
|
||||
def testGetMatchingDepsEntriesReturnsPathInSimpleCase(self):
|
||||
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing/gtest')
|
||||
self.assertEquals(len(entries), 1)
|
||||
self.assertEquals(entries[0], DEPS_ENTRIES['src/testing/gtest'])
|
||||
AssertVar('chromium_git')
|
||||
AssertVar('chromium_revision')
|
||||
self.assertEquals(len(local_scope['deps']), 3)
|
||||
self.assertEquals(len(local_scope['deps_os']), 1)
|
||||
|
||||
def testGetMatchingDepsEntriesHandlesSimilarStartingPaths(self):
|
||||
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing')
|
||||
self.assertEquals(len(entries), 2)
|
||||
def testGetMatchingDepsEntriesReturnsPathInSimpleCase(self):
|
||||
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing/gtest')
|
||||
self.assertEquals(len(entries), 1)
|
||||
self.assertEquals(entries[0], DEPS_ENTRIES['src/testing/gtest'])
|
||||
|
||||
def testGetMatchingDepsEntriesHandlesTwoPathsWithIdenticalFirstParts(self):
|
||||
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/build')
|
||||
self.assertEquals(len(entries), 1)
|
||||
def testGetMatchingDepsEntriesHandlesSimilarStartingPaths(self):
|
||||
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing')
|
||||
self.assertEquals(len(entries), 2)
|
||||
|
||||
def testGetMatchingDepsEntriesHandlesTwoPathsWithIdenticalFirstParts(self):
|
||||
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/build')
|
||||
self.assertEquals(len(entries), 1)
|
||||
|
||||
def testCalculateChangedDeps(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile)
|
||||
with mock.patch('roll_deps._RunCommand', self.fake):
|
||||
_SetupGitLsRemoteCall(
|
||||
self.fake, 'https://chromium.googlesource.com/chromium/src/build',
|
||||
BUILD_NEW_REV)
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
def testCalculateChangedDeps(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile)
|
||||
with mock.patch('roll_deps._RunCommand', self.fake):
|
||||
_SetupGitLsRemoteCall(
|
||||
self.fake,
|
||||
'https://chromium.googlesource.com/chromium/src/build',
|
||||
BUILD_NEW_REV)
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
|
||||
self.assertEquals(len(changed_deps), 3)
|
||||
self.assertEquals(changed_deps[0].path, 'src/build')
|
||||
self.assertEquals(changed_deps[0].current_rev, BUILD_OLD_REV)
|
||||
self.assertEquals(changed_deps[0].new_rev, BUILD_NEW_REV)
|
||||
self.assertEquals(len(changed_deps), 3)
|
||||
self.assertEquals(changed_deps[0].path, 'src/build')
|
||||
self.assertEquals(changed_deps[0].current_rev, BUILD_OLD_REV)
|
||||
self.assertEquals(changed_deps[0].new_rev, BUILD_NEW_REV)
|
||||
|
||||
self.assertEquals(changed_deps[1].path, 'src/third_party/depot_tools')
|
||||
self.assertEquals(changed_deps[1].current_rev, DEPOTTOOLS_OLD_REV)
|
||||
self.assertEquals(changed_deps[1].new_rev, DEPOTTOOLS_NEW_REV)
|
||||
self.assertEquals(changed_deps[1].path, 'src/third_party/depot_tools')
|
||||
self.assertEquals(changed_deps[1].current_rev, DEPOTTOOLS_OLD_REV)
|
||||
self.assertEquals(changed_deps[1].new_rev, DEPOTTOOLS_NEW_REV)
|
||||
|
||||
self.assertEquals(changed_deps[2].path, 'src/third_party/xstream')
|
||||
self.assertEquals(changed_deps[2].package, 'chromium/third_party/xstream')
|
||||
self.assertEquals(changed_deps[2].current_version, 'version:1.4.8-cr0')
|
||||
self.assertEquals(changed_deps[2].new_version, 'version:1.10.0-cr0')
|
||||
self.assertEquals(changed_deps[2].path, 'src/third_party/xstream')
|
||||
self.assertEquals(changed_deps[2].package,
|
||||
'chromium/third_party/xstream')
|
||||
self.assertEquals(changed_deps[2].current_version, 'version:1.4.8-cr0')
|
||||
self.assertEquals(changed_deps[2].new_version, 'version:1.10.0-cr0')
|
||||
|
||||
def testWithDistinctDeps(self):
|
||||
"""Check CalculateChangedDeps still works when deps are added/removed. """
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(len(changed_deps), 1)
|
||||
self.assertEquals(
|
||||
changed_deps[0].path,
|
||||
'src/third_party/android_deps/libs/android_arch_core_common')
|
||||
self.assertEquals(
|
||||
changed_deps[0].package,
|
||||
'chromium/third_party/android_deps/libs/android_arch_core_common')
|
||||
self.assertEquals(changed_deps[0].current_version, 'version:0.9.0')
|
||||
self.assertEquals(changed_deps[0].new_version, 'version:1.0.0-cr0')
|
||||
def testWithDistinctDeps(self):
|
||||
"""Check CalculateChangedDeps still works when deps are added/removed. """
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(len(changed_deps), 1)
|
||||
self.assertEquals(
|
||||
changed_deps[0].path,
|
||||
'src/third_party/android_deps/libs/android_arch_core_common')
|
||||
self.assertEquals(
|
||||
changed_deps[0].package,
|
||||
'chromium/third_party/android_deps/libs/android_arch_core_common')
|
||||
self.assertEquals(changed_deps[0].current_version, 'version:0.9.0')
|
||||
self.assertEquals(changed_deps[0].new_version, 'version:1.0.0-cr0')
|
||||
|
||||
def testFindAddedDeps(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
added_android_paths, other_paths = FindAddedDeps(webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(
|
||||
added_android_paths,
|
||||
['src/third_party/android_deps/libs/android_arch_lifecycle_common'])
|
||||
self.assertEquals(other_paths, [])
|
||||
def testFindAddedDeps(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
added_android_paths, other_paths = FindAddedDeps(
|
||||
webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(added_android_paths, [
|
||||
'src/third_party/android_deps/libs/android_arch_lifecycle_common'
|
||||
])
|
||||
self.assertEquals(other_paths, [])
|
||||
|
||||
def testFindRemovedDeps(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
removed_android_paths, other_paths = FindRemovedDeps(webrtc_deps,
|
||||
new_cr_deps)
|
||||
self.assertEquals(removed_android_paths,
|
||||
['src/third_party/android_deps/libs/android_arch_lifecycle_runtime'])
|
||||
self.assertEquals(other_paths, [])
|
||||
def testFindRemovedDeps(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
removed_android_paths, other_paths = FindRemovedDeps(
|
||||
webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(removed_android_paths, [
|
||||
'src/third_party/android_deps/libs/android_arch_lifecycle_runtime'
|
||||
])
|
||||
self.assertEquals(other_paths, [])
|
||||
|
||||
def testMissingDepsIsDetected(self):
|
||||
"""Check an error is reported when deps cannot be automatically removed."""
|
||||
# The situation at test is the following:
|
||||
# * A WebRTC DEPS entry is missing from Chromium.
|
||||
# * The dependency isn't an android_deps (those are supported).
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
_, other_paths = FindRemovedDeps(webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(other_paths, ['src/third_party/xstream',
|
||||
'src/third_party/depot_tools'])
|
||||
def testMissingDepsIsDetected(self):
|
||||
"""Check an error is reported when deps cannot be automatically removed."""
|
||||
# The situation at test is the following:
|
||||
# * A WebRTC DEPS entry is missing from Chromium.
|
||||
# * The dependency isn't an android_deps (those are supported).
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
_, other_paths = FindRemovedDeps(webrtc_deps, new_cr_deps)
|
||||
self.assertEquals(
|
||||
other_paths,
|
||||
['src/third_party/xstream', 'src/third_party/depot_tools'])
|
||||
|
||||
def testExpectedDepsIsNotReportedMissing(self):
|
||||
"""Some deps musn't be seen as missing, even if absent from Chromium."""
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
removed_android_paths, other_paths = FindRemovedDeps(webrtc_deps,
|
||||
new_cr_deps)
|
||||
self.assertTrue('src/build' not in removed_android_paths)
|
||||
self.assertTrue('src/build' not in other_paths)
|
||||
def testExpectedDepsIsNotReportedMissing(self):
|
||||
"""Some deps musn't be seen as missing, even if absent from Chromium."""
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
removed_android_paths, other_paths = FindRemovedDeps(
|
||||
webrtc_deps, new_cr_deps)
|
||||
self.assertTrue('src/build' not in removed_android_paths)
|
||||
self.assertTrue('src/build' not in other_paths)
|
||||
|
||||
def _CommitMessageSetup(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
def _CommitMessageSetup(self):
|
||||
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
|
||||
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
|
||||
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
added_paths, _ = FindAddedDeps(webrtc_deps, new_cr_deps)
|
||||
removed_paths, _ = FindRemovedDeps(webrtc_deps, new_cr_deps)
|
||||
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
|
||||
added_paths, _ = FindAddedDeps(webrtc_deps, new_cr_deps)
|
||||
removed_paths, _ = FindRemovedDeps(webrtc_deps, new_cr_deps)
|
||||
|
||||
current_commit_pos = 'cafe'
|
||||
new_commit_pos = 'f00d'
|
||||
current_commit_pos = 'cafe'
|
||||
new_commit_pos = 'f00d'
|
||||
|
||||
with mock.patch('roll_deps._RunCommand', self.fake):
|
||||
# We don't really care, but it's needed to construct the message.
|
||||
self.fake.AddExpectation(['git', 'config', 'user.email'],
|
||||
_returns=('nobody@nowhere.no', None),
|
||||
_ignores=['working_dir'])
|
||||
with mock.patch('roll_deps._RunCommand', self.fake):
|
||||
# We don't really care, but it's needed to construct the message.
|
||||
self.fake.AddExpectation(['git', 'config', 'user.email'],
|
||||
_returns=('nobody@nowhere.no', None),
|
||||
_ignores=['working_dir'])
|
||||
|
||||
commit_msg = GenerateCommitMessage(
|
||||
NO_CHROMIUM_REVISION_UPDATE, current_commit_pos, new_commit_pos,
|
||||
changed_deps, added_paths, removed_paths)
|
||||
commit_msg = GenerateCommitMessage(NO_CHROMIUM_REVISION_UPDATE,
|
||||
current_commit_pos,
|
||||
new_commit_pos, changed_deps,
|
||||
added_paths, removed_paths)
|
||||
|
||||
return [l.strip() for l in commit_msg.split('\n')]
|
||||
return [l.strip() for l in commit_msg.split('\n')]
|
||||
|
||||
def testChangedDepsInCommitMessage(self):
|
||||
commit_lines = self._CommitMessageSetup()
|
||||
def testChangedDepsInCommitMessage(self):
|
||||
commit_lines = self._CommitMessageSetup()
|
||||
|
||||
changed = '* src/third_party/android_deps/libs/' \
|
||||
'android_arch_core_common: version:0.9.0..version:1.0.0-cr0'
|
||||
self.assertTrue(changed in commit_lines)
|
||||
# Check it is in adequate section.
|
||||
changed_line = commit_lines.index(changed)
|
||||
self.assertTrue('Changed' in commit_lines[changed_line-1])
|
||||
changed = '* src/third_party/android_deps/libs/' \
|
||||
'android_arch_core_common: version:0.9.0..version:1.0.0-cr0'
|
||||
self.assertTrue(changed in commit_lines)
|
||||
# Check it is in adequate section.
|
||||
changed_line = commit_lines.index(changed)
|
||||
self.assertTrue('Changed' in commit_lines[changed_line - 1])
|
||||
|
||||
def testAddedDepsInCommitMessage(self):
|
||||
commit_lines = self._CommitMessageSetup()
|
||||
def testAddedDepsInCommitMessage(self):
|
||||
commit_lines = self._CommitMessageSetup()
|
||||
|
||||
added = '* src/third_party/android_deps/libs/' \
|
||||
'android_arch_lifecycle_common'
|
||||
self.assertTrue(added in commit_lines)
|
||||
# Check it is in adequate section.
|
||||
added_line = commit_lines.index(added)
|
||||
self.assertTrue('Added' in commit_lines[added_line-1])
|
||||
added = '* src/third_party/android_deps/libs/' \
|
||||
'android_arch_lifecycle_common'
|
||||
self.assertTrue(added in commit_lines)
|
||||
# Check it is in adequate section.
|
||||
added_line = commit_lines.index(added)
|
||||
self.assertTrue('Added' in commit_lines[added_line - 1])
|
||||
|
||||
def testRemovedDepsInCommitMessage(self):
|
||||
commit_lines = self._CommitMessageSetup()
|
||||
def testRemovedDepsInCommitMessage(self):
|
||||
commit_lines = self._CommitMessageSetup()
|
||||
|
||||
removed = '* src/third_party/android_deps/libs/' \
|
||||
'android_arch_lifecycle_runtime'
|
||||
self.assertTrue(removed in commit_lines)
|
||||
# Check it is in adequate section.
|
||||
removed_line = commit_lines.index(removed)
|
||||
self.assertTrue('Removed' in commit_lines[removed_line-1])
|
||||
removed = '* src/third_party/android_deps/libs/' \
|
||||
'android_arch_lifecycle_runtime'
|
||||
self.assertTrue(removed in commit_lines)
|
||||
# Check it is in adequate section.
|
||||
removed_line = commit_lines.index(removed)
|
||||
self.assertTrue('Removed' in commit_lines[removed_line - 1])
|
||||
|
||||
|
||||
class TestChooseCQMode(unittest.TestCase):
|
||||
def testSkip(self):
|
||||
self.assertEquals(ChooseCQMode(True, 99, 500000, 500100), 0)
|
||||
def testSkip(self):
|
||||
self.assertEquals(ChooseCQMode(True, 99, 500000, 500100), 0)
|
||||
|
||||
def testDryRun(self):
|
||||
self.assertEquals(ChooseCQMode(False, 101, 500000, 500100), 1)
|
||||
def testDryRun(self):
|
||||
self.assertEquals(ChooseCQMode(False, 101, 500000, 500100), 1)
|
||||
|
||||
def testSubmit(self):
|
||||
self.assertEquals(ChooseCQMode(False, 100, 500000, 500100), 2)
|
||||
def testSubmit(self):
|
||||
self.assertEquals(ChooseCQMode(False, 100, 500000, 500100), 2)
|
||||
|
||||
|
||||
def _SetupGitLsRemoteCall(cmd_fake, url, revision):
|
||||
cmd = ['git', 'ls-remote', url, revision]
|
||||
cmd_fake.AddExpectation(cmd, _returns=(revision, None))
|
||||
cmd = ['git', 'ls-remote', url, revision]
|
||||
cmd_fake.AddExpectation(cmd, _returns=(revision, None))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Invoke clang-tidy tool.
|
||||
|
||||
Usage: clang_tidy.py file.cc [clang-tidy-args...]
|
||||
@ -25,7 +24,6 @@ import tempfile
|
||||
from presubmit_checks_lib.build_helpers import GetClangTidyPath, \
|
||||
GetCompilationCommand
|
||||
|
||||
|
||||
# We enable all checkers by default for investigation purpose.
|
||||
# This includes clang-analyzer-* checks.
|
||||
# Individual checkers can be disabled via command line options.
|
||||
@ -34,63 +32,66 @@ CHECKER_OPTION = '-checks=*'
|
||||
|
||||
|
||||
def Process(filepath, args):
|
||||
# Build directory is needed to gather compilation flags.
|
||||
# Create a temporary one (instead of reusing an existing one)
|
||||
# to keep the CLI simple and unencumbered.
|
||||
out_dir = tempfile.mkdtemp('clang_tidy')
|
||||
# Build directory is needed to gather compilation flags.
|
||||
# Create a temporary one (instead of reusing an existing one)
|
||||
# to keep the CLI simple and unencumbered.
|
||||
out_dir = tempfile.mkdtemp('clang_tidy')
|
||||
|
||||
try:
|
||||
gn_args = [] # Use default build.
|
||||
command = GetCompilationCommand(filepath, gn_args, out_dir)
|
||||
try:
|
||||
gn_args = [] # Use default build.
|
||||
command = GetCompilationCommand(filepath, gn_args, out_dir)
|
||||
|
||||
# Remove warning flags. They aren't needed and they cause trouble
|
||||
# when clang-tidy doesn't match most recent clang.
|
||||
# Same battle for -f (e.g. -fcomplete-member-pointers).
|
||||
command = [arg for arg in command if not (arg.startswith('-W') or
|
||||
arg.startswith('-f'))]
|
||||
# Remove warning flags. They aren't needed and they cause trouble
|
||||
# when clang-tidy doesn't match most recent clang.
|
||||
# Same battle for -f (e.g. -fcomplete-member-pointers).
|
||||
command = [
|
||||
arg for arg in command
|
||||
if not (arg.startswith('-W') or arg.startswith('-f'))
|
||||
]
|
||||
|
||||
# Path from build dir.
|
||||
rel_path = os.path.relpath(os.path.abspath(filepath), out_dir)
|
||||
# Path from build dir.
|
||||
rel_path = os.path.relpath(os.path.abspath(filepath), out_dir)
|
||||
|
||||
# Replace clang++ by clang-tidy
|
||||
command[0:1] = [GetClangTidyPath(),
|
||||
CHECKER_OPTION,
|
||||
rel_path] + args + ['--'] # Separator for clang flags.
|
||||
print "Running: %s" % ' '.join(command)
|
||||
# Run from build dir so that relative paths are correct.
|
||||
p = subprocess.Popen(command, cwd=out_dir,
|
||||
stdout=sys.stdout, stderr=sys.stderr)
|
||||
p.communicate()
|
||||
return p.returncode
|
||||
finally:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
# Replace clang++ by clang-tidy
|
||||
command[0:1] = [GetClangTidyPath(), CHECKER_OPTION, rel_path
|
||||
] + args + ['--'] # Separator for clang flags.
|
||||
print "Running: %s" % ' '.join(command)
|
||||
# Run from build dir so that relative paths are correct.
|
||||
p = subprocess.Popen(command,
|
||||
cwd=out_dir,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
p.communicate()
|
||||
return p.returncode
|
||||
finally:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def ValidateCC(filepath):
|
||||
"""We can only analyze .cc files. Provide explicit message about that."""
|
||||
if filepath.endswith('.cc'):
|
||||
return filepath
|
||||
msg = ('%s not supported.\n'
|
||||
'For now, we can only analyze translation units (.cc files).' %
|
||||
filepath)
|
||||
raise argparse.ArgumentTypeError(msg)
|
||||
"""We can only analyze .cc files. Provide explicit message about that."""
|
||||
if filepath.endswith('.cc'):
|
||||
return filepath
|
||||
msg = ('%s not supported.\n'
|
||||
'For now, we can only analyze translation units (.cc files).' %
|
||||
filepath)
|
||||
raise argparse.ArgumentTypeError(msg)
|
||||
|
||||
|
||||
def Main():
|
||||
description = (
|
||||
"Run clang-tidy on single cc file.\n"
|
||||
"Use flags, defines and include paths as in default debug build.\n"
|
||||
"WARNING, this is a POC version with rough edges.")
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('filepath',
|
||||
help='Specifies the path of the .cc file to analyze.',
|
||||
type=ValidateCC)
|
||||
parser.add_argument('args',
|
||||
nargs=argparse.REMAINDER,
|
||||
help='Arguments passed to clang-tidy')
|
||||
parsed_args = parser.parse_args()
|
||||
return Process(parsed_args.filepath, parsed_args.args)
|
||||
description = (
|
||||
"Run clang-tidy on single cc file.\n"
|
||||
"Use flags, defines and include paths as in default debug build.\n"
|
||||
"WARNING, this is a POC version with rough edges.")
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('filepath',
|
||||
help='Specifies the path of the .cc file to analyze.',
|
||||
type=ValidateCC)
|
||||
parser.add_argument('args',
|
||||
nargs=argparse.REMAINDER,
|
||||
help='Arguments passed to clang-tidy')
|
||||
parsed_args = parser.parse_args()
|
||||
return Process(parsed_args.filepath, parsed_args.args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(Main())
|
||||
sys.exit(Main())
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Generates a command-line for coverage.py. Useful for manual coverage runs.
|
||||
|
||||
Before running the generated command line, do this:
|
||||
@ -17,39 +16,32 @@ gn gen out/coverage --args='use_clang_coverage=true is_component_build=false'
|
||||
import sys
|
||||
|
||||
TESTS = [
|
||||
'video_capture_tests',
|
||||
'webrtc_nonparallel_tests',
|
||||
'video_engine_tests',
|
||||
'tools_unittests',
|
||||
'test_support_unittests',
|
||||
'slow_tests',
|
||||
'system_wrappers_unittests',
|
||||
'rtc_unittests',
|
||||
'rtc_stats_unittests',
|
||||
'rtc_pc_unittests',
|
||||
'rtc_media_unittests',
|
||||
'peerconnection_unittests',
|
||||
'modules_unittests',
|
||||
'modules_tests',
|
||||
'low_bandwidth_audio_test',
|
||||
'common_video_unittests',
|
||||
'common_audio_unittests',
|
||||
'audio_decoder_unittests'
|
||||
'video_capture_tests', 'webrtc_nonparallel_tests', 'video_engine_tests',
|
||||
'tools_unittests', 'test_support_unittests', 'slow_tests',
|
||||
'system_wrappers_unittests', 'rtc_unittests', 'rtc_stats_unittests',
|
||||
'rtc_pc_unittests', 'rtc_media_unittests', 'peerconnection_unittests',
|
||||
'modules_unittests', 'modules_tests', 'low_bandwidth_audio_test',
|
||||
'common_video_unittests', 'common_audio_unittests',
|
||||
'audio_decoder_unittests'
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
cmd = ([sys.executable, 'tools/code_coverage/coverage.py'] + TESTS +
|
||||
['-b out/coverage', '-o out/report'] +
|
||||
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
|
||||
['-c \'out/coverage/%s\'' % t for t in TESTS])
|
||||
cmd = ([sys.executable, 'tools/code_coverage/coverage.py'] + TESTS +
|
||||
['-b out/coverage', '-o out/report'] +
|
||||
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
|
||||
['-c \'out/coverage/%s\'' % t for t in TESTS])
|
||||
|
||||
def WithXvfb(binary):
|
||||
return '-c \'%s testing/xvfb.py %s\'' % (sys.executable, binary)
|
||||
modules_unittests = 'out/coverage/modules_unittests'
|
||||
cmd[cmd.index('-c \'%s\'' % modules_unittests)] = WithXvfb(modules_unittests)
|
||||
def WithXvfb(binary):
|
||||
return '-c \'%s testing/xvfb.py %s\'' % (sys.executable, binary)
|
||||
|
||||
modules_unittests = 'out/coverage/modules_unittests'
|
||||
cmd[cmd.index('-c \'%s\'' %
|
||||
modules_unittests)] = WithXvfb(modules_unittests)
|
||||
|
||||
print ' '.join(cmd)
|
||||
return 0
|
||||
|
||||
print ' '.join(cmd)
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Generates command-line instructions to produce one-time iOS coverage using
|
||||
coverage.py.
|
||||
|
||||
@ -53,122 +52,115 @@ import sys
|
||||
DIRECTORY = 'out/coverage'
|
||||
|
||||
TESTS = [
|
||||
'audio_decoder_unittests',
|
||||
'common_audio_unittests',
|
||||
'common_video_unittests',
|
||||
'modules_tests',
|
||||
'modules_unittests',
|
||||
'rtc_media_unittests',
|
||||
'rtc_pc_unittests',
|
||||
'rtc_stats_unittests',
|
||||
'rtc_unittests',
|
||||
'slow_tests',
|
||||
'system_wrappers_unittests',
|
||||
'test_support_unittests',
|
||||
'tools_unittests',
|
||||
'video_capture_tests',
|
||||
'video_engine_tests',
|
||||
'webrtc_nonparallel_tests',
|
||||
'audio_decoder_unittests',
|
||||
'common_audio_unittests',
|
||||
'common_video_unittests',
|
||||
'modules_tests',
|
||||
'modules_unittests',
|
||||
'rtc_media_unittests',
|
||||
'rtc_pc_unittests',
|
||||
'rtc_stats_unittests',
|
||||
'rtc_unittests',
|
||||
'slow_tests',
|
||||
'system_wrappers_unittests',
|
||||
'test_support_unittests',
|
||||
'tools_unittests',
|
||||
'video_capture_tests',
|
||||
'video_engine_tests',
|
||||
'webrtc_nonparallel_tests',
|
||||
]
|
||||
|
||||
XC_TESTS = [
|
||||
'apprtcmobile_tests',
|
||||
'sdk_framework_unittests',
|
||||
'sdk_unittests',
|
||||
'apprtcmobile_tests',
|
||||
'sdk_framework_unittests',
|
||||
'sdk_unittests',
|
||||
]
|
||||
|
||||
|
||||
def FormatIossimTest(test_name, is_xctest=False):
|
||||
args = ['%s/%s.app' % (DIRECTORY, test_name)]
|
||||
if is_xctest:
|
||||
args += ['%s/%s_module.xctest' % (DIRECTORY, test_name)]
|
||||
args = ['%s/%s.app' % (DIRECTORY, test_name)]
|
||||
if is_xctest:
|
||||
args += ['%s/%s_module.xctest' % (DIRECTORY, test_name)]
|
||||
|
||||
return '-c \'%s/iossim %s\'' % (DIRECTORY, ' '.join(args))
|
||||
return '-c \'%s/iossim %s\'' % (DIRECTORY, ' '.join(args))
|
||||
|
||||
|
||||
def GetGNArgs(is_simulator):
|
||||
target_cpu = 'x64' if is_simulator else 'arm64'
|
||||
return ([] +
|
||||
['target_os="ios"'] +
|
||||
['target_cpu="%s"' % target_cpu] +
|
||||
['use_clang_coverage=true'] +
|
||||
['is_component_build=false'] +
|
||||
['dcheck_always_on=true'])
|
||||
target_cpu = 'x64' if is_simulator else 'arm64'
|
||||
return ([] + ['target_os="ios"'] + ['target_cpu="%s"' % target_cpu] +
|
||||
['use_clang_coverage=true'] + ['is_component_build=false'] +
|
||||
['dcheck_always_on=true'])
|
||||
|
||||
|
||||
def GenerateIOSSimulatorCommand():
|
||||
gn_args_string = ' '.join(GetGNArgs(is_simulator=True))
|
||||
gn_cmd = ['gn', 'gen', DIRECTORY, '--args=\'%s\'' % gn_args_string]
|
||||
gn_args_string = ' '.join(GetGNArgs(is_simulator=True))
|
||||
gn_cmd = ['gn', 'gen', DIRECTORY, '--args=\'%s\'' % gn_args_string]
|
||||
|
||||
coverage_cmd = (
|
||||
[sys.executable, 'tools/code_coverage/coverage.py'] +
|
||||
["%s.app" % t for t in XC_TESTS + TESTS] +
|
||||
['-b %s' % DIRECTORY, '-o out/report'] +
|
||||
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
|
||||
[FormatIossimTest(t, is_xctest=True) for t in XC_TESTS] +
|
||||
[FormatIossimTest(t, is_xctest=False) for t in TESTS]
|
||||
)
|
||||
coverage_cmd = ([sys.executable, 'tools/code_coverage/coverage.py'] +
|
||||
["%s.app" % t for t in XC_TESTS + TESTS] +
|
||||
['-b %s' % DIRECTORY, '-o out/report'] +
|
||||
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
|
||||
[FormatIossimTest(t, is_xctest=True) for t in XC_TESTS] +
|
||||
[FormatIossimTest(t, is_xctest=False) for t in TESTS])
|
||||
|
||||
print 'To get code coverage using iOS simulator just run following commands:'
|
||||
print ''
|
||||
print ' '.join(gn_cmd)
|
||||
print ''
|
||||
print ' '.join(coverage_cmd)
|
||||
return 0
|
||||
print 'To get code coverage using iOS simulator just run following commands:'
|
||||
print ''
|
||||
print ' '.join(gn_cmd)
|
||||
print ''
|
||||
print ' '.join(coverage_cmd)
|
||||
return 0
|
||||
|
||||
|
||||
def GenerateIOSDeviceCommand():
|
||||
gn_args_string = ' '.join(GetGNArgs(is_simulator=False))
|
||||
gn_args_string = ' '.join(GetGNArgs(is_simulator=False))
|
||||
|
||||
coverage_report_cmd = (
|
||||
[sys.executable, 'tools/code_coverage/coverage.py'] +
|
||||
['%s.app' % t for t in TESTS] +
|
||||
['-b %s' % DIRECTORY] +
|
||||
['-o out/report'] +
|
||||
['-p %s/merged.profdata' % DIRECTORY] +
|
||||
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\'']
|
||||
)
|
||||
coverage_report_cmd = (
|
||||
[sys.executable, 'tools/code_coverage/coverage.py'] +
|
||||
['%s.app' % t for t in TESTS] + ['-b %s' % DIRECTORY] +
|
||||
['-o out/report'] + ['-p %s/merged.profdata' % DIRECTORY] +
|
||||
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''])
|
||||
|
||||
print 'Computing code coverage for real iOS device is a little bit tedious.'
|
||||
print ''
|
||||
print 'You will need:'
|
||||
print ''
|
||||
print '1. Generate xcode project and open it with Xcode 10+:'
|
||||
print ' gn gen %s --ide=xcode --args=\'%s\'' % (DIRECTORY, gn_args_string)
|
||||
print ' open %s/all.xcworkspace' % DIRECTORY
|
||||
print ''
|
||||
print '2. Execute these Run targets manually with Xcode Run button and '
|
||||
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
|
||||
print '\n'.join('- %s' % t for t in TESTS)
|
||||
print ''
|
||||
print '3. Execute these Test targets manually with Xcode Test button and '
|
||||
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
|
||||
print '\n'.join('- %s' % t for t in XC_TESTS)
|
||||
print ''
|
||||
print '4. Merge *.profraw files to *.profdata using llvm-profdata tool:'
|
||||
print (' build/mac_files/Xcode.app/Contents/Developer/Toolchains/' +
|
||||
'XcodeDefault.xctoolchain/usr/bin/llvm-profdata merge ' +
|
||||
'-o %s/merged.profdata ' % DIRECTORY +
|
||||
'-sparse=true %s/*.profraw' % DIRECTORY)
|
||||
print ''
|
||||
print '5. Generate coverage report:'
|
||||
print ' ' + ' '.join(coverage_report_cmd)
|
||||
return 0
|
||||
print 'Computing code coverage for real iOS device is a little bit tedious.'
|
||||
print ''
|
||||
print 'You will need:'
|
||||
print ''
|
||||
print '1. Generate xcode project and open it with Xcode 10+:'
|
||||
print ' gn gen %s --ide=xcode --args=\'%s\'' % (DIRECTORY, gn_args_string)
|
||||
print ' open %s/all.xcworkspace' % DIRECTORY
|
||||
print ''
|
||||
print '2. Execute these Run targets manually with Xcode Run button and '
|
||||
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
|
||||
print '\n'.join('- %s' % t for t in TESTS)
|
||||
print ''
|
||||
print '3. Execute these Test targets manually with Xcode Test button and '
|
||||
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
|
||||
print '\n'.join('- %s' % t for t in XC_TESTS)
|
||||
print ''
|
||||
print '4. Merge *.profraw files to *.profdata using llvm-profdata tool:'
|
||||
print(' build/mac_files/Xcode.app/Contents/Developer/Toolchains/' +
|
||||
'XcodeDefault.xctoolchain/usr/bin/llvm-profdata merge ' +
|
||||
'-o %s/merged.profdata ' % DIRECTORY +
|
||||
'-sparse=true %s/*.profraw' % DIRECTORY)
|
||||
print ''
|
||||
print '5. Generate coverage report:'
|
||||
print ' ' + ' '.join(coverage_report_cmd)
|
||||
return 0
|
||||
|
||||
|
||||
def Main():
|
||||
if len(sys.argv) < 2:
|
||||
print 'Please specify type of coverage:'
|
||||
print ' %s simulator' % sys.argv[0]
|
||||
print ' %s device' % sys.argv[0]
|
||||
elif sys.argv[1] == 'simulator':
|
||||
GenerateIOSSimulatorCommand()
|
||||
elif sys.argv[1] == 'device':
|
||||
GenerateIOSDeviceCommand()
|
||||
else:
|
||||
print 'Unsupported type of coverage'
|
||||
if len(sys.argv) < 2:
|
||||
print 'Please specify type of coverage:'
|
||||
print ' %s simulator' % sys.argv[0]
|
||||
print ' %s device' % sys.argv[0]
|
||||
elif sys.argv[1] == 'simulator':
|
||||
GenerateIOSSimulatorCommand()
|
||||
elif sys.argv[1] == 'device':
|
||||
GenerateIOSDeviceCommand()
|
||||
else:
|
||||
print 'Unsupported type of coverage'
|
||||
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(Main())
|
||||
sys.exit(Main())
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
|
||||
import psutil
|
||||
import sys
|
||||
|
||||
@ -17,67 +16,68 @@ from matplotlib import pyplot
|
||||
|
||||
|
||||
class CpuSnapshot(object):
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
self.samples = []
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
self.samples = []
|
||||
|
||||
def Capture(self, sample_count):
|
||||
print ('Capturing %d CPU samples for %s...' %
|
||||
((sample_count - len(self.samples)), self.label))
|
||||
while len(self.samples) < sample_count:
|
||||
self.samples.append(psutil.cpu_percent(1.0, False))
|
||||
def Capture(self, sample_count):
|
||||
print('Capturing %d CPU samples for %s...' %
|
||||
((sample_count - len(self.samples)), self.label))
|
||||
while len(self.samples) < sample_count:
|
||||
self.samples.append(psutil.cpu_percent(1.0, False))
|
||||
|
||||
def Text(self):
|
||||
return ('%s: avg=%s, median=%s, min=%s, max=%s' %
|
||||
(self.label, numpy.average(self.samples),
|
||||
numpy.median(self.samples),
|
||||
numpy.min(self.samples), numpy.max(self.samples)))
|
||||
def Text(self):
|
||||
return ('%s: avg=%s, median=%s, min=%s, max=%s' %
|
||||
(self.label, numpy.average(self.samples),
|
||||
numpy.median(self.samples), numpy.min(
|
||||
self.samples), numpy.max(self.samples)))
|
||||
|
||||
def Max(self):
|
||||
return numpy.max(self.samples)
|
||||
def Max(self):
|
||||
return numpy.max(self.samples)
|
||||
|
||||
|
||||
def GrabCpuSamples(sample_count):
|
||||
print 'Label for snapshot (enter to quit): '
|
||||
label = raw_input().strip()
|
||||
if len(label) == 0:
|
||||
return None
|
||||
print 'Label for snapshot (enter to quit): '
|
||||
label = raw_input().strip()
|
||||
if len(label) == 0:
|
||||
return None
|
||||
|
||||
snapshot = CpuSnapshot(label)
|
||||
snapshot.Capture(sample_count)
|
||||
snapshot = CpuSnapshot(label)
|
||||
snapshot.Capture(sample_count)
|
||||
|
||||
return snapshot
|
||||
return snapshot
|
||||
|
||||
|
||||
def main():
|
||||
print 'How many seconds to capture per snapshot (enter for 60)?'
|
||||
sample_count = raw_input().strip()
|
||||
if len(sample_count) > 0 and int(sample_count) > 0:
|
||||
sample_count = int(sample_count)
|
||||
else:
|
||||
print 'Defaulting to 60 samples.'
|
||||
sample_count = 60
|
||||
print 'How many seconds to capture per snapshot (enter for 60)?'
|
||||
sample_count = raw_input().strip()
|
||||
if len(sample_count) > 0 and int(sample_count) > 0:
|
||||
sample_count = int(sample_count)
|
||||
else:
|
||||
print 'Defaulting to 60 samples.'
|
||||
sample_count = 60
|
||||
|
||||
snapshots = []
|
||||
while True:
|
||||
snapshot = GrabCpuSamples(sample_count)
|
||||
if snapshot is None:
|
||||
break
|
||||
snapshots.append(snapshot)
|
||||
snapshots = []
|
||||
while True:
|
||||
snapshot = GrabCpuSamples(sample_count)
|
||||
if snapshot is None:
|
||||
break
|
||||
snapshots.append(snapshot)
|
||||
|
||||
if len(snapshots) == 0:
|
||||
print 'no samples captured'
|
||||
return -1
|
||||
if len(snapshots) == 0:
|
||||
print 'no samples captured'
|
||||
return -1
|
||||
|
||||
pyplot.title('CPU usage')
|
||||
pyplot.title('CPU usage')
|
||||
|
||||
for s in snapshots:
|
||||
pyplot.plot(s.samples, label=s.Text(), linewidth=2)
|
||||
for s in snapshots:
|
||||
pyplot.plot(s.samples, label=s.Text(), linewidth=2)
|
||||
|
||||
pyplot.legend()
|
||||
pyplot.legend()
|
||||
|
||||
pyplot.show()
|
||||
return 0
|
||||
|
||||
pyplot.show()
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Downloads precompiled tools.
|
||||
|
||||
These are checked into the repository as SHA-1 hashes (see *.sha1 files in
|
||||
@ -17,12 +16,10 @@ so please download and compile these tools manually if this script fails.
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SRC_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir))
|
||||
sys.path.append(os.path.join(SRC_DIR, 'build'))
|
||||
|
||||
|
||||
import find_depot_tools
|
||||
find_depot_tools.add_depot_tools_to_path()
|
||||
import gclient_utils
|
||||
@ -30,32 +27,34 @@ import subprocess2
|
||||
|
||||
|
||||
def main(directories):
|
||||
if not directories:
|
||||
directories = [SCRIPT_DIR]
|
||||
if not directories:
|
||||
directories = [SCRIPT_DIR]
|
||||
|
||||
for path in directories:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH,
|
||||
'download_from_google_storage.py'),
|
||||
'--directory',
|
||||
'--num_threads=10',
|
||||
'--bucket', 'chrome-webrtc-resources',
|
||||
'--auto_platform',
|
||||
'--recursive',
|
||||
path,
|
||||
]
|
||||
print 'Downloading precompiled tools...'
|
||||
for path in directories:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH,
|
||||
'download_from_google_storage.py'),
|
||||
'--directory',
|
||||
'--num_threads=10',
|
||||
'--bucket',
|
||||
'chrome-webrtc-resources',
|
||||
'--auto_platform',
|
||||
'--recursive',
|
||||
path,
|
||||
]
|
||||
print 'Downloading precompiled tools...'
|
||||
|
||||
# Perform download similar to how gclient hooks execute.
|
||||
try:
|
||||
gclient_utils.CheckCallAndFilter(
|
||||
cmd, cwd=SRC_DIR, always_show_header=True)
|
||||
except (gclient_utils.Error, subprocess2.CalledProcessError) as e:
|
||||
print 'Error: %s' % str(e)
|
||||
return 2
|
||||
return 0
|
||||
# Perform download similar to how gclient hooks execute.
|
||||
try:
|
||||
gclient_utils.CheckCallAndFilter(cmd,
|
||||
cwd=SRC_DIR,
|
||||
always_show_header=True)
|
||||
except (gclient_utils.Error, subprocess2.CalledProcessError) as e:
|
||||
print 'Error: %s' % str(e)
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Checks if a virtual webcam is running and starts it if not.
|
||||
|
||||
Returns a non-zero return code if the webcam could not be started.
|
||||
@ -32,74 +31,73 @@ import psutil # pylint: disable=F0401
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
WEBCAM_WIN = ('schtasks', '/run', '/tn', 'ManyCam')
|
||||
WEBCAM_MAC = ('open', '/Applications/ManyCam/ManyCam.app')
|
||||
|
||||
|
||||
def IsWebCamRunning():
|
||||
if sys.platform == 'win32':
|
||||
process_name = 'ManyCam.exe'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
process_name = 'ManyCam'
|
||||
elif sys.platform.startswith('linux'):
|
||||
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
|
||||
# longer in use.
|
||||
print 'Virtual webcam: no-op on Linux'
|
||||
return True
|
||||
else:
|
||||
raise Exception('Unsupported platform: %s' % sys.platform)
|
||||
for p in psutil.process_iter():
|
||||
try:
|
||||
if process_name == p.name:
|
||||
print 'Found a running virtual webcam (%s with PID %s)' % (p.name,
|
||||
p.pid)
|
||||
if sys.platform == 'win32':
|
||||
process_name = 'ManyCam.exe'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
process_name = 'ManyCam'
|
||||
elif sys.platform.startswith('linux'):
|
||||
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
|
||||
# longer in use.
|
||||
print 'Virtual webcam: no-op on Linux'
|
||||
return True
|
||||
except psutil.AccessDenied:
|
||||
pass # This is normal if we query sys processes, etc.
|
||||
return False
|
||||
else:
|
||||
raise Exception('Unsupported platform: %s' % sys.platform)
|
||||
for p in psutil.process_iter():
|
||||
try:
|
||||
if process_name == p.name:
|
||||
print 'Found a running virtual webcam (%s with PID %s)' % (
|
||||
p.name, p.pid)
|
||||
return True
|
||||
except psutil.AccessDenied:
|
||||
pass # This is normal if we query sys processes, etc.
|
||||
return False
|
||||
|
||||
|
||||
def StartWebCam():
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
subprocess.check_call(WEBCAM_WIN)
|
||||
print 'Successfully launched virtual webcam.'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
subprocess.check_call(WEBCAM_MAC)
|
||||
print 'Successfully launched virtual webcam.'
|
||||
elif sys.platform.startswith('linux'):
|
||||
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
|
||||
# longer in use.
|
||||
print 'Not implemented on Linux'
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
subprocess.check_call(WEBCAM_WIN)
|
||||
print 'Successfully launched virtual webcam.'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
subprocess.check_call(WEBCAM_MAC)
|
||||
print 'Successfully launched virtual webcam.'
|
||||
elif sys.platform.startswith('linux'):
|
||||
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
|
||||
# longer in use.
|
||||
print 'Not implemented on Linux'
|
||||
|
||||
except Exception as e:
|
||||
print 'Failed to launch virtual webcam: %s' % e
|
||||
return False
|
||||
except Exception as e:
|
||||
print 'Failed to launch virtual webcam: %s' % e
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _ForcePythonInterpreter(cmd):
|
||||
"""Returns the fixed command line to call the right python executable."""
|
||||
out = cmd[:]
|
||||
if out[0] == 'python':
|
||||
out[0] = sys.executable
|
||||
elif out[0].endswith('.py'):
|
||||
out.insert(0, sys.executable)
|
||||
return out
|
||||
"""Returns the fixed command line to call the right python executable."""
|
||||
out = cmd[:]
|
||||
if out[0] == 'python':
|
||||
out[0] = sys.executable
|
||||
elif out[0].endswith('.py'):
|
||||
out.insert(0, sys.executable)
|
||||
return out
|
||||
|
||||
|
||||
def Main(argv):
|
||||
if not IsWebCamRunning():
|
||||
if not StartWebCam():
|
||||
return 1
|
||||
if not IsWebCamRunning():
|
||||
if not StartWebCam():
|
||||
return 1
|
||||
|
||||
if argv:
|
||||
return subprocess.call(_ForcePythonInterpreter(argv))
|
||||
else:
|
||||
return 0
|
||||
if argv:
|
||||
return subprocess.call(_ForcePythonInterpreter(argv))
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(Main(sys.argv[1:]))
|
||||
sys.exit(Main(sys.argv[1:]))
|
||||
|
||||
@ -55,7 +55,6 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir))
|
||||
sys.path.append(os.path.join(SRC_DIR, 'build'))
|
||||
@ -63,39 +62,40 @@ import find_depot_tools
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
desc = 'Generates a GN executable targeting the host machine.'
|
||||
parser = argparse.ArgumentParser(description=desc)
|
||||
parser.add_argument('--executable_name',
|
||||
required=True,
|
||||
help='Name of the executable to build')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
desc = 'Generates a GN executable targeting the host machine.'
|
||||
parser = argparse.ArgumentParser(description=desc)
|
||||
parser.add_argument('--executable_name',
|
||||
required=True,
|
||||
help='Name of the executable to build')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@contextmanager
|
||||
def HostBuildDir():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
yield temp_dir
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
yield temp_dir
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def _RunCommand(argv, cwd=SRC_DIR, **kwargs):
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
subprocess.check_call(argv, cwd=cwd, stdout=devnull, **kwargs)
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
subprocess.check_call(argv, cwd=cwd, stdout=devnull, **kwargs)
|
||||
|
||||
|
||||
def DepotToolPath(*args):
|
||||
return os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, *args)
|
||||
return os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, *args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ARGS = _ParseArgs()
|
||||
EXECUTABLE_TO_BUILD = ARGS.executable_name
|
||||
EXECUTABLE_FINAL_NAME = ARGS.executable_name + '_host'
|
||||
with HostBuildDir() as build_dir:
|
||||
_RunCommand([sys.executable, DepotToolPath('gn.py'), 'gen', build_dir])
|
||||
_RunCommand([DepotToolPath('ninja'), '-C', build_dir, EXECUTABLE_TO_BUILD])
|
||||
shutil.copy(os.path.join(build_dir, EXECUTABLE_TO_BUILD),
|
||||
EXECUTABLE_FINAL_NAME)
|
||||
ARGS = _ParseArgs()
|
||||
EXECUTABLE_TO_BUILD = ARGS.executable_name
|
||||
EXECUTABLE_FINAL_NAME = ARGS.executable_name + '_host'
|
||||
with HostBuildDir() as build_dir:
|
||||
_RunCommand([sys.executable, DepotToolPath('gn.py'), 'gen', build_dir])
|
||||
_RunCommand(
|
||||
[DepotToolPath('ninja'), '-C', build_dir, EXECUTABLE_TO_BUILD])
|
||||
shutil.copy(os.path.join(build_dir, EXECUTABLE_TO_BUILD),
|
||||
EXECUTABLE_FINAL_NAME)
|
||||
|
||||
@ -15,30 +15,32 @@ import sys
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--isolated-script-test-perf-output')
|
||||
args, unrecognized_args = parser.parse_known_args()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--isolated-script-test-perf-output')
|
||||
args, unrecognized_args = parser.parse_known_args()
|
||||
|
||||
test_command = _ForcePythonInterpreter(unrecognized_args)
|
||||
if args.isolated_script_test_perf_output:
|
||||
test_command += ['--isolated_script_test_perf_output=' +
|
||||
args.isolated_script_test_perf_output]
|
||||
logging.info('Running %r', test_command)
|
||||
test_command = _ForcePythonInterpreter(unrecognized_args)
|
||||
if args.isolated_script_test_perf_output:
|
||||
test_command += [
|
||||
'--isolated_script_test_perf_output=' +
|
||||
args.isolated_script_test_perf_output
|
||||
]
|
||||
logging.info('Running %r', test_command)
|
||||
|
||||
return subprocess.call(test_command)
|
||||
return subprocess.call(test_command)
|
||||
|
||||
|
||||
def _ForcePythonInterpreter(cmd):
|
||||
"""Returns the fixed command line to call the right python executable."""
|
||||
out = cmd[:]
|
||||
if out[0] == 'python':
|
||||
out[0] = sys.executable
|
||||
elif out[0].endswith('.py'):
|
||||
out.insert(0, sys.executable)
|
||||
return out
|
||||
"""Returns the fixed command line to call the right python executable."""
|
||||
out = cmd[:]
|
||||
if out[0] == 'python':
|
||||
out[0] = sys.executable
|
||||
elif out[0].endswith('.py'):
|
||||
out.insert(0, sys.executable)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# pylint: disable=W0101
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
sys.exit(main())
|
||||
# pylint: disable=W0101
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
sys.exit(main())
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""
|
||||
This file emits the list of reasons why a particular build needs to be clobbered
|
||||
(or a list of 'landmines').
|
||||
@ -20,49 +19,48 @@ CHECKOUT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir))
|
||||
sys.path.insert(0, os.path.join(CHECKOUT_ROOT, 'build'))
|
||||
import landmine_utils
|
||||
|
||||
|
||||
host_os = landmine_utils.host_os # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def print_landmines(): # pylint: disable=invalid-name
|
||||
"""
|
||||
"""
|
||||
ALL LANDMINES ARE EMITTED FROM HERE.
|
||||
"""
|
||||
# DO NOT add landmines as part of a regular CL. Landmines are a last-effort
|
||||
# bandaid fix if a CL that got landed has a build dependency bug and all bots
|
||||
# need to be cleaned up. If you're writing a new CL that causes build
|
||||
# dependency problems, fix the dependency problems instead of adding a
|
||||
# landmine.
|
||||
# See the Chromium version in src/build/get_landmines.py for usage examples.
|
||||
print 'Clobber to remove out/{Debug,Release}/args.gn (webrtc:5070)'
|
||||
if host_os() == 'win':
|
||||
print 'Clobber to resolve some issues with corrupt .pdb files on bots.'
|
||||
print 'Clobber due to corrupt .pdb files (after #14623)'
|
||||
print 'Clobber due to Win 64-bit Debug linking error (crbug.com/668961)'
|
||||
print ('Clobber due to Win Clang Debug linking errors in '
|
||||
'https://codereview.webrtc.org/2786603002')
|
||||
print ('Clobber due to Win Debug linking errors in '
|
||||
'https://codereview.webrtc.org/2832063003/')
|
||||
print 'Clobber win x86 bots (issues with isolated files).'
|
||||
if host_os() == 'mac':
|
||||
print 'Clobber due to iOS compile errors (crbug.com/694721)'
|
||||
print 'Clobber to unblock https://codereview.webrtc.org/2709573003'
|
||||
print ('Clobber to fix https://codereview.webrtc.org/2709573003 after '
|
||||
'landing')
|
||||
print ('Clobber to fix https://codereview.webrtc.org/2767383005 before'
|
||||
'landing (changing rtc_executable -> rtc_test on iOS)')
|
||||
print ('Clobber to fix https://codereview.webrtc.org/2767383005 before'
|
||||
'landing (changing rtc_executable -> rtc_test on iOS)')
|
||||
print 'Another landmine for low_bandwidth_audio_test (webrtc:7430)'
|
||||
print 'Clobber to change neteq_rtpplay type to executable'
|
||||
print 'Clobber to remove .xctest files.'
|
||||
print 'Clobber to remove .xctest files (take 2).'
|
||||
# DO NOT add landmines as part of a regular CL. Landmines are a last-effort
|
||||
# bandaid fix if a CL that got landed has a build dependency bug and all bots
|
||||
# need to be cleaned up. If you're writing a new CL that causes build
|
||||
# dependency problems, fix the dependency problems instead of adding a
|
||||
# landmine.
|
||||
# See the Chromium version in src/build/get_landmines.py for usage examples.
|
||||
print 'Clobber to remove out/{Debug,Release}/args.gn (webrtc:5070)'
|
||||
if host_os() == 'win':
|
||||
print 'Clobber to resolve some issues with corrupt .pdb files on bots.'
|
||||
print 'Clobber due to corrupt .pdb files (after #14623)'
|
||||
print 'Clobber due to Win 64-bit Debug linking error (crbug.com/668961)'
|
||||
print('Clobber due to Win Clang Debug linking errors in '
|
||||
'https://codereview.webrtc.org/2786603002')
|
||||
print('Clobber due to Win Debug linking errors in '
|
||||
'https://codereview.webrtc.org/2832063003/')
|
||||
print 'Clobber win x86 bots (issues with isolated files).'
|
||||
if host_os() == 'mac':
|
||||
print 'Clobber due to iOS compile errors (crbug.com/694721)'
|
||||
print 'Clobber to unblock https://codereview.webrtc.org/2709573003'
|
||||
print('Clobber to fix https://codereview.webrtc.org/2709573003 after '
|
||||
'landing')
|
||||
print('Clobber to fix https://codereview.webrtc.org/2767383005 before'
|
||||
'landing (changing rtc_executable -> rtc_test on iOS)')
|
||||
print('Clobber to fix https://codereview.webrtc.org/2767383005 before'
|
||||
'landing (changing rtc_executable -> rtc_test on iOS)')
|
||||
print 'Another landmine for low_bandwidth_audio_test (webrtc:7430)'
|
||||
print 'Clobber to change neteq_rtpplay type to executable'
|
||||
print 'Clobber to remove .xctest files.'
|
||||
print 'Clobber to remove .xctest files (take 2).'
|
||||
|
||||
|
||||
def main():
|
||||
print_landmines()
|
||||
return 0
|
||||
print_landmines()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""
|
||||
This tool tries to fix (some) errors reported by `gn gen --check` or
|
||||
`gn check`.
|
||||
@ -31,72 +30,78 @@ from collections import defaultdict
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
CHROMIUM_DIRS = ['base', 'build', 'buildtools',
|
||||
'testing', 'third_party', 'tools']
|
||||
CHROMIUM_DIRS = [
|
||||
'base', 'build', 'buildtools', 'testing', 'third_party', 'tools'
|
||||
]
|
||||
|
||||
TARGET_RE = re.compile(
|
||||
r'(?P<indentation_level>\s*)\w*\("(?P<target_name>\w*)"\) {$')
|
||||
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
def __init__(self):
|
||||
self._closed = False
|
||||
self._name = None
|
||||
self._name = tempfile.mkdtemp()
|
||||
def __init__(self):
|
||||
self._closed = False
|
||||
self._name = None
|
||||
self._name = tempfile.mkdtemp()
|
||||
|
||||
def __enter__(self):
|
||||
return self._name
|
||||
def __enter__(self):
|
||||
return self._name
|
||||
|
||||
def __exit__(self, exc, value, _tb):
|
||||
if self._name and not self._closed:
|
||||
shutil.rmtree(self._name)
|
||||
self._closed = True
|
||||
def __exit__(self, exc, value, _tb):
|
||||
if self._name and not self._closed:
|
||||
shutil.rmtree(self._name)
|
||||
self._closed = True
|
||||
|
||||
|
||||
def Run(cmd):
|
||||
print 'Running:', ' '.join(cmd)
|
||||
sub = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
return sub.communicate()
|
||||
print 'Running:', ' '.join(cmd)
|
||||
sub = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
return sub.communicate()
|
||||
|
||||
|
||||
def FixErrors(filename, missing_deps, deleted_sources):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
fixed_file = ''
|
||||
indentation_level = None
|
||||
for line in lines:
|
||||
match = TARGET_RE.match(line)
|
||||
if match:
|
||||
target = match.group('target_name')
|
||||
if target in missing_deps:
|
||||
indentation_level = match.group('indentation_level')
|
||||
elif indentation_level is not None:
|
||||
match = re.match(indentation_level + '}$', line)
|
||||
if match:
|
||||
line = ('deps = [\n' +
|
||||
''.join(' "' + dep + '",\n' for dep in missing_deps[target]) +
|
||||
']\n') + line
|
||||
indentation_level = None
|
||||
elif line.strip().startswith('deps'):
|
||||
is_empty_deps = line.strip() == 'deps = []'
|
||||
line = 'deps = [\n' if is_empty_deps else line
|
||||
line += ''.join(' "' + dep + '",\n' for dep in missing_deps[target])
|
||||
line += ']\n' if is_empty_deps else ''
|
||||
indentation_level = None
|
||||
fixed_file = ''
|
||||
indentation_level = None
|
||||
for line in lines:
|
||||
match = TARGET_RE.match(line)
|
||||
if match:
|
||||
target = match.group('target_name')
|
||||
if target in missing_deps:
|
||||
indentation_level = match.group('indentation_level')
|
||||
elif indentation_level is not None:
|
||||
match = re.match(indentation_level + '}$', line)
|
||||
if match:
|
||||
line = ('deps = [\n' + ''.join(' "' + dep + '",\n'
|
||||
for dep in missing_deps[target])
|
||||
+ ']\n') + line
|
||||
indentation_level = None
|
||||
elif line.strip().startswith('deps'):
|
||||
is_empty_deps = line.strip() == 'deps = []'
|
||||
line = 'deps = [\n' if is_empty_deps else line
|
||||
line += ''.join(' "' + dep + '",\n'
|
||||
for dep in missing_deps[target])
|
||||
line += ']\n' if is_empty_deps else ''
|
||||
indentation_level = None
|
||||
|
||||
if line.strip() not in deleted_sources:
|
||||
fixed_file += line
|
||||
if line.strip() not in deleted_sources:
|
||||
fixed_file += line
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write(fixed_file)
|
||||
with open(filename, 'w') as f:
|
||||
f.write(fixed_file)
|
||||
|
||||
Run(['gn', 'format', filename])
|
||||
|
||||
Run(['gn', 'format', filename])
|
||||
|
||||
def FirstNonEmpty(iterable):
|
||||
"""Return first item which evaluates to True, or fallback to None."""
|
||||
return next((x for x in iterable if x), None)
|
||||
"""Return first item which evaluates to True, or fallback to None."""
|
||||
return next((x for x in iterable if x), None)
|
||||
|
||||
|
||||
def Rebase(base_path, dependency_path, dependency):
|
||||
"""Adapt paths so they work both in stand-alone WebRTC and Chromium tree.
|
||||
"""Adapt paths so they work both in stand-alone WebRTC and Chromium tree.
|
||||
|
||||
To cope with varying top-level directory (WebRTC VS Chromium), we use:
|
||||
* relative paths for WebRTC modules.
|
||||
@ -113,77 +118,82 @@ def Rebase(base_path, dependency_path, dependency):
|
||||
Full target path (E.g. '../rtc_base/time:timestamp_extrapolator').
|
||||
"""
|
||||
|
||||
root = FirstNonEmpty(dependency_path.split('/'))
|
||||
if root in CHROMIUM_DIRS:
|
||||
# Chromium paths must remain absolute. E.g. //third_party//abseil-cpp...
|
||||
rebased = dependency_path
|
||||
else:
|
||||
base_path = base_path.split(os.path.sep)
|
||||
dependency_path = dependency_path.split(os.path.sep)
|
||||
root = FirstNonEmpty(dependency_path.split('/'))
|
||||
if root in CHROMIUM_DIRS:
|
||||
# Chromium paths must remain absolute. E.g. //third_party//abseil-cpp...
|
||||
rebased = dependency_path
|
||||
else:
|
||||
base_path = base_path.split(os.path.sep)
|
||||
dependency_path = dependency_path.split(os.path.sep)
|
||||
|
||||
first_difference = None
|
||||
shortest_length = min(len(dependency_path), len(base_path))
|
||||
for i in range(shortest_length):
|
||||
if dependency_path[i] != base_path[i]:
|
||||
first_difference = i
|
||||
break
|
||||
first_difference = None
|
||||
shortest_length = min(len(dependency_path), len(base_path))
|
||||
for i in range(shortest_length):
|
||||
if dependency_path[i] != base_path[i]:
|
||||
first_difference = i
|
||||
break
|
||||
|
||||
first_difference = first_difference or shortest_length
|
||||
base_path = base_path[first_difference:]
|
||||
dependency_path = dependency_path[first_difference:]
|
||||
rebased = os.path.sep.join((['..'] * len(base_path)) + dependency_path)
|
||||
return rebased + ':' + dependency
|
||||
|
||||
first_difference = first_difference or shortest_length
|
||||
base_path = base_path[first_difference:]
|
||||
dependency_path = dependency_path[first_difference:]
|
||||
rebased = os.path.sep.join((['..'] * len(base_path)) + dependency_path)
|
||||
return rebased + ':' + dependency
|
||||
|
||||
def main():
|
||||
deleted_sources = set()
|
||||
errors_by_file = defaultdict(lambda: defaultdict(set))
|
||||
deleted_sources = set()
|
||||
errors_by_file = defaultdict(lambda: defaultdict(set))
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
mb_script_path = os.path.join(SCRIPT_DIR, 'mb', 'mb.py')
|
||||
mb_config_file_path = os.path.join(SCRIPT_DIR, 'mb', 'mb_config.pyl')
|
||||
mb_gen_command = ([
|
||||
mb_script_path, 'gen',
|
||||
tmp_dir,
|
||||
'--config-file', mb_config_file_path,
|
||||
] + sys.argv[1:])
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
mb_script_path = os.path.join(SCRIPT_DIR, 'mb', 'mb.py')
|
||||
mb_config_file_path = os.path.join(SCRIPT_DIR, 'mb', 'mb_config.pyl')
|
||||
mb_gen_command = ([
|
||||
mb_script_path,
|
||||
'gen',
|
||||
tmp_dir,
|
||||
'--config-file',
|
||||
mb_config_file_path,
|
||||
] + sys.argv[1:])
|
||||
|
||||
mb_output = Run(mb_gen_command)
|
||||
errors = mb_output[0].split('ERROR')[1:]
|
||||
mb_output = Run(mb_gen_command)
|
||||
errors = mb_output[0].split('ERROR')[1:]
|
||||
|
||||
if mb_output[1]:
|
||||
print mb_output[1]
|
||||
return 1
|
||||
if mb_output[1]:
|
||||
print mb_output[1]
|
||||
return 1
|
||||
|
||||
for error in errors:
|
||||
error = error.splitlines()
|
||||
target_msg = 'The target:'
|
||||
if target_msg not in error:
|
||||
target_msg = 'It is not in any dependency of'
|
||||
if target_msg not in error:
|
||||
print '\n'.join(error)
|
||||
continue
|
||||
index = error.index(target_msg) + 1
|
||||
path, target = error[index].strip().split(':')
|
||||
if error[index+1] in ('is including a file from the target:',
|
||||
'The include file is in the target(s):'):
|
||||
dep = error[index+2].strip()
|
||||
dep_path, dep = dep.split(':')
|
||||
dep = Rebase(path, dep_path, dep)
|
||||
# Replacing /target:target with /target
|
||||
dep = re.sub(r'/(\w+):(\1)$', r'/\1', dep)
|
||||
path = os.path.join(path[2:], 'BUILD.gn')
|
||||
errors_by_file[path][target].add(dep)
|
||||
elif error[index+1] == 'has a source file:':
|
||||
deleted_file = '"' + os.path.basename(error[index+2].strip()) + '",'
|
||||
deleted_sources.add(deleted_file)
|
||||
else:
|
||||
print '\n'.join(error)
|
||||
continue
|
||||
for error in errors:
|
||||
error = error.splitlines()
|
||||
target_msg = 'The target:'
|
||||
if target_msg not in error:
|
||||
target_msg = 'It is not in any dependency of'
|
||||
if target_msg not in error:
|
||||
print '\n'.join(error)
|
||||
continue
|
||||
index = error.index(target_msg) + 1
|
||||
path, target = error[index].strip().split(':')
|
||||
if error[index + 1] in ('is including a file from the target:',
|
||||
'The include file is in the target(s):'):
|
||||
dep = error[index + 2].strip()
|
||||
dep_path, dep = dep.split(':')
|
||||
dep = Rebase(path, dep_path, dep)
|
||||
# Replacing /target:target with /target
|
||||
dep = re.sub(r'/(\w+):(\1)$', r'/\1', dep)
|
||||
path = os.path.join(path[2:], 'BUILD.gn')
|
||||
errors_by_file[path][target].add(dep)
|
||||
elif error[index + 1] == 'has a source file:':
|
||||
deleted_file = '"' + os.path.basename(
|
||||
error[index + 2].strip()) + '",'
|
||||
deleted_sources.add(deleted_file)
|
||||
else:
|
||||
print '\n'.join(error)
|
||||
continue
|
||||
|
||||
for path, missing_deps in errors_by_file.items():
|
||||
FixErrors(path, missing_deps, deleted_sources)
|
||||
for path, missing_deps in errors_by_file.items():
|
||||
FixErrors(path, missing_deps, deleted_sources)
|
||||
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -75,165 +75,174 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
Args = collections.namedtuple('Args',
|
||||
['gtest_parallel_args', 'test_env', 'output_dir',
|
||||
'test_artifacts_dir'])
|
||||
Args = collections.namedtuple(
|
||||
'Args',
|
||||
['gtest_parallel_args', 'test_env', 'output_dir', 'test_artifacts_dir'])
|
||||
|
||||
|
||||
def _CatFiles(file_list, output_file):
|
||||
with open(output_file, 'w') as output_file:
|
||||
for filename in file_list:
|
||||
with open(filename) as input_file:
|
||||
output_file.write(input_file.read())
|
||||
os.remove(filename)
|
||||
with open(output_file, 'w') as output_file:
|
||||
for filename in file_list:
|
||||
with open(filename) as input_file:
|
||||
output_file.write(input_file.read())
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def _ParseWorkersOption(workers):
|
||||
"""Interpret Nx syntax as N * cpu_count. Int value is left as is."""
|
||||
base = float(workers.rstrip('x'))
|
||||
if workers.endswith('x'):
|
||||
result = int(base * multiprocessing.cpu_count())
|
||||
else:
|
||||
result = int(base)
|
||||
return max(result, 1) # Sanitize when using e.g. '0.5x'.
|
||||
"""Interpret Nx syntax as N * cpu_count. Int value is left as is."""
|
||||
base = float(workers.rstrip('x'))
|
||||
if workers.endswith('x'):
|
||||
result = int(base * multiprocessing.cpu_count())
|
||||
else:
|
||||
result = int(base)
|
||||
return max(result, 1) # Sanitize when using e.g. '0.5x'.
|
||||
|
||||
|
||||
class ReconstructibleArgumentGroup(object):
|
||||
"""An argument group that can be converted back into a command line.
|
||||
"""An argument group that can be converted back into a command line.
|
||||
|
||||
This acts like ArgumentParser.add_argument_group, but names of arguments added
|
||||
to it are also kept in a list, so that parsed options from
|
||||
ArgumentParser.parse_args can be reconstructed back into a command line (list
|
||||
of args) based on the list of wanted keys."""
|
||||
def __init__(self, parser, *args, **kwargs):
|
||||
self._group = parser.add_argument_group(*args, **kwargs)
|
||||
self._keys = []
|
||||
|
||||
def AddArgument(self, *args, **kwargs):
|
||||
arg = self._group.add_argument(*args, **kwargs)
|
||||
self._keys.append(arg.dest)
|
||||
def __init__(self, parser, *args, **kwargs):
|
||||
self._group = parser.add_argument_group(*args, **kwargs)
|
||||
self._keys = []
|
||||
|
||||
def RemakeCommandLine(self, options):
|
||||
result = []
|
||||
for key in self._keys:
|
||||
value = getattr(options, key)
|
||||
if value is True:
|
||||
result.append('--%s' % key)
|
||||
elif value is not None:
|
||||
result.append('--%s=%s' % (key, value))
|
||||
return result
|
||||
def AddArgument(self, *args, **kwargs):
|
||||
arg = self._group.add_argument(*args, **kwargs)
|
||||
self._keys.append(arg.dest)
|
||||
|
||||
def RemakeCommandLine(self, options):
|
||||
result = []
|
||||
for key in self._keys:
|
||||
value = getattr(options, key)
|
||||
if value is True:
|
||||
result.append('--%s' % key)
|
||||
elif value is not None:
|
||||
result.append('--%s=%s' % (key, value))
|
||||
return result
|
||||
|
||||
|
||||
def ParseArgs(argv=None):
|
||||
parser = argparse.ArgumentParser(argv)
|
||||
parser = argparse.ArgumentParser(argv)
|
||||
|
||||
gtest_group = ReconstructibleArgumentGroup(parser,
|
||||
'Arguments to gtest-parallel')
|
||||
# These options will be passed unchanged to gtest-parallel.
|
||||
gtest_group.AddArgument('-d', '--output_dir')
|
||||
gtest_group.AddArgument('-r', '--repeat')
|
||||
gtest_group.AddArgument('--retry_failed')
|
||||
gtest_group.AddArgument('--gtest_color')
|
||||
gtest_group.AddArgument('--gtest_filter')
|
||||
gtest_group.AddArgument('--gtest_also_run_disabled_tests',
|
||||
action='store_true', default=None)
|
||||
gtest_group.AddArgument('--timeout')
|
||||
gtest_group = ReconstructibleArgumentGroup(parser,
|
||||
'Arguments to gtest-parallel')
|
||||
# These options will be passed unchanged to gtest-parallel.
|
||||
gtest_group.AddArgument('-d', '--output_dir')
|
||||
gtest_group.AddArgument('-r', '--repeat')
|
||||
gtest_group.AddArgument('--retry_failed')
|
||||
gtest_group.AddArgument('--gtest_color')
|
||||
gtest_group.AddArgument('--gtest_filter')
|
||||
gtest_group.AddArgument('--gtest_also_run_disabled_tests',
|
||||
action='store_true',
|
||||
default=None)
|
||||
gtest_group.AddArgument('--timeout')
|
||||
|
||||
# Syntax 'Nx' will be interpreted as N * number of cpu cores.
|
||||
gtest_group.AddArgument('-w', '--workers', type=_ParseWorkersOption)
|
||||
# Syntax 'Nx' will be interpreted as N * number of cpu cores.
|
||||
gtest_group.AddArgument('-w', '--workers', type=_ParseWorkersOption)
|
||||
|
||||
# Needed when the test wants to store test artifacts, because it doesn't know
|
||||
# what will be the swarming output dir.
|
||||
parser.add_argument('--store-test-artifacts', action='store_true')
|
||||
# Needed when the test wants to store test artifacts, because it doesn't know
|
||||
# what will be the swarming output dir.
|
||||
parser.add_argument('--store-test-artifacts', action='store_true')
|
||||
|
||||
# No-sandbox is a Chromium-specific flag, ignore it.
|
||||
# TODO(oprypin): Remove (bugs.webrtc.org/8115)
|
||||
parser.add_argument('--no-sandbox', action='store_true',
|
||||
help=argparse.SUPPRESS)
|
||||
# No-sandbox is a Chromium-specific flag, ignore it.
|
||||
# TODO(oprypin): Remove (bugs.webrtc.org/8115)
|
||||
parser.add_argument('--no-sandbox',
|
||||
action='store_true',
|
||||
help=argparse.SUPPRESS)
|
||||
|
||||
parser.add_argument('executable')
|
||||
parser.add_argument('executable_args', nargs='*')
|
||||
parser.add_argument('executable')
|
||||
parser.add_argument('executable_args', nargs='*')
|
||||
|
||||
options, unrecognized_args = parser.parse_known_args(argv)
|
||||
options, unrecognized_args = parser.parse_known_args(argv)
|
||||
|
||||
args_to_pass = []
|
||||
for arg in unrecognized_args:
|
||||
if arg.startswith('--isolated-script-test-perf-output'):
|
||||
arg_split = arg.split('=')
|
||||
assert len(arg_split) == 2, 'You must use the = syntax for this flag.'
|
||||
args_to_pass.append('--isolated_script_test_perf_output=' + arg_split[1])
|
||||
args_to_pass = []
|
||||
for arg in unrecognized_args:
|
||||
if arg.startswith('--isolated-script-test-perf-output'):
|
||||
arg_split = arg.split('=')
|
||||
assert len(
|
||||
arg_split) == 2, 'You must use the = syntax for this flag.'
|
||||
args_to_pass.append('--isolated_script_test_perf_output=' +
|
||||
arg_split[1])
|
||||
else:
|
||||
args_to_pass.append(arg)
|
||||
|
||||
executable_args = options.executable_args + args_to_pass
|
||||
|
||||
if options.store_test_artifacts:
|
||||
assert options.output_dir, (
|
||||
'--output_dir must be specified for storing test artifacts.')
|
||||
test_artifacts_dir = os.path.join(options.output_dir, 'test_artifacts')
|
||||
|
||||
executable_args.insert(0,
|
||||
'--test_artifacts_dir=%s' % test_artifacts_dir)
|
||||
else:
|
||||
args_to_pass.append(arg)
|
||||
test_artifacts_dir = None
|
||||
|
||||
executable_args = options.executable_args + args_to_pass
|
||||
gtest_parallel_args = gtest_group.RemakeCommandLine(options)
|
||||
|
||||
if options.store_test_artifacts:
|
||||
assert options.output_dir, (
|
||||
'--output_dir must be specified for storing test artifacts.')
|
||||
test_artifacts_dir = os.path.join(options.output_dir, 'test_artifacts')
|
||||
# GTEST_SHARD_INDEX and GTEST_TOTAL_SHARDS must be removed from the
|
||||
# environment. Otherwise it will be picked up by the binary, causing a bug
|
||||
# where only tests in the first shard are executed.
|
||||
test_env = os.environ.copy()
|
||||
gtest_shard_index = test_env.pop('GTEST_SHARD_INDEX', '0')
|
||||
gtest_total_shards = test_env.pop('GTEST_TOTAL_SHARDS', '1')
|
||||
|
||||
executable_args.insert(0, '--test_artifacts_dir=%s' % test_artifacts_dir)
|
||||
else:
|
||||
test_artifacts_dir = None
|
||||
gtest_parallel_args.insert(0, '--shard_index=%s' % gtest_shard_index)
|
||||
gtest_parallel_args.insert(1, '--shard_count=%s' % gtest_total_shards)
|
||||
|
||||
gtest_parallel_args = gtest_group.RemakeCommandLine(options)
|
||||
gtest_parallel_args.append(options.executable)
|
||||
if executable_args:
|
||||
gtest_parallel_args += ['--'] + executable_args
|
||||
|
||||
# GTEST_SHARD_INDEX and GTEST_TOTAL_SHARDS must be removed from the
|
||||
# environment. Otherwise it will be picked up by the binary, causing a bug
|
||||
# where only tests in the first shard are executed.
|
||||
test_env = os.environ.copy()
|
||||
gtest_shard_index = test_env.pop('GTEST_SHARD_INDEX', '0')
|
||||
gtest_total_shards = test_env.pop('GTEST_TOTAL_SHARDS', '1')
|
||||
|
||||
gtest_parallel_args.insert(0, '--shard_index=%s' % gtest_shard_index)
|
||||
gtest_parallel_args.insert(1, '--shard_count=%s' % gtest_total_shards)
|
||||
|
||||
gtest_parallel_args.append(options.executable)
|
||||
if executable_args:
|
||||
gtest_parallel_args += ['--'] + executable_args
|
||||
|
||||
return Args(gtest_parallel_args, test_env, options.output_dir,
|
||||
test_artifacts_dir)
|
||||
return Args(gtest_parallel_args, test_env, options.output_dir,
|
||||
test_artifacts_dir)
|
||||
|
||||
|
||||
def main():
|
||||
webrtc_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
gtest_parallel_path = os.path.join(
|
||||
webrtc_root, 'third_party', 'gtest-parallel', 'gtest-parallel')
|
||||
webrtc_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
gtest_parallel_path = os.path.join(webrtc_root, 'third_party',
|
||||
'gtest-parallel', 'gtest-parallel')
|
||||
|
||||
gtest_parallel_args, test_env, output_dir, test_artifacts_dir = ParseArgs()
|
||||
gtest_parallel_args, test_env, output_dir, test_artifacts_dir = ParseArgs()
|
||||
|
||||
command = [
|
||||
sys.executable,
|
||||
gtest_parallel_path,
|
||||
] + gtest_parallel_args
|
||||
command = [
|
||||
sys.executable,
|
||||
gtest_parallel_path,
|
||||
] + gtest_parallel_args
|
||||
|
||||
if output_dir and not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
if test_artifacts_dir and not os.path.isdir(test_artifacts_dir):
|
||||
os.makedirs(test_artifacts_dir)
|
||||
if output_dir and not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
if test_artifacts_dir and not os.path.isdir(test_artifacts_dir):
|
||||
os.makedirs(test_artifacts_dir)
|
||||
|
||||
print 'gtest-parallel-wrapper: Executing command %s' % ' '.join(command)
|
||||
sys.stdout.flush()
|
||||
print 'gtest-parallel-wrapper: Executing command %s' % ' '.join(command)
|
||||
sys.stdout.flush()
|
||||
|
||||
exit_code = subprocess.call(command, env=test_env, cwd=os.getcwd())
|
||||
exit_code = subprocess.call(command, env=test_env, cwd=os.getcwd())
|
||||
|
||||
if output_dir:
|
||||
for test_status in 'passed', 'failed', 'interrupted':
|
||||
logs_dir = os.path.join(output_dir, 'gtest-parallel-logs', test_status)
|
||||
if not os.path.isdir(logs_dir):
|
||||
continue
|
||||
logs = [os.path.join(logs_dir, log) for log in os.listdir(logs_dir)]
|
||||
log_file = os.path.join(output_dir, '%s-tests.log' % test_status)
|
||||
_CatFiles(logs, log_file)
|
||||
os.rmdir(logs_dir)
|
||||
if output_dir:
|
||||
for test_status in 'passed', 'failed', 'interrupted':
|
||||
logs_dir = os.path.join(output_dir, 'gtest-parallel-logs',
|
||||
test_status)
|
||||
if not os.path.isdir(logs_dir):
|
||||
continue
|
||||
logs = [
|
||||
os.path.join(logs_dir, log) for log in os.listdir(logs_dir)
|
||||
]
|
||||
log_file = os.path.join(output_dir, '%s-tests.log' % test_status)
|
||||
_CatFiles(logs, log_file)
|
||||
os.rmdir(logs_dir)
|
||||
|
||||
if test_artifacts_dir:
|
||||
shutil.make_archive(test_artifacts_dir, 'zip', test_artifacts_dir)
|
||||
shutil.rmtree(test_artifacts_dir)
|
||||
if test_artifacts_dir:
|
||||
shutil.make_archive(test_artifacts_dir, 'zip', test_artifacts_dir)
|
||||
shutil.rmtree(test_artifacts_dir)
|
||||
|
||||
return exit_code
|
||||
return exit_code
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -21,149 +21,152 @@ gtest_parallel_wrapper = __import__('gtest-parallel-wrapper')
|
||||
|
||||
@contextmanager
|
||||
def TemporaryDirectory():
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
yield tmp_dir
|
||||
os.rmdir(tmp_dir)
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
yield tmp_dir
|
||||
os.rmdir(tmp_dir)
|
||||
|
||||
|
||||
class GtestParallelWrapperHelpersTest(unittest.TestCase):
|
||||
def testGetWorkersAsIs(self):
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('12'), 12)
|
||||
|
||||
def testGetWorkersAsIs(self):
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('12'), 12)
|
||||
def testGetTwiceWorkers(self):
|
||||
expected = 2 * multiprocessing.cpu_count()
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('2x'),
|
||||
expected)
|
||||
|
||||
def testGetTwiceWorkers(self):
|
||||
expected = 2 * multiprocessing.cpu_count()
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('2x'), expected)
|
||||
|
||||
def testGetHalfWorkers(self):
|
||||
expected = max(multiprocessing.cpu_count() // 2, 1)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(
|
||||
gtest_parallel_wrapper._ParseWorkersOption('0.5x'), expected)
|
||||
def testGetHalfWorkers(self):
|
||||
expected = max(multiprocessing.cpu_count() // 2, 1)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('0.5x'),
|
||||
expected)
|
||||
|
||||
|
||||
class GtestParallelWrapperTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def _Expected(cls, gtest_parallel_args):
|
||||
return ['--shard_index=0', '--shard_count=1'] + gtest_parallel_args
|
||||
|
||||
@classmethod
|
||||
def _Expected(cls, gtest_parallel_args):
|
||||
return ['--shard_index=0', '--shard_count=1'] + gtest_parallel_args
|
||||
def testOverwrite(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout=123', 'exec', '--timeout', '124'])
|
||||
expected = self._Expected(['--timeout=124', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testOverwrite(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout=123', 'exec', '--timeout', '124'])
|
||||
expected = self._Expected(['--timeout=124', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testMixing(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs([
|
||||
'--timeout=123', '--param1', 'exec', '--param2', '--timeout', '124'
|
||||
])
|
||||
expected = self._Expected(
|
||||
['--timeout=124', 'exec', '--', '--param1', '--param2'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testMixing(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout=123', '--param1', 'exec', '--param2', '--timeout', '124'])
|
||||
expected = self._Expected(
|
||||
['--timeout=124', 'exec', '--', '--param1', '--param2'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testMixingPositional(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs([
|
||||
'--timeout=123', 'exec', '--foo1', 'bar1', '--timeout', '124',
|
||||
'--foo2', 'bar2'
|
||||
])
|
||||
expected = self._Expected([
|
||||
'--timeout=124', 'exec', '--', '--foo1', 'bar1', '--foo2', 'bar2'
|
||||
])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testMixingPositional(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs([
|
||||
'--timeout=123', 'exec', '--foo1', 'bar1', '--timeout', '124', '--foo2',
|
||||
'bar2'
|
||||
])
|
||||
expected = self._Expected(
|
||||
['--timeout=124', 'exec', '--', '--foo1', 'bar1', '--foo2', 'bar2'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testDoubleDash1(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout', '123', 'exec', '--', '--timeout', '124'])
|
||||
expected = self._Expected(
|
||||
['--timeout=123', 'exec', '--', '--timeout', '124'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testDoubleDash1(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout', '123', 'exec', '--', '--timeout', '124'])
|
||||
expected = self._Expected(
|
||||
['--timeout=123', 'exec', '--', '--timeout', '124'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testDoubleDash2(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout=123', '--', 'exec', '--timeout=124'])
|
||||
expected = self._Expected(
|
||||
['--timeout=123', 'exec', '--', '--timeout=124'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testDoubleDash2(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--timeout=123', '--', 'exec', '--timeout=124'])
|
||||
expected = self._Expected(['--timeout=123', 'exec', '--', '--timeout=124'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testArtifacts(self):
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
output_dir = os.path.join(tmp_dir, 'foo')
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['exec', '--store-test-artifacts', '--output_dir', output_dir])
|
||||
exp_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
|
||||
exp = self._Expected([
|
||||
'--output_dir=' + output_dir, 'exec', '--',
|
||||
'--test_artifacts_dir=' + exp_artifacts_dir
|
||||
])
|
||||
self.assertEqual(result.gtest_parallel_args, exp)
|
||||
self.assertEqual(result.output_dir, output_dir)
|
||||
self.assertEqual(result.test_artifacts_dir, exp_artifacts_dir)
|
||||
|
||||
def testArtifacts(self):
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
output_dir = os.path.join(tmp_dir, 'foo')
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['exec', '--store-test-artifacts', '--output_dir', output_dir])
|
||||
exp_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
|
||||
exp = self._Expected([
|
||||
'--output_dir=' + output_dir, 'exec', '--',
|
||||
'--test_artifacts_dir=' + exp_artifacts_dir
|
||||
])
|
||||
self.assertEqual(result.gtest_parallel_args, exp)
|
||||
self.assertEqual(result.output_dir, output_dir)
|
||||
self.assertEqual(result.test_artifacts_dir, exp_artifacts_dir)
|
||||
def testNoDirsSpecified(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['exec'])
|
||||
self.assertEqual(result.output_dir, None)
|
||||
self.assertEqual(result.test_artifacts_dir, None)
|
||||
|
||||
def testNoDirsSpecified(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['exec'])
|
||||
self.assertEqual(result.output_dir, None)
|
||||
self.assertEqual(result.test_artifacts_dir, None)
|
||||
def testOutputDirSpecified(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['exec', '--output_dir', '/tmp/foo'])
|
||||
self.assertEqual(result.output_dir, '/tmp/foo')
|
||||
self.assertEqual(result.test_artifacts_dir, None)
|
||||
|
||||
def testOutputDirSpecified(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['exec', '--output_dir', '/tmp/foo'])
|
||||
self.assertEqual(result.output_dir, '/tmp/foo')
|
||||
self.assertEqual(result.test_artifacts_dir, None)
|
||||
def testShortArg(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['-d', '/tmp/foo', 'exec'])
|
||||
expected = self._Expected(['--output_dir=/tmp/foo', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
self.assertEqual(result.output_dir, '/tmp/foo')
|
||||
|
||||
def testShortArg(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['-d', '/tmp/foo', 'exec'])
|
||||
expected = self._Expected(['--output_dir=/tmp/foo', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
self.assertEqual(result.output_dir, '/tmp/foo')
|
||||
def testBoolArg(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--gtest_also_run_disabled_tests', 'exec'])
|
||||
expected = self._Expected(['--gtest_also_run_disabled_tests', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testBoolArg(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--gtest_also_run_disabled_tests', 'exec'])
|
||||
expected = self._Expected(['--gtest_also_run_disabled_tests', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testNoArgs(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['exec'])
|
||||
expected = self._Expected(['exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testNoArgs(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['exec'])
|
||||
expected = self._Expected(['exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testDocExample(self):
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
output_dir = os.path.join(tmp_dir, 'foo')
|
||||
result = gtest_parallel_wrapper.ParseArgs([
|
||||
'some_test', '--some_flag=some_value', '--another_flag',
|
||||
'--output_dir=' + output_dir, '--store-test-artifacts',
|
||||
'--isolated-script-test-perf-output=SOME_OTHER_DIR',
|
||||
'--foo=bar', '--baz'
|
||||
])
|
||||
expected_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
|
||||
expected = self._Expected([
|
||||
'--output_dir=' + output_dir, 'some_test', '--',
|
||||
'--test_artifacts_dir=' + expected_artifacts_dir,
|
||||
'--some_flag=some_value', '--another_flag',
|
||||
'--isolated_script_test_perf_output=SOME_OTHER_DIR',
|
||||
'--foo=bar', '--baz'
|
||||
])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testDocExample(self):
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
output_dir = os.path.join(tmp_dir, 'foo')
|
||||
result = gtest_parallel_wrapper.ParseArgs([
|
||||
'some_test', '--some_flag=some_value', '--another_flag',
|
||||
'--output_dir=' + output_dir, '--store-test-artifacts',
|
||||
'--isolated-script-test-perf-output=SOME_OTHER_DIR', '--foo=bar',
|
||||
'--baz'
|
||||
])
|
||||
expected_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
|
||||
expected = self._Expected([
|
||||
'--output_dir=' + output_dir,
|
||||
'some_test', '--', '--test_artifacts_dir=' + expected_artifacts_dir,
|
||||
'--some_flag=some_value', '--another_flag',
|
||||
'--isolated_script_test_perf_output=SOME_OTHER_DIR', '--foo=bar',
|
||||
'--baz'
|
||||
])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testStandardWorkers(self):
|
||||
"""Check integer value is passed as-is."""
|
||||
result = gtest_parallel_wrapper.ParseArgs(['--workers', '17', 'exec'])
|
||||
expected = self._Expected(['--workers=17', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testStandardWorkers(self):
|
||||
"""Check integer value is passed as-is."""
|
||||
result = gtest_parallel_wrapper.ParseArgs(['--workers', '17', 'exec'])
|
||||
expected = self._Expected(['--workers=17', 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testTwoWorkersPerCpuCore(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['--workers', '2x', 'exec'])
|
||||
workers = 2 * multiprocessing.cpu_count()
|
||||
expected = self._Expected(['--workers=%s' % workers, 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testTwoWorkersPerCpuCore(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['--workers', '2x', 'exec'])
|
||||
workers = 2 * multiprocessing.cpu_count()
|
||||
expected = self._Expected(['--workers=%s' % workers, 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
def testUseHalfTheCpuCores(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(['--workers', '0.5x', 'exec'])
|
||||
workers = max(multiprocessing.cpu_count() // 2, 1)
|
||||
expected = self._Expected(['--workers=%s' % workers, 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
def testUseHalfTheCpuCores(self):
|
||||
result = gtest_parallel_wrapper.ParseArgs(
|
||||
['--workers', '0.5x', 'exec'])
|
||||
workers = max(multiprocessing.cpu_count() // 2, 1)
|
||||
expected = self._Expected(['--workers=%s' % workers, 'exec'])
|
||||
self.assertEqual(result.gtest_parallel_args, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""WebRTC iOS FAT libraries build script.
|
||||
Each architecture is compiled separately before being merged together.
|
||||
By default, the library is created in out_ios_libs/. (Change with -o.)
|
||||
@ -21,7 +20,6 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
os.environ['PATH'] = '/usr/libexec' + os.pathsep + os.environ['PATH']
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -41,198 +39,235 @@ from generate_licenses import LicenseBuilder
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument('--build_config', default='release',
|
||||
choices=['debug', 'release'],
|
||||
help='The build config. Can be "debug" or "release". '
|
||||
'Defaults to "release".')
|
||||
parser.add_argument('--arch', nargs='+', default=DEFAULT_ARCHS,
|
||||
choices=ENABLED_ARCHS,
|
||||
help='Architectures to build. Defaults to %(default)s.')
|
||||
parser.add_argument('-c', '--clean', action='store_true', default=False,
|
||||
help='Removes the previously generated build output, if any.')
|
||||
parser.add_argument('-p', '--purify', action='store_true', default=False,
|
||||
help='Purifies the previously generated build output by '
|
||||
'removing the temporary results used when (re)building.')
|
||||
parser.add_argument('-o', '--output-dir', default=SDK_OUTPUT_DIR,
|
||||
help='Specifies a directory to output the build artifacts to. '
|
||||
'If specified together with -c, deletes the dir.')
|
||||
parser.add_argument('-r', '--revision', type=int, default=0,
|
||||
help='Specifies a revision number to embed if building the framework.')
|
||||
parser.add_argument('-e', '--bitcode', action='store_true', default=False,
|
||||
help='Compile with bitcode.')
|
||||
parser.add_argument('--verbose', action='store_true', default=False,
|
||||
help='Debug logging.')
|
||||
parser.add_argument('--use-goma', action='store_true', default=False,
|
||||
help='Use goma to build.')
|
||||
parser.add_argument('--extra-gn-args', default=[], nargs='*',
|
||||
help='Additional GN args to be used during Ninja generation.')
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument('--build_config',
|
||||
default='release',
|
||||
choices=['debug', 'release'],
|
||||
help='The build config. Can be "debug" or "release". '
|
||||
'Defaults to "release".')
|
||||
parser.add_argument(
|
||||
'--arch',
|
||||
nargs='+',
|
||||
default=DEFAULT_ARCHS,
|
||||
choices=ENABLED_ARCHS,
|
||||
help='Architectures to build. Defaults to %(default)s.')
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--clean',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Removes the previously generated build output, if any.')
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
'--purify',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Purifies the previously generated build output by '
|
||||
'removing the temporary results used when (re)building.')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--output-dir',
|
||||
default=SDK_OUTPUT_DIR,
|
||||
help='Specifies a directory to output the build artifacts to. '
|
||||
'If specified together with -c, deletes the dir.')
|
||||
parser.add_argument(
|
||||
'-r',
|
||||
'--revision',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Specifies a revision number to embed if building the framework.')
|
||||
parser.add_argument('-e',
|
||||
'--bitcode',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Compile with bitcode.')
|
||||
parser.add_argument('--verbose',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Debug logging.')
|
||||
parser.add_argument('--use-goma',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use goma to build.')
|
||||
parser.add_argument(
|
||||
'--extra-gn-args',
|
||||
default=[],
|
||||
nargs='*',
|
||||
help='Additional GN args to be used during Ninja generation.')
|
||||
|
||||
return parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _RunCommand(cmd):
|
||||
logging.debug('Running: %r', cmd)
|
||||
subprocess.check_call(cmd, cwd=SRC_DIR)
|
||||
logging.debug('Running: %r', cmd)
|
||||
subprocess.check_call(cmd, cwd=SRC_DIR)
|
||||
|
||||
|
||||
def _CleanArtifacts(output_dir):
|
||||
if os.path.isdir(output_dir):
|
||||
logging.info('Deleting %s', output_dir)
|
||||
shutil.rmtree(output_dir)
|
||||
if os.path.isdir(output_dir):
|
||||
logging.info('Deleting %s', output_dir)
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
def _CleanTemporary(output_dir, architectures):
|
||||
if os.path.isdir(output_dir):
|
||||
logging.info('Removing temporary build files.')
|
||||
for arch in architectures:
|
||||
arch_lib_path = os.path.join(output_dir, arch + '_libs')
|
||||
if os.path.isdir(arch_lib_path):
|
||||
shutil.rmtree(arch_lib_path)
|
||||
if os.path.isdir(output_dir):
|
||||
logging.info('Removing temporary build files.')
|
||||
for arch in architectures:
|
||||
arch_lib_path = os.path.join(output_dir, arch + '_libs')
|
||||
if os.path.isdir(arch_lib_path):
|
||||
shutil.rmtree(arch_lib_path)
|
||||
|
||||
|
||||
def BuildWebRTC(output_dir, target_arch, flavor, gn_target_name,
|
||||
ios_deployment_target, libvpx_build_vp9, use_bitcode,
|
||||
use_goma, extra_gn_args):
|
||||
output_dir = os.path.join(output_dir, target_arch + '_libs')
|
||||
gn_args = ['target_os="ios"', 'ios_enable_code_signing=false',
|
||||
'use_xcode_clang=true', 'is_component_build=false']
|
||||
ios_deployment_target, libvpx_build_vp9, use_bitcode, use_goma,
|
||||
extra_gn_args):
|
||||
output_dir = os.path.join(output_dir, target_arch + '_libs')
|
||||
gn_args = [
|
||||
'target_os="ios"', 'ios_enable_code_signing=false',
|
||||
'use_xcode_clang=true', 'is_component_build=false'
|
||||
]
|
||||
|
||||
# Add flavor option.
|
||||
if flavor == 'debug':
|
||||
gn_args.append('is_debug=true')
|
||||
elif flavor == 'release':
|
||||
gn_args.append('is_debug=false')
|
||||
else:
|
||||
raise ValueError('Unexpected flavor type: %s' % flavor)
|
||||
# Add flavor option.
|
||||
if flavor == 'debug':
|
||||
gn_args.append('is_debug=true')
|
||||
elif flavor == 'release':
|
||||
gn_args.append('is_debug=false')
|
||||
else:
|
||||
raise ValueError('Unexpected flavor type: %s' % flavor)
|
||||
|
||||
gn_args.append('target_cpu="%s"' % target_arch)
|
||||
gn_args.append('target_cpu="%s"' % target_arch)
|
||||
|
||||
gn_args.append('ios_deployment_target="%s"' % ios_deployment_target)
|
||||
gn_args.append('ios_deployment_target="%s"' % ios_deployment_target)
|
||||
|
||||
gn_args.append('rtc_libvpx_build_vp9=' +
|
||||
('true' if libvpx_build_vp9 else 'false'))
|
||||
gn_args.append('rtc_libvpx_build_vp9=' +
|
||||
('true' if libvpx_build_vp9 else 'false'))
|
||||
|
||||
gn_args.append('enable_ios_bitcode=' +
|
||||
('true' if use_bitcode else 'false'))
|
||||
gn_args.append('use_goma=' + ('true' if use_goma else 'false'))
|
||||
gn_args.append('enable_ios_bitcode=' +
|
||||
('true' if use_bitcode else 'false'))
|
||||
gn_args.append('use_goma=' + ('true' if use_goma else 'false'))
|
||||
|
||||
args_string = ' '.join(gn_args + extra_gn_args)
|
||||
logging.info('Building WebRTC with args: %s', args_string)
|
||||
args_string = ' '.join(gn_args + extra_gn_args)
|
||||
logging.info('Building WebRTC with args: %s', args_string)
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
|
||||
'gen',
|
||||
output_dir,
|
||||
'--args=' + args_string,
|
||||
]
|
||||
_RunCommand(cmd)
|
||||
logging.info('Building target: %s', gn_target_name)
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
|
||||
'gen',
|
||||
output_dir,
|
||||
'--args=' + args_string,
|
||||
]
|
||||
_RunCommand(cmd)
|
||||
logging.info('Building target: %s', gn_target_name)
|
||||
|
||||
cmd = [
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'),
|
||||
'-C',
|
||||
output_dir,
|
||||
gn_target_name,
|
||||
]
|
||||
if use_goma:
|
||||
cmd.extend(['-j', '200'])
|
||||
_RunCommand(cmd)
|
||||
|
||||
cmd = [
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'),
|
||||
'-C',
|
||||
output_dir,
|
||||
gn_target_name,
|
||||
]
|
||||
if use_goma:
|
||||
cmd.extend(['-j', '200'])
|
||||
_RunCommand(cmd)
|
||||
|
||||
def main():
|
||||
args = _ParseArgs()
|
||||
args = _ParseArgs()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
if args.clean:
|
||||
_CleanArtifacts(args.output_dir)
|
||||
return 0
|
||||
if args.clean:
|
||||
_CleanArtifacts(args.output_dir)
|
||||
return 0
|
||||
|
||||
architectures = list(args.arch)
|
||||
gn_args = args.extra_gn_args
|
||||
architectures = list(args.arch)
|
||||
gn_args = args.extra_gn_args
|
||||
|
||||
if args.purify:
|
||||
_CleanTemporary(args.output_dir, architectures)
|
||||
return 0
|
||||
if args.purify:
|
||||
_CleanTemporary(args.output_dir, architectures)
|
||||
return 0
|
||||
|
||||
gn_target_name = 'framework_objc'
|
||||
if not args.bitcode:
|
||||
gn_args.append('enable_dsyms=true')
|
||||
gn_args.append('enable_stripping=true')
|
||||
gn_target_name = 'framework_objc'
|
||||
if not args.bitcode:
|
||||
gn_args.append('enable_dsyms=true')
|
||||
gn_args.append('enable_stripping=true')
|
||||
|
||||
# Build all architectures.
|
||||
for arch in architectures:
|
||||
BuildWebRTC(args.output_dir, arch, args.build_config, gn_target_name,
|
||||
IOS_DEPLOYMENT_TARGET, LIBVPX_BUILD_VP9, args.bitcode,
|
||||
args.use_goma, gn_args)
|
||||
|
||||
# Build all architectures.
|
||||
for arch in architectures:
|
||||
BuildWebRTC(args.output_dir, arch, args.build_config, gn_target_name,
|
||||
IOS_DEPLOYMENT_TARGET, LIBVPX_BUILD_VP9, args.bitcode,
|
||||
args.use_goma, gn_args)
|
||||
# Create FAT archive.
|
||||
lib_paths = [
|
||||
os.path.join(args.output_dir, arch + '_libs') for arch in architectures
|
||||
]
|
||||
|
||||
# Create FAT archive.
|
||||
lib_paths = [os.path.join(args.output_dir, arch + '_libs')
|
||||
for arch in architectures]
|
||||
|
||||
# Combine the slices.
|
||||
dylib_path = os.path.join(SDK_FRAMEWORK_NAME, 'WebRTC')
|
||||
# Dylibs will be combined, all other files are the same across archs.
|
||||
# Use distutils instead of shutil to support merging folders.
|
||||
distutils.dir_util.copy_tree(
|
||||
os.path.join(lib_paths[0], SDK_FRAMEWORK_NAME),
|
||||
os.path.join(args.output_dir, SDK_FRAMEWORK_NAME))
|
||||
logging.info('Merging framework slices.')
|
||||
dylib_paths = [os.path.join(path, dylib_path) for path in lib_paths]
|
||||
out_dylib_path = os.path.join(args.output_dir, dylib_path)
|
||||
try:
|
||||
os.remove(out_dylib_path)
|
||||
except OSError:
|
||||
pass
|
||||
cmd = ['lipo'] + dylib_paths + ['-create', '-output', out_dylib_path]
|
||||
_RunCommand(cmd)
|
||||
|
||||
# Merge the dSYM slices.
|
||||
lib_dsym_dir_path = os.path.join(lib_paths[0], 'WebRTC.dSYM')
|
||||
if os.path.isdir(lib_dsym_dir_path):
|
||||
distutils.dir_util.copy_tree(lib_dsym_dir_path,
|
||||
os.path.join(args.output_dir, 'WebRTC.dSYM'))
|
||||
logging.info('Merging dSYM slices.')
|
||||
dsym_path = os.path.join('WebRTC.dSYM', 'Contents', 'Resources', 'DWARF',
|
||||
'WebRTC')
|
||||
lib_dsym_paths = [os.path.join(path, dsym_path) for path in lib_paths]
|
||||
out_dsym_path = os.path.join(args.output_dir, dsym_path)
|
||||
try:
|
||||
os.remove(out_dsym_path)
|
||||
except OSError:
|
||||
pass
|
||||
cmd = ['lipo'] + lib_dsym_paths + ['-create', '-output', out_dsym_path]
|
||||
_RunCommand(cmd)
|
||||
|
||||
# Generate the license file.
|
||||
ninja_dirs = [os.path.join(args.output_dir, arch + '_libs')
|
||||
for arch in architectures]
|
||||
gn_target_full_name = '//sdk:' + gn_target_name
|
||||
builder = LicenseBuilder(ninja_dirs, [gn_target_full_name])
|
||||
builder.GenerateLicenseText(
|
||||
# Combine the slices.
|
||||
dylib_path = os.path.join(SDK_FRAMEWORK_NAME, 'WebRTC')
|
||||
# Dylibs will be combined, all other files are the same across archs.
|
||||
# Use distutils instead of shutil to support merging folders.
|
||||
distutils.dir_util.copy_tree(
|
||||
os.path.join(lib_paths[0], SDK_FRAMEWORK_NAME),
|
||||
os.path.join(args.output_dir, SDK_FRAMEWORK_NAME))
|
||||
|
||||
|
||||
# Modify the version number.
|
||||
# Format should be <Branch cut MXX>.<Hotfix #>.<Rev #>.
|
||||
# e.g. 55.0.14986 means branch cut 55, no hotfixes, and revision 14986.
|
||||
infoplist_path = os.path.join(args.output_dir, SDK_FRAMEWORK_NAME,
|
||||
'Info.plist')
|
||||
cmd = ['PlistBuddy', '-c',
|
||||
'Print :CFBundleShortVersionString', infoplist_path]
|
||||
major_minor = subprocess.check_output(cmd).strip()
|
||||
version_number = '%s.%s' % (major_minor, args.revision)
|
||||
logging.info('Substituting revision number: %s', version_number)
|
||||
cmd = ['PlistBuddy', '-c',
|
||||
'Set :CFBundleVersion ' + version_number, infoplist_path]
|
||||
logging.info('Merging framework slices.')
|
||||
dylib_paths = [os.path.join(path, dylib_path) for path in lib_paths]
|
||||
out_dylib_path = os.path.join(args.output_dir, dylib_path)
|
||||
try:
|
||||
os.remove(out_dylib_path)
|
||||
except OSError:
|
||||
pass
|
||||
cmd = ['lipo'] + dylib_paths + ['-create', '-output', out_dylib_path]
|
||||
_RunCommand(cmd)
|
||||
_RunCommand(['plutil', '-convert', 'binary1', infoplist_path])
|
||||
|
||||
logging.info('Done.')
|
||||
return 0
|
||||
# Merge the dSYM slices.
|
||||
lib_dsym_dir_path = os.path.join(lib_paths[0], 'WebRTC.dSYM')
|
||||
if os.path.isdir(lib_dsym_dir_path):
|
||||
distutils.dir_util.copy_tree(
|
||||
lib_dsym_dir_path, os.path.join(args.output_dir, 'WebRTC.dSYM'))
|
||||
logging.info('Merging dSYM slices.')
|
||||
dsym_path = os.path.join('WebRTC.dSYM', 'Contents', 'Resources',
|
||||
'DWARF', 'WebRTC')
|
||||
lib_dsym_paths = [os.path.join(path, dsym_path) for path in lib_paths]
|
||||
out_dsym_path = os.path.join(args.output_dir, dsym_path)
|
||||
try:
|
||||
os.remove(out_dsym_path)
|
||||
except OSError:
|
||||
pass
|
||||
cmd = ['lipo'] + lib_dsym_paths + ['-create', '-output', out_dsym_path]
|
||||
_RunCommand(cmd)
|
||||
|
||||
# Generate the license file.
|
||||
ninja_dirs = [
|
||||
os.path.join(args.output_dir, arch + '_libs')
|
||||
for arch in architectures
|
||||
]
|
||||
gn_target_full_name = '//sdk:' + gn_target_name
|
||||
builder = LicenseBuilder(ninja_dirs, [gn_target_full_name])
|
||||
builder.GenerateLicenseText(
|
||||
os.path.join(args.output_dir, SDK_FRAMEWORK_NAME))
|
||||
|
||||
# Modify the version number.
|
||||
# Format should be <Branch cut MXX>.<Hotfix #>.<Rev #>.
|
||||
# e.g. 55.0.14986 means branch cut 55, no hotfixes, and revision 14986.
|
||||
infoplist_path = os.path.join(args.output_dir, SDK_FRAMEWORK_NAME,
|
||||
'Info.plist')
|
||||
cmd = [
|
||||
'PlistBuddy', '-c', 'Print :CFBundleShortVersionString',
|
||||
infoplist_path
|
||||
]
|
||||
major_minor = subprocess.check_output(cmd).strip()
|
||||
version_number = '%s.%s' % (major_minor, args.revision)
|
||||
logging.info('Substituting revision number: %s', version_number)
|
||||
cmd = [
|
||||
'PlistBuddy', '-c', 'Set :CFBundleVersion ' + version_number,
|
||||
infoplist_path
|
||||
]
|
||||
_RunCommand(cmd)
|
||||
_RunCommand(['plutil', '-convert', 'binary1', infoplist_path])
|
||||
|
||||
logging.info('Done.')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -9,24 +9,24 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
def GenerateModulemap():
|
||||
parser = argparse.ArgumentParser(description='Generate modulemap')
|
||||
parser.add_argument("-o", "--out", type=str, help="Output file.")
|
||||
parser.add_argument("-n", "--name", type=str, help="Name of binary.")
|
||||
parser = argparse.ArgumentParser(description='Generate modulemap')
|
||||
parser.add_argument("-o", "--out", type=str, help="Output file.")
|
||||
parser.add_argument("-n", "--name", type=str, help="Name of binary.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.out, "w") as outfile:
|
||||
module_template = 'framework module %s {\n' \
|
||||
' umbrella header "%s.h"\n' \
|
||||
'\n' \
|
||||
' export *\n' \
|
||||
' module * { export * }\n' \
|
||||
'}\n' % (args.name, args.name)
|
||||
outfile.write(module_template)
|
||||
return 0
|
||||
with open(args.out, "w") as outfile:
|
||||
module_template = 'framework module %s {\n' \
|
||||
' umbrella header "%s.h"\n' \
|
||||
'\n' \
|
||||
' export *\n' \
|
||||
' module * { export * }\n' \
|
||||
'}\n' % (args.name, args.name)
|
||||
outfile.write(module_template)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(GenerateModulemap())
|
||||
|
||||
sys.exit(GenerateModulemap())
|
||||
|
||||
@ -14,15 +14,20 @@ import textwrap
|
||||
|
||||
|
||||
def GenerateUmbrellaHeader():
|
||||
parser = argparse.ArgumentParser(description='Generate umbrella header')
|
||||
parser.add_argument("-o", "--out", type=str, help="Output file.")
|
||||
parser.add_argument("-s", "--sources", default=[], type=str, nargs='+',
|
||||
help="Headers to include.")
|
||||
parser = argparse.ArgumentParser(description='Generate umbrella header')
|
||||
parser.add_argument("-o", "--out", type=str, help="Output file.")
|
||||
parser.add_argument("-s",
|
||||
"--sources",
|
||||
default=[],
|
||||
type=str,
|
||||
nargs='+',
|
||||
help="Headers to include.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.out, "w") as outfile:
|
||||
outfile.write(textwrap.dedent("""\
|
||||
with open(args.out, "w") as outfile:
|
||||
outfile.write(
|
||||
textwrap.dedent("""\
|
||||
/*
|
||||
* Copyright %d The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
@ -33,11 +38,11 @@ def GenerateUmbrellaHeader():
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/\n\n""" % datetime.datetime.now().year))
|
||||
|
||||
for s in args.sources:
|
||||
outfile.write("#import <WebRTC/{}>\n".format(os.path.basename(s)))
|
||||
for s in args.sources:
|
||||
outfile.write("#import <WebRTC/{}>\n".format(os.path.basename(s)))
|
||||
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(GenerateUmbrellaHeader())
|
||||
sys.exit(GenerateUmbrellaHeader())
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Script for merging generated iOS libraries."""
|
||||
|
||||
import sys
|
||||
@ -22,7 +21,7 @@ VALID_ARCHS = ['arm_libs', 'arm64_libs', 'ia32_libs', 'x64_libs']
|
||||
|
||||
|
||||
def MergeLibs(lib_base_dir):
|
||||
"""Merges generated iOS libraries for different archs.
|
||||
"""Merges generated iOS libraries for different archs.
|
||||
|
||||
Uses libtool to generate FAT archive files for each generated library.
|
||||
|
||||
@ -33,92 +32,96 @@ def MergeLibs(lib_base_dir):
|
||||
Returns:
|
||||
Exit code of libtool.
|
||||
"""
|
||||
output_dir_name = 'fat_libs'
|
||||
archs = [arch for arch in os.listdir(lib_base_dir)
|
||||
if arch in VALID_ARCHS]
|
||||
# For each arch, find (library name, libary path) for arch. We will merge
|
||||
# all libraries with the same name.
|
||||
libs = {}
|
||||
for lib_dir in [os.path.join(lib_base_dir, arch) for arch in VALID_ARCHS]:
|
||||
if not os.path.exists(lib_dir):
|
||||
continue
|
||||
for dirpath, _, filenames in os.walk(lib_dir):
|
||||
for filename in filenames:
|
||||
if not filename.endswith('.a'):
|
||||
continue
|
||||
entry = libs.get(filename, [])
|
||||
entry.append(os.path.join(dirpath, filename))
|
||||
libs[filename] = entry
|
||||
orphaned_libs = {}
|
||||
valid_libs = {}
|
||||
for library, paths in libs.items():
|
||||
if len(paths) < len(archs):
|
||||
orphaned_libs[library] = paths
|
||||
else:
|
||||
valid_libs[library] = paths
|
||||
for library, paths in orphaned_libs.items():
|
||||
components = library[:-2].split('_')[:-1]
|
||||
found = False
|
||||
# Find directly matching parent libs by stripping suffix.
|
||||
while components and not found:
|
||||
parent_library = '_'.join(components) + '.a'
|
||||
if parent_library in valid_libs:
|
||||
valid_libs[parent_library].extend(paths)
|
||||
found = True
|
||||
break
|
||||
components = components[:-1]
|
||||
# Find next best match by finding parent libs with the same prefix.
|
||||
if not found:
|
||||
base_prefix = library[:-2].split('_')[0]
|
||||
for valid_lib, valid_paths in valid_libs.items():
|
||||
if valid_lib[:len(base_prefix)] == base_prefix:
|
||||
valid_paths.extend(paths)
|
||||
found = True
|
||||
break
|
||||
assert found
|
||||
output_dir_name = 'fat_libs'
|
||||
archs = [arch for arch in os.listdir(lib_base_dir) if arch in VALID_ARCHS]
|
||||
# For each arch, find (library name, libary path) for arch. We will merge
|
||||
# all libraries with the same name.
|
||||
libs = {}
|
||||
for lib_dir in [os.path.join(lib_base_dir, arch) for arch in VALID_ARCHS]:
|
||||
if not os.path.exists(lib_dir):
|
||||
continue
|
||||
for dirpath, _, filenames in os.walk(lib_dir):
|
||||
for filename in filenames:
|
||||
if not filename.endswith('.a'):
|
||||
continue
|
||||
entry = libs.get(filename, [])
|
||||
entry.append(os.path.join(dirpath, filename))
|
||||
libs[filename] = entry
|
||||
orphaned_libs = {}
|
||||
valid_libs = {}
|
||||
for library, paths in libs.items():
|
||||
if len(paths) < len(archs):
|
||||
orphaned_libs[library] = paths
|
||||
else:
|
||||
valid_libs[library] = paths
|
||||
for library, paths in orphaned_libs.items():
|
||||
components = library[:-2].split('_')[:-1]
|
||||
found = False
|
||||
# Find directly matching parent libs by stripping suffix.
|
||||
while components and not found:
|
||||
parent_library = '_'.join(components) + '.a'
|
||||
if parent_library in valid_libs:
|
||||
valid_libs[parent_library].extend(paths)
|
||||
found = True
|
||||
break
|
||||
components = components[:-1]
|
||||
# Find next best match by finding parent libs with the same prefix.
|
||||
if not found:
|
||||
base_prefix = library[:-2].split('_')[0]
|
||||
for valid_lib, valid_paths in valid_libs.items():
|
||||
if valid_lib[:len(base_prefix)] == base_prefix:
|
||||
valid_paths.extend(paths)
|
||||
found = True
|
||||
break
|
||||
assert found
|
||||
|
||||
# Create output directory.
|
||||
output_dir_path = os.path.join(lib_base_dir, output_dir_name)
|
||||
if not os.path.exists(output_dir_path):
|
||||
os.mkdir(output_dir_path)
|
||||
# Create output directory.
|
||||
output_dir_path = os.path.join(lib_base_dir, output_dir_name)
|
||||
if not os.path.exists(output_dir_path):
|
||||
os.mkdir(output_dir_path)
|
||||
|
||||
# Use this so libtool merged binaries are always the same.
|
||||
env = os.environ.copy()
|
||||
env['ZERO_AR_DATE'] = '1'
|
||||
# Use this so libtool merged binaries are always the same.
|
||||
env = os.environ.copy()
|
||||
env['ZERO_AR_DATE'] = '1'
|
||||
|
||||
# Ignore certain errors.
|
||||
libtool_re = re.compile(r'^.*libtool:.*file: .* has no symbols$')
|
||||
# Ignore certain errors.
|
||||
libtool_re = re.compile(r'^.*libtool:.*file: .* has no symbols$')
|
||||
|
||||
# Merge libraries using libtool.
|
||||
libtool_returncode = 0
|
||||
for library, paths in valid_libs.items():
|
||||
cmd_list = ['libtool', '-static', '-v', '-o',
|
||||
os.path.join(output_dir_path, library)] + paths
|
||||
libtoolout = subprocess.Popen(cmd_list, stderr=subprocess.PIPE, env=env)
|
||||
_, err = libtoolout.communicate()
|
||||
for line in err.splitlines():
|
||||
if not libtool_re.match(line):
|
||||
print >>sys.stderr, line
|
||||
# Unconditionally touch the output .a file on the command line if present
|
||||
# and the command succeeded. A bit hacky.
|
||||
libtool_returncode = libtoolout.returncode
|
||||
if not libtool_returncode:
|
||||
for i in range(len(cmd_list) - 1):
|
||||
if cmd_list[i] == '-o' and cmd_list[i+1].endswith('.a'):
|
||||
os.utime(cmd_list[i+1], None)
|
||||
break
|
||||
return libtool_returncode
|
||||
# Merge libraries using libtool.
|
||||
libtool_returncode = 0
|
||||
for library, paths in valid_libs.items():
|
||||
cmd_list = [
|
||||
'libtool', '-static', '-v', '-o',
|
||||
os.path.join(output_dir_path, library)
|
||||
] + paths
|
||||
libtoolout = subprocess.Popen(cmd_list,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env)
|
||||
_, err = libtoolout.communicate()
|
||||
for line in err.splitlines():
|
||||
if not libtool_re.match(line):
|
||||
print >> sys.stderr, line
|
||||
# Unconditionally touch the output .a file on the command line if present
|
||||
# and the command succeeded. A bit hacky.
|
||||
libtool_returncode = libtoolout.returncode
|
||||
if not libtool_returncode:
|
||||
for i in range(len(cmd_list) - 1):
|
||||
if cmd_list[i] == '-o' and cmd_list[i + 1].endswith('.a'):
|
||||
os.utime(cmd_list[i + 1], None)
|
||||
break
|
||||
return libtool_returncode
|
||||
|
||||
|
||||
def Main():
|
||||
parser_description = 'Merge WebRTC libraries.'
|
||||
parser = argparse.ArgumentParser(description=parser_description)
|
||||
parser.add_argument('lib_base_dir',
|
||||
help='Directory with built libraries. ',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
lib_base_dir = args.lib_base_dir
|
||||
MergeLibs(lib_base_dir)
|
||||
parser_description = 'Merge WebRTC libraries.'
|
||||
parser = argparse.ArgumentParser(description=parser_description)
|
||||
parser.add_argument('lib_base_dir',
|
||||
help='Directory with built libraries. ',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
lib_base_dir = args.lib_base_dir
|
||||
MergeLibs(lib_base_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(Main())
|
||||
sys.exit(Main())
|
||||
|
||||
@ -36,12 +36,16 @@ LIB_TO_LICENSES_DICT = {
|
||||
'abseil-cpp': ['third_party/abseil-cpp/LICENSE'],
|
||||
'android_ndk': ['third_party/android_ndk/NOTICE'],
|
||||
'android_sdk': ['third_party/android_sdk/LICENSE'],
|
||||
'auto': ['third_party/android_deps/libs/'
|
||||
'com_google_auto_service_auto_service/LICENSE'],
|
||||
'auto': [
|
||||
'third_party/android_deps/libs/'
|
||||
'com_google_auto_service_auto_service/LICENSE'
|
||||
],
|
||||
'bazel': ['third_party/bazel/LICENSE'],
|
||||
'boringssl': ['third_party/boringssl/src/LICENSE'],
|
||||
'errorprone': ['third_party/android_deps/libs/'
|
||||
'com_google_errorprone_error_prone_core/LICENSE'],
|
||||
'errorprone': [
|
||||
'third_party/android_deps/libs/'
|
||||
'com_google_errorprone_error_prone_core/LICENSE'
|
||||
],
|
||||
'fiat': ['third_party/boringssl/src/third_party/fiat/LICENSE'],
|
||||
'guava': ['third_party/guava/LICENSE'],
|
||||
'ijar': ['third_party/ijar/LICENSE'],
|
||||
@ -95,11 +99,11 @@ LIB_REGEX_TO_LICENSES_DICT = {
|
||||
|
||||
|
||||
def FindSrcDirPath():
|
||||
"""Returns the abs path to the src/ dir of the project."""
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
while os.path.basename(src_dir) != 'src':
|
||||
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
|
||||
return src_dir
|
||||
"""Returns the abs path to the src/ dir of the project."""
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
while os.path.basename(src_dir) != 'src':
|
||||
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
|
||||
return src_dir
|
||||
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
|
||||
@ -113,29 +117,28 @@ THIRD_PARTY_LIB_REGEX_TEMPLATE = r'^.*/third_party/%s$'
|
||||
|
||||
|
||||
class LicenseBuilder(object):
|
||||
def __init__(self,
|
||||
buildfile_dirs,
|
||||
targets,
|
||||
lib_to_licenses_dict=None,
|
||||
lib_regex_to_licenses_dict=None):
|
||||
if lib_to_licenses_dict is None:
|
||||
lib_to_licenses_dict = LIB_TO_LICENSES_DICT
|
||||
|
||||
def __init__(self,
|
||||
buildfile_dirs,
|
||||
targets,
|
||||
lib_to_licenses_dict=None,
|
||||
lib_regex_to_licenses_dict=None):
|
||||
if lib_to_licenses_dict is None:
|
||||
lib_to_licenses_dict = LIB_TO_LICENSES_DICT
|
||||
if lib_regex_to_licenses_dict is None:
|
||||
lib_regex_to_licenses_dict = LIB_REGEX_TO_LICENSES_DICT
|
||||
|
||||
if lib_regex_to_licenses_dict is None:
|
||||
lib_regex_to_licenses_dict = LIB_REGEX_TO_LICENSES_DICT
|
||||
self.buildfile_dirs = buildfile_dirs
|
||||
self.targets = targets
|
||||
self.lib_to_licenses_dict = lib_to_licenses_dict
|
||||
self.lib_regex_to_licenses_dict = lib_regex_to_licenses_dict
|
||||
|
||||
self.buildfile_dirs = buildfile_dirs
|
||||
self.targets = targets
|
||||
self.lib_to_licenses_dict = lib_to_licenses_dict
|
||||
self.lib_regex_to_licenses_dict = lib_regex_to_licenses_dict
|
||||
self.common_licenses_dict = self.lib_to_licenses_dict.copy()
|
||||
self.common_licenses_dict.update(self.lib_regex_to_licenses_dict)
|
||||
|
||||
self.common_licenses_dict = self.lib_to_licenses_dict.copy()
|
||||
self.common_licenses_dict.update(self.lib_regex_to_licenses_dict)
|
||||
|
||||
@staticmethod
|
||||
def _ParseLibraryName(dep):
|
||||
"""Returns library name after third_party
|
||||
@staticmethod
|
||||
def _ParseLibraryName(dep):
|
||||
"""Returns library name after third_party
|
||||
|
||||
Input one of:
|
||||
//a/b/third_party/libname:c
|
||||
@ -144,11 +147,11 @@ class LicenseBuilder(object):
|
||||
|
||||
Outputs libname or None if this is not a third_party dependency.
|
||||
"""
|
||||
groups = re.match(THIRD_PARTY_LIB_SIMPLE_NAME_REGEX, dep)
|
||||
return groups.group(1) if groups else None
|
||||
groups = re.match(THIRD_PARTY_LIB_SIMPLE_NAME_REGEX, dep)
|
||||
return groups.group(1) if groups else None
|
||||
|
||||
def _ParseLibrary(self, dep):
|
||||
"""Returns library simple or regex name that matches `dep` after third_party
|
||||
def _ParseLibrary(self, dep):
|
||||
"""Returns library simple or regex name that matches `dep` after third_party
|
||||
|
||||
This method matches `dep` dependency against simple names in
|
||||
LIB_TO_LICENSES_DICT and regular expression names in
|
||||
@ -156,104 +159,109 @@ class LicenseBuilder(object):
|
||||
|
||||
Outputs matched dict key or None if this is not a third_party dependency.
|
||||
"""
|
||||
libname = LicenseBuilder._ParseLibraryName(dep)
|
||||
libname = LicenseBuilder._ParseLibraryName(dep)
|
||||
|
||||
for lib_regex in self.lib_regex_to_licenses_dict:
|
||||
if re.match(THIRD_PARTY_LIB_REGEX_TEMPLATE % lib_regex, dep):
|
||||
return lib_regex
|
||||
for lib_regex in self.lib_regex_to_licenses_dict:
|
||||
if re.match(THIRD_PARTY_LIB_REGEX_TEMPLATE % lib_regex, dep):
|
||||
return lib_regex
|
||||
|
||||
return libname
|
||||
return libname
|
||||
|
||||
@staticmethod
|
||||
def _RunGN(buildfile_dir, target):
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
|
||||
'desc',
|
||||
'--all',
|
||||
'--format=json',
|
||||
os.path.abspath(buildfile_dir),
|
||||
target,
|
||||
]
|
||||
logging.debug('Running: %r', cmd)
|
||||
output_json = subprocess.check_output(cmd, cwd=WEBRTC_ROOT)
|
||||
logging.debug('Output: %s', output_json)
|
||||
return output_json
|
||||
@staticmethod
|
||||
def _RunGN(buildfile_dir, target):
|
||||
cmd = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
|
||||
'desc',
|
||||
'--all',
|
||||
'--format=json',
|
||||
os.path.abspath(buildfile_dir),
|
||||
target,
|
||||
]
|
||||
logging.debug('Running: %r', cmd)
|
||||
output_json = subprocess.check_output(cmd, cwd=WEBRTC_ROOT)
|
||||
logging.debug('Output: %s', output_json)
|
||||
return output_json
|
||||
|
||||
def _GetThirdPartyLibraries(self, buildfile_dir, target):
|
||||
output = json.loads(LicenseBuilder._RunGN(buildfile_dir, target))
|
||||
libraries = set()
|
||||
for described_target in output.values():
|
||||
third_party_libs = (
|
||||
self._ParseLibrary(dep) for dep in described_target['deps'])
|
||||
libraries |= set(lib for lib in third_party_libs if lib)
|
||||
return libraries
|
||||
def _GetThirdPartyLibraries(self, buildfile_dir, target):
|
||||
output = json.loads(LicenseBuilder._RunGN(buildfile_dir, target))
|
||||
libraries = set()
|
||||
for described_target in output.values():
|
||||
third_party_libs = (self._ParseLibrary(dep)
|
||||
for dep in described_target['deps'])
|
||||
libraries |= set(lib for lib in third_party_libs if lib)
|
||||
return libraries
|
||||
|
||||
def GenerateLicenseText(self, output_dir):
|
||||
# Get a list of third_party libs from gn. For fat libraries we must consider
|
||||
# all architectures, hence the multiple buildfile directories.
|
||||
third_party_libs = set()
|
||||
for buildfile in self.buildfile_dirs:
|
||||
for target in self.targets:
|
||||
third_party_libs |= self._GetThirdPartyLibraries(buildfile, target)
|
||||
assert len(third_party_libs) > 0
|
||||
def GenerateLicenseText(self, output_dir):
|
||||
# Get a list of third_party libs from gn. For fat libraries we must consider
|
||||
# all architectures, hence the multiple buildfile directories.
|
||||
third_party_libs = set()
|
||||
for buildfile in self.buildfile_dirs:
|
||||
for target in self.targets:
|
||||
third_party_libs |= self._GetThirdPartyLibraries(
|
||||
buildfile, target)
|
||||
assert len(third_party_libs) > 0
|
||||
|
||||
missing_licenses = third_party_libs - set(self.common_licenses_dict.keys())
|
||||
if missing_licenses:
|
||||
error_msg = 'Missing licenses for following third_party targets: %s' % \
|
||||
', '.join(missing_licenses)
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
missing_licenses = third_party_libs - set(
|
||||
self.common_licenses_dict.keys())
|
||||
if missing_licenses:
|
||||
error_msg = 'Missing licenses for following third_party targets: %s' % \
|
||||
', '.join(missing_licenses)
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# Put webrtc at the front of the list.
|
||||
license_libs = sorted(third_party_libs)
|
||||
license_libs.insert(0, 'webrtc')
|
||||
# Put webrtc at the front of the list.
|
||||
license_libs = sorted(third_party_libs)
|
||||
license_libs.insert(0, 'webrtc')
|
||||
|
||||
logging.info('List of licenses: %s', ', '.join(license_libs))
|
||||
logging.info('List of licenses: %s', ', '.join(license_libs))
|
||||
|
||||
# Generate markdown.
|
||||
output_license_file = open(os.path.join(output_dir, 'LICENSE.md'), 'w+')
|
||||
for license_lib in license_libs:
|
||||
if len(self.common_licenses_dict[license_lib]) == 0:
|
||||
logging.info('Skipping compile time or internal dependency: %s',
|
||||
license_lib)
|
||||
continue # Compile time dependency
|
||||
# Generate markdown.
|
||||
output_license_file = open(os.path.join(output_dir, 'LICENSE.md'),
|
||||
'w+')
|
||||
for license_lib in license_libs:
|
||||
if len(self.common_licenses_dict[license_lib]) == 0:
|
||||
logging.info(
|
||||
'Skipping compile time or internal dependency: %s',
|
||||
license_lib)
|
||||
continue # Compile time dependency
|
||||
|
||||
output_license_file.write('# %s\n' % license_lib)
|
||||
output_license_file.write('```\n')
|
||||
for path in self.common_licenses_dict[license_lib]:
|
||||
license_path = os.path.join(WEBRTC_ROOT, path)
|
||||
with open(license_path, 'r') as license_file:
|
||||
license_text = cgi.escape(license_file.read(), quote=True)
|
||||
output_license_file.write(license_text)
|
||||
output_license_file.write('\n')
|
||||
output_license_file.write('```\n\n')
|
||||
output_license_file.write('# %s\n' % license_lib)
|
||||
output_license_file.write('```\n')
|
||||
for path in self.common_licenses_dict[license_lib]:
|
||||
license_path = os.path.join(WEBRTC_ROOT, path)
|
||||
with open(license_path, 'r') as license_file:
|
||||
license_text = cgi.escape(license_file.read(), quote=True)
|
||||
output_license_file.write(license_text)
|
||||
output_license_file.write('\n')
|
||||
output_license_file.write('```\n\n')
|
||||
|
||||
output_license_file.close()
|
||||
output_license_file.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate WebRTC LICENSE.md')
|
||||
parser.add_argument(
|
||||
'--verbose', action='store_true', default=False, help='Debug logging.')
|
||||
parser.add_argument(
|
||||
'--target',
|
||||
required=True,
|
||||
action='append',
|
||||
default=[],
|
||||
help='Name of the GN target to generate a license for')
|
||||
parser.add_argument('output_dir', help='Directory to output LICENSE.md to.')
|
||||
parser.add_argument(
|
||||
'buildfile_dirs',
|
||||
nargs='+',
|
||||
help='Directories containing gn generated ninja files')
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser(description='Generate WebRTC LICENSE.md')
|
||||
parser.add_argument('--verbose',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Debug logging.')
|
||||
parser.add_argument('--target',
|
||||
required=True,
|
||||
action='append',
|
||||
default=[],
|
||||
help='Name of the GN target to generate a license for')
|
||||
parser.add_argument('output_dir',
|
||||
help='Directory to output LICENSE.md to.')
|
||||
parser.add_argument('buildfile_dirs',
|
||||
nargs='+',
|
||||
help='Directories containing gn generated ninja files')
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
builder = LicenseBuilder(args.buildfile_dirs, args.target)
|
||||
builder.GenerateLicenseText(args.output_dir)
|
||||
builder = LicenseBuilder(args.buildfile_dirs, args.target)
|
||||
builder.GenerateLicenseText(args.output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -16,10 +16,9 @@ from generate_licenses import LicenseBuilder
|
||||
|
||||
|
||||
class TestLicenseBuilder(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _FakeRunGN(buildfile_dir, target):
|
||||
return """
|
||||
@staticmethod
|
||||
def _FakeRunGN(buildfile_dir, target):
|
||||
return """
|
||||
{
|
||||
"target1": {
|
||||
"deps": [
|
||||
@ -32,91 +31,93 @@ class TestLicenseBuilder(unittest.TestCase):
|
||||
}
|
||||
"""
|
||||
|
||||
def testParseLibraryName(self):
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname1:c'),
|
||||
'libname1')
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname2:c(d)'),
|
||||
'libname2')
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname3/c:d(e)'),
|
||||
'libname3')
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName('//a/b/not_third_party/c'), None)
|
||||
def testParseLibraryName(self):
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname1:c'),
|
||||
'libname1')
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName(
|
||||
'//a/b/third_party/libname2:c(d)'), 'libname2')
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName(
|
||||
'//a/b/third_party/libname3/c:d(e)'), 'libname3')
|
||||
self.assertEquals(
|
||||
LicenseBuilder._ParseLibraryName('//a/b/not_third_party/c'), None)
|
||||
|
||||
def testParseLibrarySimpleMatch(self):
|
||||
builder = LicenseBuilder([], [], {}, {})
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname:c'), 'libname')
|
||||
def testParseLibrarySimpleMatch(self):
|
||||
builder = LicenseBuilder([], [], {}, {})
|
||||
self.assertEquals(builder._ParseLibrary('//a/b/third_party/libname:c'),
|
||||
'libname')
|
||||
|
||||
def testParseLibraryRegExNoMatchFallbacksToDefaultLibname(self):
|
||||
lib_dict = {
|
||||
'libname:foo.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], lib_dict, {})
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname:bar_java'), 'libname')
|
||||
def testParseLibraryRegExNoMatchFallbacksToDefaultLibname(self):
|
||||
lib_dict = {
|
||||
'libname:foo.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], lib_dict, {})
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname:bar_java'),
|
||||
'libname')
|
||||
|
||||
def testParseLibraryRegExMatch(self):
|
||||
lib_regex_dict = {
|
||||
'libname:foo.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname:foo_bar_java'),
|
||||
'libname:foo.*')
|
||||
def testParseLibraryRegExMatch(self):
|
||||
lib_regex_dict = {
|
||||
'libname:foo.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname:foo_bar_java'),
|
||||
'libname:foo.*')
|
||||
|
||||
def testParseLibraryRegExMatchWithSubDirectory(self):
|
||||
lib_regex_dict = {
|
||||
'libname/foo:bar.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname/foo:bar_java'),
|
||||
'libname/foo:bar.*')
|
||||
def testParseLibraryRegExMatchWithSubDirectory(self):
|
||||
lib_regex_dict = {
|
||||
'libname/foo:bar.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname/foo:bar_java'),
|
||||
'libname/foo:bar.*')
|
||||
|
||||
def testParseLibraryRegExMatchWithStarInside(self):
|
||||
lib_regex_dict = {
|
||||
'libname/foo.*bar.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary('//a/b/third_party/libname/fooHAHA:bar_java'),
|
||||
'libname/foo.*bar.*')
|
||||
def testParseLibraryRegExMatchWithStarInside(self):
|
||||
lib_regex_dict = {
|
||||
'libname/foo.*bar.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._ParseLibrary(
|
||||
'//a/b/third_party/libname/fooHAHA:bar_java'),
|
||||
'libname/foo.*bar.*')
|
||||
|
||||
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
|
||||
def testGetThirdPartyLibrariesWithoutRegex(self):
|
||||
builder = LicenseBuilder([], [], {}, {})
|
||||
self.assertEquals(
|
||||
builder._GetThirdPartyLibraries('out/arm', 'target1'),
|
||||
set(['libname1', 'libname2', 'libname3']))
|
||||
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
|
||||
def testGetThirdPartyLibrariesWithoutRegex(self):
|
||||
builder = LicenseBuilder([], [], {}, {})
|
||||
self.assertEquals(
|
||||
builder._GetThirdPartyLibraries('out/arm', 'target1'),
|
||||
set(['libname1', 'libname2', 'libname3']))
|
||||
|
||||
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
|
||||
def testGetThirdPartyLibrariesWithRegex(self):
|
||||
lib_regex_dict = {
|
||||
'libname2:c.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._GetThirdPartyLibraries('out/arm', 'target1'),
|
||||
set(['libname1', 'libname2:c.*', 'libname3']))
|
||||
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
|
||||
def testGetThirdPartyLibrariesWithRegex(self):
|
||||
lib_regex_dict = {
|
||||
'libname2:c.*': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder([], [], {}, lib_regex_dict)
|
||||
self.assertEquals(
|
||||
builder._GetThirdPartyLibraries('out/arm', 'target1'),
|
||||
set(['libname1', 'libname2:c.*', 'libname3']))
|
||||
|
||||
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
|
||||
def testGenerateLicenseTextFailIfUnknownLibrary(self):
|
||||
lib_dict = {
|
||||
'simple_library': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder(['dummy_dir'], ['dummy_target'], lib_dict, {})
|
||||
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
|
||||
def testGenerateLicenseTextFailIfUnknownLibrary(self):
|
||||
lib_dict = {
|
||||
'simple_library': ['path/to/LICENSE'],
|
||||
}
|
||||
builder = LicenseBuilder(['dummy_dir'], ['dummy_target'], lib_dict, {})
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
builder.GenerateLicenseText('dummy/dir')
|
||||
with self.assertRaises(Exception) as context:
|
||||
builder.GenerateLicenseText('dummy/dir')
|
||||
|
||||
self.assertEquals(
|
||||
context.exception.message,
|
||||
'Missing licenses for following third_party targets: '
|
||||
'libname1, libname2, libname3')
|
||||
self.assertEquals(
|
||||
context.exception.message,
|
||||
'Missing licenses for following third_party targets: '
|
||||
'libname1, libname2, libname3')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -6,31 +6,31 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Configuration class for network emulation."""
|
||||
|
||||
|
||||
class ConnectionConfig(object):
|
||||
"""Configuration containing the characteristics of a network connection."""
|
||||
"""Configuration containing the characteristics of a network connection."""
|
||||
|
||||
def __init__(self, num, name, receive_bw_kbps, send_bw_kbps, delay_ms,
|
||||
packet_loss_percent, queue_slots):
|
||||
self.num = num
|
||||
self.name = name
|
||||
self.receive_bw_kbps = receive_bw_kbps
|
||||
self.send_bw_kbps = send_bw_kbps
|
||||
self.delay_ms = delay_ms
|
||||
self.packet_loss_percent = packet_loss_percent
|
||||
self.queue_slots = queue_slots
|
||||
def __init__(self, num, name, receive_bw_kbps, send_bw_kbps, delay_ms,
|
||||
packet_loss_percent, queue_slots):
|
||||
self.num = num
|
||||
self.name = name
|
||||
self.receive_bw_kbps = receive_bw_kbps
|
||||
self.send_bw_kbps = send_bw_kbps
|
||||
self.delay_ms = delay_ms
|
||||
self.packet_loss_percent = packet_loss_percent
|
||||
self.queue_slots = queue_slots
|
||||
|
||||
def __str__(self):
|
||||
"""String representing the configuration.
|
||||
def __str__(self):
|
||||
"""String representing the configuration.
|
||||
|
||||
Returns:
|
||||
A string formatted and padded like this example:
|
||||
12 Name 375 kbps 375 kbps 10 145 ms 0.1 %
|
||||
"""
|
||||
left_aligned_name = self.name.ljust(24, ' ')
|
||||
return '%2s %24s %5s kbps %5s kbps %4s %5s ms %3s %%' % (
|
||||
self.num, left_aligned_name, self.receive_bw_kbps, self.send_bw_kbps,
|
||||
self.queue_slots, self.delay_ms, self.packet_loss_percent)
|
||||
left_aligned_name = self.name.ljust(24, ' ')
|
||||
return '%2s %24s %5s kbps %5s kbps %4s %5s ms %3s %%' % (
|
||||
self.num, left_aligned_name, self.receive_bw_kbps,
|
||||
self.send_bw_kbps, self.queue_slots, self.delay_ms,
|
||||
self.packet_loss_percent)
|
||||
|
||||
@ -6,10 +6,8 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Script for constraining traffic on the local machine."""
|
||||
|
||||
|
||||
import logging
|
||||
import optparse
|
||||
import socket
|
||||
@ -18,7 +16,6 @@ import sys
|
||||
import config
|
||||
import network_emulator
|
||||
|
||||
|
||||
_DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# Default port range to apply network constraints on.
|
||||
@ -41,7 +38,7 @@ _PRESETS = [
|
||||
config.ConnectionConfig(12, 'Wifi, Average Case', 40000, 33000, 1, 0, 100),
|
||||
config.ConnectionConfig(13, 'Wifi, Good', 45000, 40000, 1, 0, 100),
|
||||
config.ConnectionConfig(14, 'Wifi, Lossy', 40000, 33000, 1, 0, 100),
|
||||
]
|
||||
]
|
||||
_PRESETS_DICT = dict((p.num, p) for p in _PRESETS)
|
||||
|
||||
_DEFAULT_PRESET_ID = 2
|
||||
@ -49,147 +46,170 @@ _DEFAULT_PRESET = _PRESETS_DICT[_DEFAULT_PRESET_ID]
|
||||
|
||||
|
||||
class NonStrippingEpilogOptionParser(optparse.OptionParser):
|
||||
"""Custom parser to let us show the epilog without weird line breaking."""
|
||||
"""Custom parser to let us show the epilog without weird line breaking."""
|
||||
|
||||
def format_epilog(self, formatter):
|
||||
return self.epilog
|
||||
def format_epilog(self, formatter):
|
||||
return self.epilog
|
||||
|
||||
|
||||
def _GetExternalIp():
|
||||
"""Finds out the machine's external IP by connecting to google.com."""
|
||||
external_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
external_socket.connect(('google.com', 80))
|
||||
return external_socket.getsockname()[0]
|
||||
"""Finds out the machine's external IP by connecting to google.com."""
|
||||
external_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
external_socket.connect(('google.com', 80))
|
||||
return external_socket.getsockname()[0]
|
||||
|
||||
|
||||
def _ParseArgs():
|
||||
"""Define and parse the command-line arguments."""
|
||||
presets_string = '\n'.join(str(p) for p in _PRESETS)
|
||||
parser = NonStrippingEpilogOptionParser(epilog=(
|
||||
'\nAvailable presets:\n'
|
||||
' Bandwidth (kbps) Packet\n'
|
||||
'ID Name Receive Send Queue Delay loss \n'
|
||||
'-- ---- --------- -------- ----- ------- ------\n'
|
||||
'%s\n' % presets_string))
|
||||
parser.add_option('-p', '--preset', type='int', default=_DEFAULT_PRESET_ID,
|
||||
help=('ConnectionConfig configuration, specified by ID. '
|
||||
'Default: %default'))
|
||||
parser.add_option('-r', '--receive-bw', type='int',
|
||||
default=_DEFAULT_PRESET.receive_bw_kbps,
|
||||
help=('Receive bandwidth in kilobit/s. Default: %default'))
|
||||
parser.add_option('-s', '--send-bw', type='int',
|
||||
default=_DEFAULT_PRESET.send_bw_kbps,
|
||||
help=('Send bandwidth in kilobit/s. Default: %default'))
|
||||
parser.add_option('-d', '--delay', type='int',
|
||||
default=_DEFAULT_PRESET.delay_ms,
|
||||
help=('Delay in ms. Default: %default'))
|
||||
parser.add_option('-l', '--packet-loss', type='float',
|
||||
default=_DEFAULT_PRESET.packet_loss_percent,
|
||||
help=('Packet loss in %. Default: %default'))
|
||||
parser.add_option('-q', '--queue', type='int',
|
||||
default=_DEFAULT_PRESET.queue_slots,
|
||||
help=('Queue size as number of slots. Default: %default'))
|
||||
parser.add_option('--port-range', default='%s,%s' % _DEFAULT_PORT_RANGE,
|
||||
help=('Range of ports for constrained network. Specify as '
|
||||
'two comma separated integers. Default: %default'))
|
||||
parser.add_option('--target-ip', default=None,
|
||||
help=('The interface IP address to apply the rules for. '
|
||||
'Default: the external facing interface IP address.'))
|
||||
parser.add_option('-v', '--verbose', action='store_true', default=False,
|
||||
help=('Turn on verbose output. Will print all \'ipfw\' '
|
||||
'commands that are executed.'))
|
||||
"""Define and parse the command-line arguments."""
|
||||
presets_string = '\n'.join(str(p) for p in _PRESETS)
|
||||
parser = NonStrippingEpilogOptionParser(epilog=(
|
||||
'\nAvailable presets:\n'
|
||||
' Bandwidth (kbps) Packet\n'
|
||||
'ID Name Receive Send Queue Delay loss \n'
|
||||
'-- ---- --------- -------- ----- ------- ------\n'
|
||||
'%s\n' % presets_string))
|
||||
parser.add_option('-p',
|
||||
'--preset',
|
||||
type='int',
|
||||
default=_DEFAULT_PRESET_ID,
|
||||
help=('ConnectionConfig configuration, specified by ID. '
|
||||
'Default: %default'))
|
||||
parser.add_option(
|
||||
'-r',
|
||||
'--receive-bw',
|
||||
type='int',
|
||||
default=_DEFAULT_PRESET.receive_bw_kbps,
|
||||
help=('Receive bandwidth in kilobit/s. Default: %default'))
|
||||
parser.add_option('-s',
|
||||
'--send-bw',
|
||||
type='int',
|
||||
default=_DEFAULT_PRESET.send_bw_kbps,
|
||||
help=('Send bandwidth in kilobit/s. Default: %default'))
|
||||
parser.add_option('-d',
|
||||
'--delay',
|
||||
type='int',
|
||||
default=_DEFAULT_PRESET.delay_ms,
|
||||
help=('Delay in ms. Default: %default'))
|
||||
parser.add_option('-l',
|
||||
'--packet-loss',
|
||||
type='float',
|
||||
default=_DEFAULT_PRESET.packet_loss_percent,
|
||||
help=('Packet loss in %. Default: %default'))
|
||||
parser.add_option(
|
||||
'-q',
|
||||
'--queue',
|
||||
type='int',
|
||||
default=_DEFAULT_PRESET.queue_slots,
|
||||
help=('Queue size as number of slots. Default: %default'))
|
||||
parser.add_option(
|
||||
'--port-range',
|
||||
default='%s,%s' % _DEFAULT_PORT_RANGE,
|
||||
help=('Range of ports for constrained network. Specify as '
|
||||
'two comma separated integers. Default: %default'))
|
||||
parser.add_option(
|
||||
'--target-ip',
|
||||
default=None,
|
||||
help=('The interface IP address to apply the rules for. '
|
||||
'Default: the external facing interface IP address.'))
|
||||
parser.add_option('-v',
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=('Turn on verbose output. Will print all \'ipfw\' '
|
||||
'commands that are executed.'))
|
||||
|
||||
options = parser.parse_args()[0]
|
||||
options = parser.parse_args()[0]
|
||||
|
||||
# Find preset by ID, if specified.
|
||||
if options.preset and not _PRESETS_DICT.has_key(options.preset):
|
||||
parser.error('Invalid preset: %s' % options.preset)
|
||||
# Find preset by ID, if specified.
|
||||
if options.preset and not _PRESETS_DICT.has_key(options.preset):
|
||||
parser.error('Invalid preset: %s' % options.preset)
|
||||
|
||||
# Simple validation of the IP address, if supplied.
|
||||
if options.target_ip:
|
||||
# Simple validation of the IP address, if supplied.
|
||||
if options.target_ip:
|
||||
try:
|
||||
socket.inet_aton(options.target_ip)
|
||||
except socket.error:
|
||||
parser.error('Invalid IP address specified: %s' %
|
||||
options.target_ip)
|
||||
|
||||
# Convert port range into the desired tuple format.
|
||||
try:
|
||||
socket.inet_aton(options.target_ip)
|
||||
except socket.error:
|
||||
parser.error('Invalid IP address specified: %s' % options.target_ip)
|
||||
if isinstance(options.port_range, str):
|
||||
options.port_range = tuple(
|
||||
int(port) for port in options.port_range.split(','))
|
||||
if len(options.port_range) != 2:
|
||||
parser.error(
|
||||
'Invalid port range specified, please specify two '
|
||||
'integers separated by a comma.')
|
||||
except ValueError:
|
||||
parser.error('Invalid port range specified.')
|
||||
|
||||
# Convert port range into the desired tuple format.
|
||||
try:
|
||||
if isinstance(options.port_range, str):
|
||||
options.port_range = tuple(int(port) for port in
|
||||
options.port_range.split(','))
|
||||
if len(options.port_range) != 2:
|
||||
parser.error('Invalid port range specified, please specify two '
|
||||
'integers separated by a comma.')
|
||||
except ValueError:
|
||||
parser.error('Invalid port range specified.')
|
||||
|
||||
_InitLogging(options.verbose)
|
||||
return options
|
||||
_InitLogging(options.verbose)
|
||||
return options
|
||||
|
||||
|
||||
def _InitLogging(verbose):
|
||||
"""Setup logging."""
|
||||
log_level = _DEFAULT_LOG_LEVEL
|
||||
if verbose:
|
||||
log_level = logging.DEBUG
|
||||
logging.basicConfig(level=log_level, format='%(message)s')
|
||||
"""Setup logging."""
|
||||
log_level = _DEFAULT_LOG_LEVEL
|
||||
if verbose:
|
||||
log_level = logging.DEBUG
|
||||
logging.basicConfig(level=log_level, format='%(message)s')
|
||||
|
||||
|
||||
def main():
|
||||
options = _ParseArgs()
|
||||
options = _ParseArgs()
|
||||
|
||||
# Build a configuration object. Override any preset configuration settings if
|
||||
# a value of a setting was also given as a flag.
|
||||
connection_config = _PRESETS_DICT[options.preset]
|
||||
if options.receive_bw is not _DEFAULT_PRESET.receive_bw_kbps:
|
||||
connection_config.receive_bw_kbps = options.receive_bw
|
||||
if options.send_bw is not _DEFAULT_PRESET.send_bw_kbps:
|
||||
connection_config.send_bw_kbps = options.send_bw
|
||||
if options.delay is not _DEFAULT_PRESET.delay_ms:
|
||||
connection_config.delay_ms = options.delay
|
||||
if options.packet_loss is not _DEFAULT_PRESET.packet_loss_percent:
|
||||
connection_config.packet_loss_percent = options.packet_loss
|
||||
if options.queue is not _DEFAULT_PRESET.queue_slots:
|
||||
connection_config.queue_slots = options.queue
|
||||
emulator = network_emulator.NetworkEmulator(connection_config,
|
||||
options.port_range)
|
||||
try:
|
||||
emulator.CheckPermissions()
|
||||
except network_emulator.NetworkEmulatorError as e:
|
||||
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
|
||||
return -1
|
||||
# Build a configuration object. Override any preset configuration settings if
|
||||
# a value of a setting was also given as a flag.
|
||||
connection_config = _PRESETS_DICT[options.preset]
|
||||
if options.receive_bw is not _DEFAULT_PRESET.receive_bw_kbps:
|
||||
connection_config.receive_bw_kbps = options.receive_bw
|
||||
if options.send_bw is not _DEFAULT_PRESET.send_bw_kbps:
|
||||
connection_config.send_bw_kbps = options.send_bw
|
||||
if options.delay is not _DEFAULT_PRESET.delay_ms:
|
||||
connection_config.delay_ms = options.delay
|
||||
if options.packet_loss is not _DEFAULT_PRESET.packet_loss_percent:
|
||||
connection_config.packet_loss_percent = options.packet_loss
|
||||
if options.queue is not _DEFAULT_PRESET.queue_slots:
|
||||
connection_config.queue_slots = options.queue
|
||||
emulator = network_emulator.NetworkEmulator(connection_config,
|
||||
options.port_range)
|
||||
try:
|
||||
emulator.CheckPermissions()
|
||||
except network_emulator.NetworkEmulatorError as e:
|
||||
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
|
||||
return -1
|
||||
|
||||
if not options.target_ip:
|
||||
external_ip = _GetExternalIp()
|
||||
else:
|
||||
external_ip = options.target_ip
|
||||
if not options.target_ip:
|
||||
external_ip = _GetExternalIp()
|
||||
else:
|
||||
external_ip = options.target_ip
|
||||
|
||||
logging.info('Constraining traffic to/from IP: %s', external_ip)
|
||||
try:
|
||||
emulator.Emulate(external_ip)
|
||||
logging.info(
|
||||
'Started network emulation with the following configuration:\n'
|
||||
' Receive bandwidth: %s kbps (%s kB/s)\n'
|
||||
' Send bandwidth : %s kbps (%s kB/s)\n'
|
||||
' Delay : %s ms\n'
|
||||
' Packet loss : %s %%\n'
|
||||
' Queue slots : %s', connection_config.receive_bw_kbps,
|
||||
connection_config.receive_bw_kbps / 8,
|
||||
connection_config.send_bw_kbps, connection_config.send_bw_kbps / 8,
|
||||
connection_config.delay_ms, connection_config.packet_loss_percent,
|
||||
connection_config.queue_slots)
|
||||
logging.info('Affected traffic: IP traffic on ports %s-%s',
|
||||
options.port_range[0], options.port_range[1])
|
||||
raw_input('Press Enter to abort Network Emulation...')
|
||||
logging.info('Flushing all Dummynet rules...')
|
||||
network_emulator.Cleanup()
|
||||
logging.info('Completed Network Emulation.')
|
||||
return 0
|
||||
except network_emulator.NetworkEmulatorError as e:
|
||||
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
|
||||
return -2
|
||||
|
||||
logging.info('Constraining traffic to/from IP: %s', external_ip)
|
||||
try:
|
||||
emulator.Emulate(external_ip)
|
||||
logging.info('Started network emulation with the following configuration:\n'
|
||||
' Receive bandwidth: %s kbps (%s kB/s)\n'
|
||||
' Send bandwidth : %s kbps (%s kB/s)\n'
|
||||
' Delay : %s ms\n'
|
||||
' Packet loss : %s %%\n'
|
||||
' Queue slots : %s',
|
||||
connection_config.receive_bw_kbps,
|
||||
connection_config.receive_bw_kbps/8,
|
||||
connection_config.send_bw_kbps,
|
||||
connection_config.send_bw_kbps/8,
|
||||
connection_config.delay_ms,
|
||||
connection_config.packet_loss_percent,
|
||||
connection_config.queue_slots)
|
||||
logging.info('Affected traffic: IP traffic on ports %s-%s',
|
||||
options.port_range[0], options.port_range[1])
|
||||
raw_input('Press Enter to abort Network Emulation...')
|
||||
logging.info('Flushing all Dummynet rules...')
|
||||
network_emulator.Cleanup()
|
||||
logging.info('Completed Network Emulation.')
|
||||
return 0
|
||||
except network_emulator.NetworkEmulatorError as e:
|
||||
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
|
||||
return -2
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Script for constraining traffic on the local machine."""
|
||||
|
||||
import ctypes
|
||||
@ -17,7 +16,7 @@ import sys
|
||||
|
||||
|
||||
class NetworkEmulatorError(BaseException):
|
||||
"""Exception raised for errors in the network emulator.
|
||||
"""Exception raised for errors in the network emulator.
|
||||
|
||||
Attributes:
|
||||
fail_msg: User defined error message.
|
||||
@ -27,81 +26,88 @@ class NetworkEmulatorError(BaseException):
|
||||
stderr: Error output of running the command.
|
||||
"""
|
||||
|
||||
def __init__(self, fail_msg, cmd=None, returncode=None, output=None,
|
||||
error=None):
|
||||
BaseException.__init__(self, fail_msg)
|
||||
self.fail_msg = fail_msg
|
||||
self.cmd = cmd
|
||||
self.returncode = returncode
|
||||
self.output = output
|
||||
self.error = error
|
||||
def __init__(self,
|
||||
fail_msg,
|
||||
cmd=None,
|
||||
returncode=None,
|
||||
output=None,
|
||||
error=None):
|
||||
BaseException.__init__(self, fail_msg)
|
||||
self.fail_msg = fail_msg
|
||||
self.cmd = cmd
|
||||
self.returncode = returncode
|
||||
self.output = output
|
||||
self.error = error
|
||||
|
||||
|
||||
class NetworkEmulator(object):
|
||||
"""A network emulator that can constrain the network using Dummynet."""
|
||||
"""A network emulator that can constrain the network using Dummynet."""
|
||||
|
||||
def __init__(self, connection_config, port_range):
|
||||
"""Constructor.
|
||||
def __init__(self, connection_config, port_range):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
connection_config: A config.ConnectionConfig object containing the
|
||||
characteristics for the connection to be emulation.
|
||||
port_range: Tuple containing two integers defining the port range.
|
||||
"""
|
||||
self._pipe_counter = 0
|
||||
self._rule_counter = 0
|
||||
self._port_range = port_range
|
||||
self._connection_config = connection_config
|
||||
self._pipe_counter = 0
|
||||
self._rule_counter = 0
|
||||
self._port_range = port_range
|
||||
self._connection_config = connection_config
|
||||
|
||||
def Emulate(self, target_ip):
|
||||
"""Starts a network emulation by setting up Dummynet rules.
|
||||
def Emulate(self, target_ip):
|
||||
"""Starts a network emulation by setting up Dummynet rules.
|
||||
|
||||
Args:
|
||||
target_ip: The IP address of the interface that shall be that have the
|
||||
network constraints applied to it.
|
||||
"""
|
||||
receive_pipe_id = self._CreateDummynetPipe(
|
||||
self._connection_config.receive_bw_kbps,
|
||||
self._connection_config.delay_ms,
|
||||
self._connection_config.packet_loss_percent,
|
||||
self._connection_config.queue_slots)
|
||||
logging.debug('Created receive pipe: %s', receive_pipe_id)
|
||||
send_pipe_id = self._CreateDummynetPipe(
|
||||
self._connection_config.send_bw_kbps,
|
||||
self._connection_config.delay_ms,
|
||||
self._connection_config.packet_loss_percent,
|
||||
self._connection_config.queue_slots)
|
||||
logging.debug('Created send pipe: %s', send_pipe_id)
|
||||
receive_pipe_id = self._CreateDummynetPipe(
|
||||
self._connection_config.receive_bw_kbps,
|
||||
self._connection_config.delay_ms,
|
||||
self._connection_config.packet_loss_percent,
|
||||
self._connection_config.queue_slots)
|
||||
logging.debug('Created receive pipe: %s', receive_pipe_id)
|
||||
send_pipe_id = self._CreateDummynetPipe(
|
||||
self._connection_config.send_bw_kbps,
|
||||
self._connection_config.delay_ms,
|
||||
self._connection_config.packet_loss_percent,
|
||||
self._connection_config.queue_slots)
|
||||
logging.debug('Created send pipe: %s', send_pipe_id)
|
||||
|
||||
# Adding the rules will start the emulation.
|
||||
incoming_rule_id = self._CreateDummynetRule(receive_pipe_id, 'any',
|
||||
target_ip, self._port_range)
|
||||
logging.debug('Created incoming rule: %s', incoming_rule_id)
|
||||
outgoing_rule_id = self._CreateDummynetRule(send_pipe_id, target_ip,
|
||||
'any', self._port_range)
|
||||
logging.debug('Created outgoing rule: %s', outgoing_rule_id)
|
||||
# Adding the rules will start the emulation.
|
||||
incoming_rule_id = self._CreateDummynetRule(receive_pipe_id, 'any',
|
||||
target_ip,
|
||||
self._port_range)
|
||||
logging.debug('Created incoming rule: %s', incoming_rule_id)
|
||||
outgoing_rule_id = self._CreateDummynetRule(send_pipe_id, target_ip,
|
||||
'any', self._port_range)
|
||||
logging.debug('Created outgoing rule: %s', outgoing_rule_id)
|
||||
|
||||
@staticmethod
|
||||
def CheckPermissions():
|
||||
"""Checks if permissions are available to run Dummynet commands.
|
||||
@staticmethod
|
||||
def CheckPermissions():
|
||||
"""Checks if permissions are available to run Dummynet commands.
|
||||
|
||||
Raises:
|
||||
NetworkEmulatorError: If permissions to run Dummynet commands are not
|
||||
available.
|
||||
"""
|
||||
try:
|
||||
if os.getuid() != 0:
|
||||
raise NetworkEmulatorError('You must run this script with sudo.')
|
||||
except AttributeError:
|
||||
try:
|
||||
if os.getuid() != 0:
|
||||
raise NetworkEmulatorError(
|
||||
'You must run this script with sudo.')
|
||||
except AttributeError:
|
||||
|
||||
# AttributeError will be raised on Windows.
|
||||
if ctypes.windll.shell32.IsUserAnAdmin() == 0:
|
||||
raise NetworkEmulatorError('You must run this script with administrator'
|
||||
' privileges.')
|
||||
# AttributeError will be raised on Windows.
|
||||
if ctypes.windll.shell32.IsUserAnAdmin() == 0:
|
||||
raise NetworkEmulatorError(
|
||||
'You must run this script with administrator'
|
||||
' privileges.')
|
||||
|
||||
def _CreateDummynetRule(self, pipe_id, from_address, to_address,
|
||||
port_range):
|
||||
"""Creates a network emulation rule and returns its ID.
|
||||
def _CreateDummynetRule(self, pipe_id, from_address, to_address,
|
||||
port_range):
|
||||
"""Creates a network emulation rule and returns its ID.
|
||||
|
||||
Args:
|
||||
pipe_id: integer ID of the pipe.
|
||||
@ -115,18 +121,22 @@ class NetworkEmulator(object):
|
||||
The ID of the rule, starting at 100. The rule ID increments with 100 for
|
||||
each rule being added.
|
||||
"""
|
||||
self._rule_counter += 100
|
||||
add_part = ['add', self._rule_counter, 'pipe', pipe_id,
|
||||
'ip', 'from', from_address, 'to', to_address]
|
||||
_RunIpfwCommand(add_part + ['src-port', '%s-%s' % port_range],
|
||||
'Failed to add Dummynet src-port rule.')
|
||||
_RunIpfwCommand(add_part + ['dst-port', '%s-%s' % port_range],
|
||||
'Failed to add Dummynet dst-port rule.')
|
||||
return self._rule_counter
|
||||
self._rule_counter += 100
|
||||
add_part = [
|
||||
'add', self._rule_counter, 'pipe', pipe_id, 'ip', 'from',
|
||||
from_address, 'to', to_address
|
||||
]
|
||||
_RunIpfwCommand(add_part +
|
||||
['src-port', '%s-%s' % port_range],
|
||||
'Failed to add Dummynet src-port rule.')
|
||||
_RunIpfwCommand(add_part +
|
||||
['dst-port', '%s-%s' % port_range],
|
||||
'Failed to add Dummynet dst-port rule.')
|
||||
return self._rule_counter
|
||||
|
||||
def _CreateDummynetPipe(self, bandwidth_kbps, delay_ms, packet_loss_percent,
|
||||
queue_slots):
|
||||
"""Creates a Dummynet pipe and return its ID.
|
||||
def _CreateDummynetPipe(self, bandwidth_kbps, delay_ms,
|
||||
packet_loss_percent, queue_slots):
|
||||
"""Creates a Dummynet pipe and return its ID.
|
||||
|
||||
Args:
|
||||
bandwidth_kbps: Bandwidth.
|
||||
@ -136,32 +146,34 @@ class NetworkEmulator(object):
|
||||
Returns:
|
||||
The ID of the pipe, starting at 1.
|
||||
"""
|
||||
self._pipe_counter += 1
|
||||
cmd = ['pipe', self._pipe_counter, 'config',
|
||||
'bw', str(bandwidth_kbps/8) + 'KByte/s',
|
||||
'delay', '%sms' % delay_ms,
|
||||
'plr', (packet_loss_percent/100.0),
|
||||
'queue', queue_slots]
|
||||
error_message = 'Failed to create Dummynet pipe. '
|
||||
if sys.platform.startswith('linux'):
|
||||
error_message += ('Make sure you have loaded the ipfw_mod.ko module to '
|
||||
'your kernel (sudo insmod /path/to/ipfw_mod.ko).')
|
||||
_RunIpfwCommand(cmd, error_message)
|
||||
return self._pipe_counter
|
||||
self._pipe_counter += 1
|
||||
cmd = [
|
||||
'pipe', self._pipe_counter, 'config', 'bw',
|
||||
str(bandwidth_kbps / 8) + 'KByte/s', 'delay',
|
||||
'%sms' % delay_ms, 'plr', (packet_loss_percent / 100.0), 'queue',
|
||||
queue_slots
|
||||
]
|
||||
error_message = 'Failed to create Dummynet pipe. '
|
||||
if sys.platform.startswith('linux'):
|
||||
error_message += (
|
||||
'Make sure you have loaded the ipfw_mod.ko module to '
|
||||
'your kernel (sudo insmod /path/to/ipfw_mod.ko).')
|
||||
_RunIpfwCommand(cmd, error_message)
|
||||
return self._pipe_counter
|
||||
|
||||
|
||||
def Cleanup():
|
||||
"""Stops the network emulation by flushing all Dummynet rules.
|
||||
"""Stops the network emulation by flushing all Dummynet rules.
|
||||
|
||||
Notice that this will flush any rules that may have been created previously
|
||||
before starting the emulation.
|
||||
"""
|
||||
_RunIpfwCommand(['-f', 'flush'],
|
||||
'Failed to flush Dummynet rules!')
|
||||
_RunIpfwCommand(['-f', 'pipe', 'flush'],
|
||||
'Failed to flush Dummynet pipes!')
|
||||
_RunIpfwCommand(['-f', 'flush'], 'Failed to flush Dummynet rules!')
|
||||
_RunIpfwCommand(['-f', 'pipe', 'flush'], 'Failed to flush Dummynet pipes!')
|
||||
|
||||
|
||||
def _RunIpfwCommand(command, fail_msg=None):
|
||||
"""Executes a command and prefixes the appropriate command for
|
||||
"""Executes a command and prefixes the appropriate command for
|
||||
Windows or Linux/UNIX.
|
||||
|
||||
Args:
|
||||
@ -172,18 +184,19 @@ def _RunIpfwCommand(command, fail_msg=None):
|
||||
NetworkEmulatorError: If command fails a message is set by the fail_msg
|
||||
parameter.
|
||||
"""
|
||||
if sys.platform == 'win32':
|
||||
ipfw_command = ['ipfw.exe']
|
||||
else:
|
||||
ipfw_command = ['sudo', '-n', 'ipfw']
|
||||
if sys.platform == 'win32':
|
||||
ipfw_command = ['ipfw.exe']
|
||||
else:
|
||||
ipfw_command = ['sudo', '-n', 'ipfw']
|
||||
|
||||
cmd_list = ipfw_command[:] + [str(x) for x in command]
|
||||
cmd_string = ' '.join(cmd_list)
|
||||
logging.debug('Running command: %s', cmd_string)
|
||||
process = subprocess.Popen(cmd_list, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
if process.returncode != 0:
|
||||
raise NetworkEmulatorError(fail_msg, cmd_string, process.returncode, output,
|
||||
error)
|
||||
return output.strip()
|
||||
cmd_list = ipfw_command[:] + [str(x) for x in command]
|
||||
cmd_string = ' '.join(cmd_list)
|
||||
logging.debug('Running command: %s', cmd_string)
|
||||
process = subprocess.Popen(cmd_list,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
output, error = process.communicate()
|
||||
if process.returncode != 0:
|
||||
raise NetworkEmulatorError(fail_msg, cmd_string, process.returncode,
|
||||
output, error)
|
||||
return output.strip()
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
|
||||
import httplib2
|
||||
import json
|
||||
import subprocess
|
||||
@ -20,19 +19,19 @@ from tracing.value.diagnostics import reserved_infos
|
||||
|
||||
|
||||
def _GenerateOauthToken():
|
||||
args = ['luci-auth', 'token']
|
||||
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
if p.wait() == 0:
|
||||
output = p.stdout.read()
|
||||
return output.strip()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Error generating authentication token.\nStdout: %s\nStderr:%s' %
|
||||
(p.stdout.read(), p.stderr.read()))
|
||||
args = ['luci-auth', 'token']
|
||||
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
if p.wait() == 0:
|
||||
output = p.stdout.read()
|
||||
return output.strip()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Error generating authentication token.\nStdout: %s\nStderr:%s' %
|
||||
(p.stdout.read(), p.stderr.read()))
|
||||
|
||||
|
||||
def _SendHistogramSet(url, histograms, oauth_token):
|
||||
"""Make a HTTP POST with the given JSON to the Performance Dashboard.
|
||||
"""Make a HTTP POST with the given JSON to the Performance Dashboard.
|
||||
|
||||
Args:
|
||||
url: URL of Performance Dashboard instance, e.g.
|
||||
@ -40,83 +39,87 @@ def _SendHistogramSet(url, histograms, oauth_token):
|
||||
histograms: a histogram set object that contains the data to be sent.
|
||||
oauth_token: An oauth token to use for authorization.
|
||||
"""
|
||||
headers = {'Authorization': 'Bearer %s' % oauth_token}
|
||||
headers = {'Authorization': 'Bearer %s' % oauth_token}
|
||||
|
||||
serialized = json.dumps(_ApplyHacks(histograms.AsDicts()), indent=4)
|
||||
serialized = json.dumps(_ApplyHacks(histograms.AsDicts()), indent=4)
|
||||
|
||||
if url.startswith('http://localhost'):
|
||||
# The catapult server turns off compression in developer mode.
|
||||
data = serialized
|
||||
else:
|
||||
data = zlib.compress(serialized)
|
||||
if url.startswith('http://localhost'):
|
||||
# The catapult server turns off compression in developer mode.
|
||||
data = serialized
|
||||
else:
|
||||
data = zlib.compress(serialized)
|
||||
|
||||
print 'Sending %d bytes to %s.' % (len(data), url + '/add_histograms')
|
||||
print 'Sending %d bytes to %s.' % (len(data), url + '/add_histograms')
|
||||
|
||||
http = httplib2.Http()
|
||||
response, content = http.request(url + '/add_histograms', method='POST',
|
||||
body=data, headers=headers)
|
||||
return response, content
|
||||
http = httplib2.Http()
|
||||
response, content = http.request(url + '/add_histograms',
|
||||
method='POST',
|
||||
body=data,
|
||||
headers=headers)
|
||||
return response, content
|
||||
|
||||
|
||||
# TODO(https://crbug.com/1029452): HACKHACK
|
||||
# Remove once we have doubles in the proto and handle -infinity correctly.
|
||||
def _ApplyHacks(dicts):
|
||||
for d in dicts:
|
||||
if 'running' in d:
|
||||
def _NoInf(value):
|
||||
if value == float('inf'):
|
||||
return histogram.JS_MAX_VALUE
|
||||
if value == float('-inf'):
|
||||
return -histogram.JS_MAX_VALUE
|
||||
return value
|
||||
d['running'] = [_NoInf(value) for value in d['running']]
|
||||
for d in dicts:
|
||||
if 'running' in d:
|
||||
|
||||
return dicts
|
||||
def _NoInf(value):
|
||||
if value == float('inf'):
|
||||
return histogram.JS_MAX_VALUE
|
||||
if value == float('-inf'):
|
||||
return -histogram.JS_MAX_VALUE
|
||||
return value
|
||||
|
||||
d['running'] = [_NoInf(value) for value in d['running']]
|
||||
|
||||
return dicts
|
||||
|
||||
|
||||
def _LoadHistogramSetFromProto(options):
|
||||
hs = histogram_set.HistogramSet()
|
||||
with options.input_results_file as f:
|
||||
hs.ImportProto(f.read())
|
||||
hs = histogram_set.HistogramSet()
|
||||
with options.input_results_file as f:
|
||||
hs.ImportProto(f.read())
|
||||
|
||||
return hs
|
||||
return hs
|
||||
|
||||
|
||||
def _AddBuildInfo(histograms, options):
|
||||
common_diagnostics = {
|
||||
reserved_infos.MASTERS: options.perf_dashboard_machine_group,
|
||||
reserved_infos.BOTS: options.bot,
|
||||
reserved_infos.POINT_ID: options.commit_position,
|
||||
reserved_infos.BENCHMARKS: options.test_suite,
|
||||
reserved_infos.WEBRTC_REVISIONS: str(options.webrtc_git_hash),
|
||||
reserved_infos.BUILD_URLS: options.build_page_url,
|
||||
}
|
||||
common_diagnostics = {
|
||||
reserved_infos.MASTERS: options.perf_dashboard_machine_group,
|
||||
reserved_infos.BOTS: options.bot,
|
||||
reserved_infos.POINT_ID: options.commit_position,
|
||||
reserved_infos.BENCHMARKS: options.test_suite,
|
||||
reserved_infos.WEBRTC_REVISIONS: str(options.webrtc_git_hash),
|
||||
reserved_infos.BUILD_URLS: options.build_page_url,
|
||||
}
|
||||
|
||||
for k, v in common_diagnostics.items():
|
||||
histograms.AddSharedDiagnosticToAllHistograms(
|
||||
k.name, generic_set.GenericSet([v]))
|
||||
for k, v in common_diagnostics.items():
|
||||
histograms.AddSharedDiagnosticToAllHistograms(
|
||||
k.name, generic_set.GenericSet([v]))
|
||||
|
||||
|
||||
def _DumpOutput(histograms, output_file):
|
||||
with output_file:
|
||||
json.dump(_ApplyHacks(histograms.AsDicts()), output_file, indent=4)
|
||||
with output_file:
|
||||
json.dump(_ApplyHacks(histograms.AsDicts()), output_file, indent=4)
|
||||
|
||||
|
||||
def UploadToDashboard(options):
|
||||
histograms = _LoadHistogramSetFromProto(options)
|
||||
_AddBuildInfo(histograms, options)
|
||||
histograms = _LoadHistogramSetFromProto(options)
|
||||
_AddBuildInfo(histograms, options)
|
||||
|
||||
if options.output_json_file:
|
||||
_DumpOutput(histograms, options.output_json_file)
|
||||
if options.output_json_file:
|
||||
_DumpOutput(histograms, options.output_json_file)
|
||||
|
||||
oauth_token = _GenerateOauthToken()
|
||||
response, content = _SendHistogramSet(
|
||||
options.dashboard_url, histograms, oauth_token)
|
||||
oauth_token = _GenerateOauthToken()
|
||||
response, content = _SendHistogramSet(options.dashboard_url, histograms,
|
||||
oauth_token)
|
||||
|
||||
if response.status == 200:
|
||||
print 'Received 200 from dashboard.'
|
||||
return 0
|
||||
else:
|
||||
print('Upload failed with %d: %s\n\n%s' % (response.status, response.reason,
|
||||
content))
|
||||
return 1
|
||||
if response.status == 200:
|
||||
print 'Received 200 from dashboard.'
|
||||
return 0
|
||||
else:
|
||||
print('Upload failed with %d: %s\n\n%s' %
|
||||
(response.status, response.reason, content))
|
||||
return 1
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Adds build info to perf results and uploads them.
|
||||
|
||||
The tests don't know which bot executed the tests or at what revision, so we
|
||||
@ -24,78 +23,93 @@ import sys
|
||||
|
||||
|
||||
def _CreateParser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--perf-dashboard-machine-group', required=True,
|
||||
help='The "master" the bots are grouped under. This '
|
||||
'string is the group in the the perf dashboard path '
|
||||
'group/bot/perf_id/metric/subtest.')
|
||||
parser.add_argument('--bot', required=True,
|
||||
help='The bot running the test (e.g. '
|
||||
'webrtc-win-large-tests).')
|
||||
parser.add_argument('--test-suite', required=True,
|
||||
help='The key for the test in the dashboard (i.e. what '
|
||||
'you select in the top-level test suite selector in the '
|
||||
'dashboard')
|
||||
parser.add_argument('--webrtc-git-hash', required=True,
|
||||
help='webrtc.googlesource.com commit hash.')
|
||||
parser.add_argument('--commit-position', type=int, required=True,
|
||||
help='Commit pos corresponding to the git hash.')
|
||||
parser.add_argument('--build-page-url', required=True,
|
||||
help='URL to the build page for this build.')
|
||||
parser.add_argument('--dashboard-url', required=True,
|
||||
help='Which dashboard to use.')
|
||||
parser.add_argument('--input-results-file', type=argparse.FileType(),
|
||||
required=True,
|
||||
help='A JSON file with output from WebRTC tests.')
|
||||
parser.add_argument('--output-json-file', type=argparse.FileType('w'),
|
||||
help='Where to write the output (for debugging).')
|
||||
parser.add_argument('--outdir', required=True,
|
||||
help='Path to the local out/ dir (usually out/Default)')
|
||||
return parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--perf-dashboard-machine-group',
|
||||
required=True,
|
||||
help='The "master" the bots are grouped under. This '
|
||||
'string is the group in the the perf dashboard path '
|
||||
'group/bot/perf_id/metric/subtest.')
|
||||
parser.add_argument('--bot',
|
||||
required=True,
|
||||
help='The bot running the test (e.g. '
|
||||
'webrtc-win-large-tests).')
|
||||
parser.add_argument(
|
||||
'--test-suite',
|
||||
required=True,
|
||||
help='The key for the test in the dashboard (i.e. what '
|
||||
'you select in the top-level test suite selector in the '
|
||||
'dashboard')
|
||||
parser.add_argument('--webrtc-git-hash',
|
||||
required=True,
|
||||
help='webrtc.googlesource.com commit hash.')
|
||||
parser.add_argument('--commit-position',
|
||||
type=int,
|
||||
required=True,
|
||||
help='Commit pos corresponding to the git hash.')
|
||||
parser.add_argument('--build-page-url',
|
||||
required=True,
|
||||
help='URL to the build page for this build.')
|
||||
parser.add_argument('--dashboard-url',
|
||||
required=True,
|
||||
help='Which dashboard to use.')
|
||||
parser.add_argument('--input-results-file',
|
||||
type=argparse.FileType(),
|
||||
required=True,
|
||||
help='A JSON file with output from WebRTC tests.')
|
||||
parser.add_argument('--output-json-file',
|
||||
type=argparse.FileType('w'),
|
||||
help='Where to write the output (for debugging).')
|
||||
parser.add_argument(
|
||||
'--outdir',
|
||||
required=True,
|
||||
help='Path to the local out/ dir (usually out/Default)')
|
||||
return parser
|
||||
|
||||
|
||||
def _ConfigurePythonPath(options):
|
||||
# We just yank the python scripts we require into the PYTHONPATH. You could
|
||||
# also imagine a solution where we use for instance protobuf:py_proto_runtime
|
||||
# to copy catapult and protobuf code to out/. This is the convention in
|
||||
# Chromium and WebRTC python scripts. We do need to build histogram_pb2
|
||||
# however, so that's why we add out/ to sys.path below.
|
||||
#
|
||||
# It would be better if there was an equivalent to py_binary in GN, but
|
||||
# there's not.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkout_root = os.path.abspath(
|
||||
os.path.join(script_dir, os.pardir, os.pardir))
|
||||
# We just yank the python scripts we require into the PYTHONPATH. You could
|
||||
# also imagine a solution where we use for instance protobuf:py_proto_runtime
|
||||
# to copy catapult and protobuf code to out/. This is the convention in
|
||||
# Chromium and WebRTC python scripts. We do need to build histogram_pb2
|
||||
# however, so that's why we add out/ to sys.path below.
|
||||
#
|
||||
# It would be better if there was an equivalent to py_binary in GN, but
|
||||
# there's not.
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
checkout_root = os.path.abspath(
|
||||
os.path.join(script_dir, os.pardir, os.pardir))
|
||||
|
||||
sys.path.insert(0, os.path.join(checkout_root, 'third_party', 'catapult',
|
||||
'tracing'))
|
||||
sys.path.insert(0, os.path.join(checkout_root, 'third_party', 'protobuf',
|
||||
'python'))
|
||||
sys.path.insert(
|
||||
0, os.path.join(checkout_root, 'third_party', 'catapult', 'tracing'))
|
||||
sys.path.insert(
|
||||
0, os.path.join(checkout_root, 'third_party', 'protobuf', 'python'))
|
||||
|
||||
# The webrtc_dashboard_upload gn rule will build the protobuf stub for python,
|
||||
# so put it in the path for this script before we attempt to import it.
|
||||
histogram_proto_path = os.path.join(
|
||||
options.outdir, 'pyproto', 'tracing', 'tracing', 'proto')
|
||||
sys.path.insert(0, histogram_proto_path)
|
||||
# The webrtc_dashboard_upload gn rule will build the protobuf stub for python,
|
||||
# so put it in the path for this script before we attempt to import it.
|
||||
histogram_proto_path = os.path.join(options.outdir, 'pyproto', 'tracing',
|
||||
'tracing', 'proto')
|
||||
sys.path.insert(0, histogram_proto_path)
|
||||
|
||||
# Fail early in case the proto hasn't been built.
|
||||
from tracing.proto import histogram_proto
|
||||
if not histogram_proto.HAS_PROTO:
|
||||
raise ImportError('Could not find histogram_pb2. You need to build the '
|
||||
'webrtc_dashboard_upload target before invoking this '
|
||||
'script. Expected to find '
|
||||
'histogram_pb2.py in %s.' % histogram_proto_path)
|
||||
# Fail early in case the proto hasn't been built.
|
||||
from tracing.proto import histogram_proto
|
||||
if not histogram_proto.HAS_PROTO:
|
||||
raise ImportError(
|
||||
'Could not find histogram_pb2. You need to build the '
|
||||
'webrtc_dashboard_upload target before invoking this '
|
||||
'script. Expected to find '
|
||||
'histogram_pb2.py in %s.' % histogram_proto_path)
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = _CreateParser()
|
||||
options = parser.parse_args(args)
|
||||
parser = _CreateParser()
|
||||
options = parser.parse_args(args)
|
||||
|
||||
_ConfigurePythonPath(options)
|
||||
_ConfigurePythonPath(options)
|
||||
|
||||
import catapult_uploader
|
||||
import catapult_uploader
|
||||
|
||||
return catapult_uploader.UploadToDashboard(options)
|
||||
|
||||
return catapult_uploader.UploadToDashboard(options)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""This script helps to invoke gn and ninja
|
||||
which lie in depot_tools repository."""
|
||||
|
||||
@ -19,11 +18,11 @@ import tempfile
|
||||
|
||||
|
||||
def FindSrcDirPath():
|
||||
"""Returns the abs path to the src/ dir of the project."""
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
while os.path.basename(src_dir) != 'src':
|
||||
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
|
||||
return src_dir
|
||||
"""Returns the abs path to the src/ dir of the project."""
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
while os.path.basename(src_dir) != 'src':
|
||||
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
|
||||
return src_dir
|
||||
|
||||
|
||||
SRC_DIR = FindSrcDirPath()
|
||||
@ -32,16 +31,16 @@ import find_depot_tools
|
||||
|
||||
|
||||
def RunGnCommand(args, root_dir=None):
|
||||
"""Runs `gn` with provided args and return error if any."""
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')
|
||||
] + args
|
||||
subprocess.check_output(command, cwd=root_dir)
|
||||
except subprocess.CalledProcessError as err:
|
||||
return err.output
|
||||
return None
|
||||
"""Runs `gn` with provided args and return error if any."""
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')
|
||||
] + args
|
||||
subprocess.check_output(command, cwd=root_dir)
|
||||
except subprocess.CalledProcessError as err:
|
||||
return err.output
|
||||
return None
|
||||
|
||||
|
||||
# GN_ERROR_RE matches the summary of an error output by `gn check`.
|
||||
@ -51,49 +50,49 @@ GN_ERROR_RE = re.compile(r'^ERROR .+(?:\n.*[^_\n].*$)+', re.MULTILINE)
|
||||
|
||||
|
||||
def RunGnCheck(root_dir=None):
|
||||
"""Runs `gn gen --check` with default args to detect mismatches between
|
||||
"""Runs `gn gen --check` with default args to detect mismatches between
|
||||
#includes and dependencies in the BUILD.gn files, as well as general build
|
||||
errors.
|
||||
|
||||
Returns a list of error summary strings.
|
||||
"""
|
||||
out_dir = tempfile.mkdtemp('gn')
|
||||
try:
|
||||
error = RunGnCommand(['gen', '--check', out_dir], root_dir)
|
||||
finally:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
return GN_ERROR_RE.findall(error) if error else []
|
||||
out_dir = tempfile.mkdtemp('gn')
|
||||
try:
|
||||
error = RunGnCommand(['gen', '--check', out_dir], root_dir)
|
||||
finally:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
return GN_ERROR_RE.findall(error) if error else []
|
||||
|
||||
|
||||
def RunNinjaCommand(args, root_dir=None):
|
||||
"""Runs ninja quietly. Any failure (e.g. clang not found) is
|
||||
"""Runs ninja quietly. Any failure (e.g. clang not found) is
|
||||
silently discarded, since this is unlikely an error in submitted CL."""
|
||||
command = [
|
||||
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja')
|
||||
] + args
|
||||
p = subprocess.Popen(command, cwd=root_dir,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, _ = p.communicate()
|
||||
return out
|
||||
command = [os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja')] + args
|
||||
p = subprocess.Popen(command,
|
||||
cwd=root_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
out, _ = p.communicate()
|
||||
return out
|
||||
|
||||
|
||||
def GetClangTidyPath():
|
||||
"""POC/WIP! Use the one we have, even it doesn't match clang's version."""
|
||||
tidy = ('third_party/android_ndk/toolchains/'
|
||||
'llvm/prebuilt/linux-x86_64/bin/clang-tidy')
|
||||
return os.path.join(SRC_DIR, tidy)
|
||||
"""POC/WIP! Use the one we have, even it doesn't match clang's version."""
|
||||
tidy = ('third_party/android_ndk/toolchains/'
|
||||
'llvm/prebuilt/linux-x86_64/bin/clang-tidy')
|
||||
return os.path.join(SRC_DIR, tidy)
|
||||
|
||||
|
||||
def GetCompilationDb(root_dir=None):
|
||||
"""Run ninja compdb tool to get proper flags, defines and include paths."""
|
||||
# The compdb tool expect a rule.
|
||||
commands = json.loads(RunNinjaCommand(['-t', 'compdb', 'cxx'], root_dir))
|
||||
# Turns 'file' field into a key.
|
||||
return {v['file']: v for v in commands}
|
||||
"""Run ninja compdb tool to get proper flags, defines and include paths."""
|
||||
# The compdb tool expect a rule.
|
||||
commands = json.loads(RunNinjaCommand(['-t', 'compdb', 'cxx'], root_dir))
|
||||
# Turns 'file' field into a key.
|
||||
return {v['file']: v for v in commands}
|
||||
|
||||
|
||||
def GetCompilationCommand(filepath, gn_args, work_dir):
|
||||
"""Get the whole command used to compile one cc file.
|
||||
"""Get the whole command used to compile one cc file.
|
||||
Typically, clang++ with flags, defines and include paths.
|
||||
|
||||
Args:
|
||||
@ -104,31 +103,30 @@ def GetCompilationCommand(filepath, gn_args, work_dir):
|
||||
Returns:
|
||||
Command as a list, ready to be consumed by subprocess.Popen.
|
||||
"""
|
||||
gn_errors = RunGnCommand(['gen'] + gn_args + [work_dir])
|
||||
if gn_errors:
|
||||
raise(RuntimeError(
|
||||
'FYI, cannot complete check due to gn error:\n%s\n'
|
||||
'Please open a bug.' % gn_errors))
|
||||
gn_errors = RunGnCommand(['gen'] + gn_args + [work_dir])
|
||||
if gn_errors:
|
||||
raise (RuntimeError('FYI, cannot complete check due to gn error:\n%s\n'
|
||||
'Please open a bug.' % gn_errors))
|
||||
|
||||
# Needed for single file compilation.
|
||||
commands = GetCompilationDb(work_dir)
|
||||
# Needed for single file compilation.
|
||||
commands = GetCompilationDb(work_dir)
|
||||
|
||||
# Path as referenced by ninja.
|
||||
rel_path = os.path.relpath(os.path.abspath(filepath), work_dir)
|
||||
# Path as referenced by ninja.
|
||||
rel_path = os.path.relpath(os.path.abspath(filepath), work_dir)
|
||||
|
||||
# Gather defines, include path and flags (such as -std=c++11).
|
||||
try:
|
||||
compilation_entry = commands[rel_path]
|
||||
except KeyError:
|
||||
raise ValueError('%s: Not found in compilation database.\n'
|
||||
'Please check the path.' % filepath)
|
||||
command = compilation_entry['command'].split()
|
||||
# Gather defines, include path and flags (such as -std=c++11).
|
||||
try:
|
||||
compilation_entry = commands[rel_path]
|
||||
except KeyError:
|
||||
raise ValueError('%s: Not found in compilation database.\n'
|
||||
'Please check the path.' % filepath)
|
||||
command = compilation_entry['command'].split()
|
||||
|
||||
# Remove troublesome flags. May trigger an error otherwise.
|
||||
if '-MMD' in command:
|
||||
command.remove('-MMD')
|
||||
if '-MF' in command:
|
||||
index = command.index('-MF')
|
||||
del command[index:index+2] # Remove filename as well.
|
||||
# Remove troublesome flags. May trigger an error otherwise.
|
||||
if '-MMD' in command:
|
||||
command.remove('-MMD')
|
||||
if '-MF' in command:
|
||||
index = command.index('-MF')
|
||||
del command[index:index + 2] # Remove filename as well.
|
||||
|
||||
return command
|
||||
return command
|
||||
|
||||
@ -14,19 +14,20 @@ import unittest
|
||||
#pylint: disable=relative-import
|
||||
import build_helpers
|
||||
|
||||
|
||||
TESTDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'testdata')
|
||||
|
||||
|
||||
class GnCheckTest(unittest.TestCase):
|
||||
def testCircularDependencyError(self):
|
||||
test_dir = os.path.join(TESTDATA_DIR, 'circular_dependency')
|
||||
expected_errors = ['ERROR Dependency cycle:\n'
|
||||
' //:bar ->\n //:foo ->\n //:bar']
|
||||
self.assertListEqual(expected_errors,
|
||||
build_helpers.RunGnCheck(test_dir))
|
||||
def testCircularDependencyError(self):
|
||||
test_dir = os.path.join(TESTDATA_DIR, 'circular_dependency')
|
||||
expected_errors = [
|
||||
'ERROR Dependency cycle:\n'
|
||||
' //:bar ->\n //:foo ->\n //:bar'
|
||||
]
|
||||
self.assertListEqual(expected_errors,
|
||||
build_helpers.RunGnCheck(test_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -11,12 +11,11 @@ import os
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
# TARGET_RE matches a GN target, and extracts the target name and the contents.
|
||||
TARGET_RE = re.compile(r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
|
||||
r'(?P<target_contents>.*?)'
|
||||
r'(?P=indent)}',
|
||||
re.MULTILINE | re.DOTALL)
|
||||
TARGET_RE = re.compile(
|
||||
r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
|
||||
r'(?P<target_contents>.*?)'
|
||||
r'(?P=indent)}', re.MULTILINE | re.DOTALL)
|
||||
|
||||
# SOURCES_RE matches a block of sources inside a GN target.
|
||||
SOURCES_RE = re.compile(
|
||||
@ -27,27 +26,27 @@ SOURCE_FILE_RE = re.compile(r'.*\"(?P<source_file>.*)\"')
|
||||
|
||||
|
||||
class NoBuildGnFoundError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class WrongFileTypeError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def _ReadFile(file_path):
|
||||
"""Returns the content of file_path in a string.
|
||||
"""Returns the content of file_path in a string.
|
||||
|
||||
Args:
|
||||
file_path: the path of the file to read.
|
||||
Returns:
|
||||
A string with the content of the file.
|
||||
"""
|
||||
with open(file_path) as f:
|
||||
return f.read()
|
||||
with open(file_path) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def GetBuildGnPathFromFilePath(file_path, file_exists_check, root_dir_path):
|
||||
"""Returns the BUILD.gn file responsible for file_path.
|
||||
"""Returns the BUILD.gn file responsible for file_path.
|
||||
|
||||
Args:
|
||||
file_path: the absolute path to the .h file to check.
|
||||
@ -59,23 +58,23 @@ def GetBuildGnPathFromFilePath(file_path, file_exists_check, root_dir_path):
|
||||
A string with the absolute path to the BUILD.gn file responsible to include
|
||||
file_path in a target.
|
||||
"""
|
||||
if not file_path.endswith('.h'):
|
||||
raise WrongFileTypeError(
|
||||
'File {} is not an header file (.h)'.format(file_path))
|
||||
candidate_dir = os.path.dirname(file_path)
|
||||
while candidate_dir.startswith(root_dir_path):
|
||||
candidate_build_gn_path = os.path.join(candidate_dir, 'BUILD.gn')
|
||||
if file_exists_check(candidate_build_gn_path):
|
||||
return candidate_build_gn_path
|
||||
else:
|
||||
candidate_dir = os.path.abspath(os.path.join(candidate_dir,
|
||||
os.pardir))
|
||||
raise NoBuildGnFoundError(
|
||||
'No BUILD.gn file found for file: `{}`'.format(file_path))
|
||||
if not file_path.endswith('.h'):
|
||||
raise WrongFileTypeError(
|
||||
'File {} is not an header file (.h)'.format(file_path))
|
||||
candidate_dir = os.path.dirname(file_path)
|
||||
while candidate_dir.startswith(root_dir_path):
|
||||
candidate_build_gn_path = os.path.join(candidate_dir, 'BUILD.gn')
|
||||
if file_exists_check(candidate_build_gn_path):
|
||||
return candidate_build_gn_path
|
||||
else:
|
||||
candidate_dir = os.path.abspath(
|
||||
os.path.join(candidate_dir, os.pardir))
|
||||
raise NoBuildGnFoundError(
|
||||
'No BUILD.gn file found for file: `{}`'.format(file_path))
|
||||
|
||||
|
||||
def IsHeaderInBuildGn(header_path, build_gn_path):
|
||||
"""Returns True if the header is listed in the BUILD.gn file.
|
||||
"""Returns True if the header is listed in the BUILD.gn file.
|
||||
|
||||
Args:
|
||||
header_path: the absolute path to the header to check.
|
||||
@ -86,15 +85,15 @@ def IsHeaderInBuildGn(header_path, build_gn_path):
|
||||
at least one GN target in the BUILD.gn file specified by
|
||||
the argument build_gn_path.
|
||||
"""
|
||||
target_abs_path = os.path.dirname(build_gn_path)
|
||||
build_gn_content = _ReadFile(build_gn_path)
|
||||
headers_in_build_gn = GetHeadersInBuildGnFileSources(build_gn_content,
|
||||
target_abs_path)
|
||||
return header_path in headers_in_build_gn
|
||||
target_abs_path = os.path.dirname(build_gn_path)
|
||||
build_gn_content = _ReadFile(build_gn_path)
|
||||
headers_in_build_gn = GetHeadersInBuildGnFileSources(
|
||||
build_gn_content, target_abs_path)
|
||||
return header_path in headers_in_build_gn
|
||||
|
||||
|
||||
def GetHeadersInBuildGnFileSources(file_content, target_abs_path):
|
||||
"""Returns a set with all the .h files in the file_content.
|
||||
"""Returns a set with all the .h files in the file_content.
|
||||
|
||||
Args:
|
||||
file_content: a string with the content of the BUILD.gn file.
|
||||
@ -105,15 +104,15 @@ def GetHeadersInBuildGnFileSources(file_content, target_abs_path):
|
||||
A set with all the headers (.h file) in the file_content.
|
||||
The set contains absolute paths.
|
||||
"""
|
||||
headers_in_sources = set([])
|
||||
for target_match in TARGET_RE.finditer(file_content):
|
||||
target_contents = target_match.group('target_contents')
|
||||
for sources_match in SOURCES_RE.finditer(target_contents):
|
||||
sources = sources_match.group('sources')
|
||||
for source_file_match in SOURCE_FILE_RE.finditer(sources):
|
||||
source_file = source_file_match.group('source_file')
|
||||
if source_file.endswith('.h'):
|
||||
source_file_tokens = string.split(source_file, '/')
|
||||
headers_in_sources.add(os.path.join(target_abs_path,
|
||||
*source_file_tokens))
|
||||
return headers_in_sources
|
||||
headers_in_sources = set([])
|
||||
for target_match in TARGET_RE.finditer(file_content):
|
||||
target_contents = target_match.group('target_contents')
|
||||
for sources_match in SOURCES_RE.finditer(target_contents):
|
||||
sources = sources_match.group('sources')
|
||||
for source_file_match in SOURCE_FILE_RE.finditer(sources):
|
||||
source_file = source_file_match.group('source_file')
|
||||
if source_file.endswith('.h'):
|
||||
source_file_tokens = string.split(source_file, '/')
|
||||
headers_in_sources.add(
|
||||
os.path.join(target_abs_path, *source_file_tokens))
|
||||
return headers_in_sources
|
||||
|
||||
@ -16,73 +16,67 @@ import check_orphan_headers
|
||||
|
||||
|
||||
def _GetRootBasedOnPlatform():
|
||||
if sys.platform.startswith('win'):
|
||||
return 'C:\\'
|
||||
else:
|
||||
return '/'
|
||||
if sys.platform.startswith('win'):
|
||||
return 'C:\\'
|
||||
else:
|
||||
return '/'
|
||||
|
||||
|
||||
def _GetPath(*path_chunks):
|
||||
return os.path.join(_GetRootBasedOnPlatform(),
|
||||
*path_chunks)
|
||||
return os.path.join(_GetRootBasedOnPlatform(), *path_chunks)
|
||||
|
||||
|
||||
class GetBuildGnPathFromFilePathTest(unittest.TestCase):
|
||||
def testGetBuildGnFromSameDirectory(self):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
|
||||
expected_build_path = _GetPath('home', 'projects', 'webrtc', 'base',
|
||||
'BUILD.gn')
|
||||
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
|
||||
'base', 'BUILD.gn')
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
self.assertEqual(
|
||||
expected_build_path,
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(
|
||||
file_path, file_exists, src_dir_path))
|
||||
|
||||
def testGetBuildGnFromSameDirectory(self):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
|
||||
expected_build_path = _GetPath('home', 'projects', 'webrtc', 'base',
|
||||
'BUILD.gn')
|
||||
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
|
||||
'base', 'BUILD.gn')
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
self.assertEqual(
|
||||
expected_build_path,
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
|
||||
file_exists,
|
||||
src_dir_path))
|
||||
def testGetBuildPathFromParentDirectory(self):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
|
||||
expected_build_path = _GetPath('home', 'projects', 'webrtc',
|
||||
'BUILD.gn')
|
||||
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
|
||||
'BUILD.gn')
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
self.assertEqual(
|
||||
expected_build_path,
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(
|
||||
file_path, file_exists, src_dir_path))
|
||||
|
||||
def testGetBuildPathFromParentDirectory(self):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
|
||||
expected_build_path = _GetPath('home', 'projects', 'webrtc',
|
||||
'BUILD.gn')
|
||||
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
|
||||
'BUILD.gn')
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
self.assertEqual(
|
||||
expected_build_path,
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
|
||||
file_exists,
|
||||
src_dir_path))
|
||||
def testExceptionIfNoBuildGnFilesAreFound(self):
|
||||
with self.assertRaises(check_orphan_headers.NoBuildGnFoundError):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
|
||||
file_exists = lambda p: False
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(
|
||||
file_path, file_exists, src_dir_path)
|
||||
|
||||
def testExceptionIfNoBuildGnFilesAreFound(self):
|
||||
with self.assertRaises(check_orphan_headers.NoBuildGnFoundError):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
|
||||
file_exists = lambda p: False
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
|
||||
file_exists,
|
||||
src_dir_path)
|
||||
|
||||
def testExceptionIfFilePathIsNotAnHeader(self):
|
||||
with self.assertRaises(check_orphan_headers.WrongFileTypeError):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.cc')
|
||||
file_exists = lambda p: False
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
|
||||
file_exists,
|
||||
src_dir_path)
|
||||
def testExceptionIfFilePathIsNotAnHeader(self):
|
||||
with self.assertRaises(check_orphan_headers.WrongFileTypeError):
|
||||
file_path = _GetPath('home', 'projects', 'webrtc', 'base',
|
||||
'foo.cc')
|
||||
file_exists = lambda p: False
|
||||
src_dir_path = _GetPath('home', 'projects', 'webrtc')
|
||||
check_orphan_headers.GetBuildGnPathFromFilePath(
|
||||
file_path, file_exists, src_dir_path)
|
||||
|
||||
|
||||
class GetHeadersInBuildGnFileSourcesTest(unittest.TestCase):
|
||||
def testEmptyFileReturnsEmptySet(self):
|
||||
self.assertEqual(
|
||||
set([]),
|
||||
check_orphan_headers.GetHeadersInBuildGnFileSources('', '/a/b'))
|
||||
|
||||
def testEmptyFileReturnsEmptySet(self):
|
||||
self.assertEqual(
|
||||
set([]),
|
||||
check_orphan_headers.GetHeadersInBuildGnFileSources('', '/a/b'))
|
||||
|
||||
def testReturnsSetOfHeadersFromFileContent(self):
|
||||
file_content = """
|
||||
def testReturnsSetOfHeadersFromFileContent(self):
|
||||
file_content = """
|
||||
# Some comments
|
||||
if (is_android) {
|
||||
import("//a/b/c.gni")
|
||||
@ -107,17 +101,17 @@ class GetHeadersInBuildGnFileSourcesTest(unittest.TestCase):
|
||||
sources = ["baz/foo.h"]
|
||||
}
|
||||
"""
|
||||
target_abs_path = _GetPath('a', 'b')
|
||||
self.assertEqual(
|
||||
set([
|
||||
_GetPath('a', 'b', 'foo.h'),
|
||||
_GetPath('a', 'b', 'bar.h'),
|
||||
_GetPath('a', 'b', 'public_foo.h'),
|
||||
_GetPath('a', 'b', 'baz', 'foo.h'),
|
||||
]),
|
||||
check_orphan_headers.GetHeadersInBuildGnFileSources(file_content,
|
||||
target_abs_path))
|
||||
target_abs_path = _GetPath('a', 'b')
|
||||
self.assertEqual(
|
||||
set([
|
||||
_GetPath('a', 'b', 'foo.h'),
|
||||
_GetPath('a', 'b', 'bar.h'),
|
||||
_GetPath('a', 'b', 'public_foo.h'),
|
||||
_GetPath('a', 'b', 'baz', 'foo.h'),
|
||||
]),
|
||||
check_orphan_headers.GetHeadersInBuildGnFileSources(
|
||||
file_content, target_abs_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -14,12 +14,11 @@ import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
# TARGET_RE matches a GN target, and extracts the target name and the contents.
|
||||
TARGET_RE = re.compile(r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
|
||||
r'(?P<target_contents>.*?)'
|
||||
r'(?P=indent)}',
|
||||
re.MULTILINE | re.DOTALL)
|
||||
TARGET_RE = re.compile(
|
||||
r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
|
||||
r'(?P<target_contents>.*?)'
|
||||
r'(?P=indent)}', re.MULTILINE | re.DOTALL)
|
||||
|
||||
# SOURCES_RE matches a block of sources inside a GN target.
|
||||
SOURCES_RE = re.compile(r'sources \+?= \[(?P<sources>.*?)\]',
|
||||
@ -31,96 +30,107 @@ ERROR_MESSAGE = ("{build_file_path} in target '{target_name}':\n"
|
||||
|
||||
|
||||
class PackageBoundaryViolation(
|
||||
collections.namedtuple('PackageBoundaryViolation',
|
||||
'build_file_path target_name source_file subpackage')):
|
||||
def __str__(self):
|
||||
return ERROR_MESSAGE.format(**self._asdict())
|
||||
collections.namedtuple(
|
||||
'PackageBoundaryViolation',
|
||||
'build_file_path target_name source_file subpackage')):
|
||||
def __str__(self):
|
||||
return ERROR_MESSAGE.format(**self._asdict())
|
||||
|
||||
|
||||
def _BuildSubpackagesPattern(packages, query):
|
||||
"""Returns a regular expression that matches source files inside subpackages
|
||||
"""Returns a regular expression that matches source files inside subpackages
|
||||
of the given query."""
|
||||
query += os.path.sep
|
||||
length = len(query)
|
||||
pattern = r'\s*"(?P<source_file>(?P<subpackage>'
|
||||
pattern += '|'.join(re.escape(package[length:].replace(os.path.sep, '/'))
|
||||
for package in packages if package.startswith(query))
|
||||
pattern += r')/[\w\./]*)"'
|
||||
return re.compile(pattern)
|
||||
query += os.path.sep
|
||||
length = len(query)
|
||||
pattern = r'\s*"(?P<source_file>(?P<subpackage>'
|
||||
pattern += '|'.join(
|
||||
re.escape(package[length:].replace(os.path.sep, '/'))
|
||||
for package in packages if package.startswith(query))
|
||||
pattern += r')/[\w\./]*)"'
|
||||
return re.compile(pattern)
|
||||
|
||||
|
||||
def _ReadFileAndPrependLines(file_path):
|
||||
"""Reads the contents of a file."""
|
||||
with open(file_path) as f:
|
||||
return "".join(f.readlines())
|
||||
"""Reads the contents of a file."""
|
||||
with open(file_path) as f:
|
||||
return "".join(f.readlines())
|
||||
|
||||
|
||||
def _CheckBuildFile(build_file_path, packages):
|
||||
"""Iterates over all the targets of the given BUILD.gn file, and verifies that
|
||||
"""Iterates over all the targets of the given BUILD.gn file, and verifies that
|
||||
the source files referenced by it don't belong to any of it's subpackages.
|
||||
Returns an iterator over PackageBoundaryViolations for this package.
|
||||
"""
|
||||
package = os.path.dirname(build_file_path)
|
||||
subpackages_re = _BuildSubpackagesPattern(packages, package)
|
||||
package = os.path.dirname(build_file_path)
|
||||
subpackages_re = _BuildSubpackagesPattern(packages, package)
|
||||
|
||||
build_file_contents = _ReadFileAndPrependLines(build_file_path)
|
||||
for target_match in TARGET_RE.finditer(build_file_contents):
|
||||
target_name = target_match.group('target_name')
|
||||
target_contents = target_match.group('target_contents')
|
||||
for sources_match in SOURCES_RE.finditer(target_contents):
|
||||
sources = sources_match.group('sources')
|
||||
for subpackages_match in subpackages_re.finditer(sources):
|
||||
subpackage = subpackages_match.group('subpackage')
|
||||
source_file = subpackages_match.group('source_file')
|
||||
if subpackage:
|
||||
yield PackageBoundaryViolation(build_file_path,
|
||||
target_name, source_file, subpackage)
|
||||
build_file_contents = _ReadFileAndPrependLines(build_file_path)
|
||||
for target_match in TARGET_RE.finditer(build_file_contents):
|
||||
target_name = target_match.group('target_name')
|
||||
target_contents = target_match.group('target_contents')
|
||||
for sources_match in SOURCES_RE.finditer(target_contents):
|
||||
sources = sources_match.group('sources')
|
||||
for subpackages_match in subpackages_re.finditer(sources):
|
||||
subpackage = subpackages_match.group('subpackage')
|
||||
source_file = subpackages_match.group('source_file')
|
||||
if subpackage:
|
||||
yield PackageBoundaryViolation(build_file_path,
|
||||
target_name, source_file,
|
||||
subpackage)
|
||||
|
||||
|
||||
def CheckPackageBoundaries(root_dir, build_files=None):
|
||||
packages = [root for root, _, files in os.walk(root_dir)
|
||||
if 'BUILD.gn' in files]
|
||||
packages = [
|
||||
root for root, _, files in os.walk(root_dir) if 'BUILD.gn' in files
|
||||
]
|
||||
|
||||
if build_files is not None:
|
||||
if build_files is not None:
|
||||
for build_file_path in build_files:
|
||||
assert build_file_path.startswith(root_dir)
|
||||
else:
|
||||
build_files = [
|
||||
os.path.join(package, 'BUILD.gn') for package in packages
|
||||
]
|
||||
|
||||
messages = []
|
||||
for build_file_path in build_files:
|
||||
assert build_file_path.startswith(root_dir)
|
||||
else:
|
||||
build_files = [os.path.join(package, 'BUILD.gn') for package in packages]
|
||||
|
||||
messages = []
|
||||
for build_file_path in build_files:
|
||||
messages.extend(_CheckBuildFile(build_file_path, packages))
|
||||
return messages
|
||||
messages.extend(_CheckBuildFile(build_file_path, packages))
|
||||
return messages
|
||||
|
||||
|
||||
def main(argv):
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Script that checks package boundary violations in GN '
|
||||
'build files.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Script that checks package boundary violations in GN '
|
||||
'build files.')
|
||||
|
||||
parser.add_argument('root_dir', metavar='ROOT_DIR',
|
||||
help='The root directory that contains all BUILD.gn '
|
||||
'files to be processed.')
|
||||
parser.add_argument('build_files', metavar='BUILD_FILE', nargs='*',
|
||||
help='A list of BUILD.gn files to be processed. If no '
|
||||
'files are given, all BUILD.gn files under ROOT_DIR '
|
||||
'will be processed.')
|
||||
parser.add_argument('--max_messages', type=int, default=None,
|
||||
help='If set, the maximum number of violations to be '
|
||||
'displayed.')
|
||||
parser.add_argument('root_dir',
|
||||
metavar='ROOT_DIR',
|
||||
help='The root directory that contains all BUILD.gn '
|
||||
'files to be processed.')
|
||||
parser.add_argument('build_files',
|
||||
metavar='BUILD_FILE',
|
||||
nargs='*',
|
||||
help='A list of BUILD.gn files to be processed. If no '
|
||||
'files are given, all BUILD.gn files under ROOT_DIR '
|
||||
'will be processed.')
|
||||
parser.add_argument('--max_messages',
|
||||
type=int,
|
||||
default=None,
|
||||
help='If set, the maximum number of violations to be '
|
||||
'displayed.')
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
messages = CheckPackageBoundaries(args.root_dir, args.build_files)
|
||||
messages = messages[:args.max_messages]
|
||||
messages = CheckPackageBoundaries(args.root_dir, args.build_files)
|
||||
messages = messages[:args.max_messages]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if i > 0:
|
||||
print
|
||||
print message
|
||||
for i, message in enumerate(messages):
|
||||
if i > 0:
|
||||
print
|
||||
print message
|
||||
|
||||
return bool(messages)
|
||||
return bool(messages)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
|
||||
@ -15,58 +15,60 @@ import unittest
|
||||
#pylint: disable=relative-import
|
||||
from check_package_boundaries import CheckPackageBoundaries
|
||||
|
||||
|
||||
MSG_FORMAT = 'ERROR:check_package_boundaries.py: Unexpected %s.'
|
||||
TESTDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'testdata')
|
||||
|
||||
|
||||
def ReadPylFile(file_path):
|
||||
with open(file_path) as f:
|
||||
return ast.literal_eval(f.read())
|
||||
with open(file_path) as f:
|
||||
return ast.literal_eval(f.read())
|
||||
|
||||
|
||||
class UnitTest(unittest.TestCase):
|
||||
def _RunTest(self, test_dir, check_all_build_files=False):
|
||||
build_files = [os.path.join(test_dir, 'BUILD.gn')]
|
||||
if check_all_build_files:
|
||||
build_files = None
|
||||
def _RunTest(self, test_dir, check_all_build_files=False):
|
||||
build_files = [os.path.join(test_dir, 'BUILD.gn')]
|
||||
if check_all_build_files:
|
||||
build_files = None
|
||||
|
||||
messages = []
|
||||
for violation in CheckPackageBoundaries(test_dir, build_files):
|
||||
build_file_path = os.path.relpath(violation.build_file_path, test_dir)
|
||||
build_file_path = build_file_path.replace(os.path.sep, '/')
|
||||
messages.append(violation._replace(build_file_path=build_file_path))
|
||||
messages = []
|
||||
for violation in CheckPackageBoundaries(test_dir, build_files):
|
||||
build_file_path = os.path.relpath(violation.build_file_path,
|
||||
test_dir)
|
||||
build_file_path = build_file_path.replace(os.path.sep, '/')
|
||||
messages.append(
|
||||
violation._replace(build_file_path=build_file_path))
|
||||
|
||||
expected_messages = ReadPylFile(os.path.join(test_dir, 'expected.pyl'))
|
||||
self.assertListEqual(sorted(expected_messages), sorted(messages))
|
||||
expected_messages = ReadPylFile(os.path.join(test_dir, 'expected.pyl'))
|
||||
self.assertListEqual(sorted(expected_messages), sorted(messages))
|
||||
|
||||
def testNoErrors(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'no_errors'))
|
||||
def testNoErrors(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'no_errors'))
|
||||
|
||||
def testMultipleErrorsSingleTarget(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'multiple_errors_single_target'))
|
||||
def testMultipleErrorsSingleTarget(self):
|
||||
self._RunTest(
|
||||
os.path.join(TESTDATA_DIR, 'multiple_errors_single_target'))
|
||||
|
||||
def testMultipleErrorsMultipleTargets(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR,
|
||||
'multiple_errors_multiple_targets'))
|
||||
def testMultipleErrorsMultipleTargets(self):
|
||||
self._RunTest(
|
||||
os.path.join(TESTDATA_DIR, 'multiple_errors_multiple_targets'))
|
||||
|
||||
def testCommonPrefix(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'common_prefix'))
|
||||
def testCommonPrefix(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'common_prefix'))
|
||||
|
||||
def testAllBuildFiles(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'all_build_files'), True)
|
||||
def testAllBuildFiles(self):
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'all_build_files'), True)
|
||||
|
||||
def testSanitizeFilename(self):
|
||||
# The `dangerous_filename` test case contains a directory with '++' in its
|
||||
# name. If it's not properly escaped, a regex error would be raised.
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'dangerous_filename'), True)
|
||||
def testSanitizeFilename(self):
|
||||
# The `dangerous_filename` test case contains a directory with '++' in its
|
||||
# name. If it's not properly escaped, a regex error would be raised.
|
||||
self._RunTest(os.path.join(TESTDATA_DIR, 'dangerous_filename'), True)
|
||||
|
||||
def testRelativeFilename(self):
|
||||
test_dir = os.path.join(TESTDATA_DIR, 'all_build_files')
|
||||
with self.assertRaises(AssertionError):
|
||||
CheckPackageBoundaries(test_dir, ["BUILD.gn"])
|
||||
def testRelativeFilename(self):
|
||||
test_dir = os.path.join(TESTDATA_DIR, 'all_build_files')
|
||||
with self.assertRaises(AssertionError):
|
||||
CheckPackageBoundaries(test_dir, ["BUILD.gn"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@ -6,8 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
|
||||
"""This is a tool to transform a crt file into a C/C++ header.
|
||||
|
||||
Usage:
|
||||
@ -41,172 +39,180 @@ _VERBOSE = 'verbose'
|
||||
|
||||
|
||||
def main():
|
||||
"""The main entrypoint."""
|
||||
parser = OptionParser('usage %prog FILE')
|
||||
parser.add_option('-v', '--verbose', dest='verbose', action='store_true')
|
||||
parser.add_option('-f', '--full_cert', dest='full_cert', action='store_true')
|
||||
options, args = parser.parse_args()
|
||||
if len(args) < 1:
|
||||
parser.error('No crt file specified.')
|
||||
return
|
||||
root_dir = _SplitCrt(args[0], options)
|
||||
_GenCFiles(root_dir, options)
|
||||
_Cleanup(root_dir)
|
||||
"""The main entrypoint."""
|
||||
parser = OptionParser('usage %prog FILE')
|
||||
parser.add_option('-v', '--verbose', dest='verbose', action='store_true')
|
||||
parser.add_option('-f',
|
||||
'--full_cert',
|
||||
dest='full_cert',
|
||||
action='store_true')
|
||||
options, args = parser.parse_args()
|
||||
if len(args) < 1:
|
||||
parser.error('No crt file specified.')
|
||||
return
|
||||
root_dir = _SplitCrt(args[0], options)
|
||||
_GenCFiles(root_dir, options)
|
||||
_Cleanup(root_dir)
|
||||
|
||||
|
||||
def _SplitCrt(source_file, options):
|
||||
sub_file_blocks = []
|
||||
label_name = ''
|
||||
root_dir = os.path.dirname(os.path.abspath(source_file)) + '/'
|
||||
_PrintOutput(root_dir, options)
|
||||
f = open(source_file)
|
||||
for line in f:
|
||||
if line.startswith('# Label: '):
|
||||
sub_file_blocks.append(line)
|
||||
label = re.search(r'\".*\"', line)
|
||||
temp_label = label.group(0)
|
||||
end = len(temp_label)-1
|
||||
label_name = _SafeName(temp_label[1:end])
|
||||
elif line.startswith('-----END CERTIFICATE-----'):
|
||||
sub_file_blocks.append(line)
|
||||
new_file_name = root_dir + _PREFIX + label_name + _EXTENSION
|
||||
_PrintOutput('Generating: ' + new_file_name, options)
|
||||
new_file = open(new_file_name, 'w')
|
||||
for out_line in sub_file_blocks:
|
||||
new_file.write(out_line)
|
||||
new_file.close()
|
||||
sub_file_blocks = []
|
||||
else:
|
||||
sub_file_blocks.append(line)
|
||||
f.close()
|
||||
return root_dir
|
||||
sub_file_blocks = []
|
||||
label_name = ''
|
||||
root_dir = os.path.dirname(os.path.abspath(source_file)) + '/'
|
||||
_PrintOutput(root_dir, options)
|
||||
f = open(source_file)
|
||||
for line in f:
|
||||
if line.startswith('# Label: '):
|
||||
sub_file_blocks.append(line)
|
||||
label = re.search(r'\".*\"', line)
|
||||
temp_label = label.group(0)
|
||||
end = len(temp_label) - 1
|
||||
label_name = _SafeName(temp_label[1:end])
|
||||
elif line.startswith('-----END CERTIFICATE-----'):
|
||||
sub_file_blocks.append(line)
|
||||
new_file_name = root_dir + _PREFIX + label_name + _EXTENSION
|
||||
_PrintOutput('Generating: ' + new_file_name, options)
|
||||
new_file = open(new_file_name, 'w')
|
||||
for out_line in sub_file_blocks:
|
||||
new_file.write(out_line)
|
||||
new_file.close()
|
||||
sub_file_blocks = []
|
||||
else:
|
||||
sub_file_blocks.append(line)
|
||||
f.close()
|
||||
return root_dir
|
||||
|
||||
|
||||
def _GenCFiles(root_dir, options):
|
||||
output_header_file = open(root_dir + _GENERATED_FILE, 'w')
|
||||
output_header_file.write(_CreateOutputHeader())
|
||||
if options.full_cert:
|
||||
subject_name_list = _CreateArraySectionHeader(_SUBJECT_NAME_VARIABLE,
|
||||
_CHAR_TYPE, options)
|
||||
public_key_list = _CreateArraySectionHeader(_PUBLIC_KEY_VARIABLE,
|
||||
_CHAR_TYPE, options)
|
||||
certificate_list = _CreateArraySectionHeader(_CERTIFICATE_VARIABLE,
|
||||
_CHAR_TYPE, options)
|
||||
certificate_size_list = _CreateArraySectionHeader(_CERTIFICATE_SIZE_VARIABLE,
|
||||
_INT_TYPE, options)
|
||||
output_header_file = open(root_dir + _GENERATED_FILE, 'w')
|
||||
output_header_file.write(_CreateOutputHeader())
|
||||
if options.full_cert:
|
||||
subject_name_list = _CreateArraySectionHeader(_SUBJECT_NAME_VARIABLE,
|
||||
_CHAR_TYPE, options)
|
||||
public_key_list = _CreateArraySectionHeader(_PUBLIC_KEY_VARIABLE,
|
||||
_CHAR_TYPE, options)
|
||||
certificate_list = _CreateArraySectionHeader(_CERTIFICATE_VARIABLE,
|
||||
_CHAR_TYPE, options)
|
||||
certificate_size_list = _CreateArraySectionHeader(
|
||||
_CERTIFICATE_SIZE_VARIABLE, _INT_TYPE, options)
|
||||
|
||||
for _, _, files in os.walk(root_dir):
|
||||
for current_file in files:
|
||||
if current_file.startswith(_PREFIX):
|
||||
prefix_length = len(_PREFIX)
|
||||
length = len(current_file) - len(_EXTENSION)
|
||||
label = current_file[prefix_length:length]
|
||||
filtered_output, cert_size = _CreateCertSection(root_dir, current_file,
|
||||
label, options)
|
||||
output_header_file.write(filtered_output + '\n\n\n')
|
||||
if options.full_cert:
|
||||
subject_name_list += _AddLabelToArray(label, _SUBJECT_NAME_ARRAY)
|
||||
public_key_list += _AddLabelToArray(label, _PUBLIC_KEY_ARRAY)
|
||||
certificate_list += _AddLabelToArray(label, _CERTIFICATE_ARRAY)
|
||||
certificate_size_list += (' %s,\n') %(cert_size)
|
||||
for _, _, files in os.walk(root_dir):
|
||||
for current_file in files:
|
||||
if current_file.startswith(_PREFIX):
|
||||
prefix_length = len(_PREFIX)
|
||||
length = len(current_file) - len(_EXTENSION)
|
||||
label = current_file[prefix_length:length]
|
||||
filtered_output, cert_size = _CreateCertSection(
|
||||
root_dir, current_file, label, options)
|
||||
output_header_file.write(filtered_output + '\n\n\n')
|
||||
if options.full_cert:
|
||||
subject_name_list += _AddLabelToArray(
|
||||
label, _SUBJECT_NAME_ARRAY)
|
||||
public_key_list += _AddLabelToArray(
|
||||
label, _PUBLIC_KEY_ARRAY)
|
||||
certificate_list += _AddLabelToArray(label, _CERTIFICATE_ARRAY)
|
||||
certificate_size_list += (' %s,\n') % (cert_size)
|
||||
|
||||
if options.full_cert:
|
||||
subject_name_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(subject_name_list)
|
||||
public_key_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(public_key_list)
|
||||
certificate_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(certificate_list)
|
||||
certificate_size_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(certificate_size_list)
|
||||
output_header_file.write(_CreateOutputFooter())
|
||||
output_header_file.close()
|
||||
if options.full_cert:
|
||||
subject_name_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(subject_name_list)
|
||||
public_key_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(public_key_list)
|
||||
certificate_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(certificate_list)
|
||||
certificate_size_list += _CreateArraySectionFooter()
|
||||
output_header_file.write(certificate_size_list)
|
||||
output_header_file.write(_CreateOutputFooter())
|
||||
output_header_file.close()
|
||||
|
||||
|
||||
def _Cleanup(root_dir):
|
||||
for f in os.listdir(root_dir):
|
||||
if f.startswith(_PREFIX):
|
||||
os.remove(root_dir + f)
|
||||
for f in os.listdir(root_dir):
|
||||
if f.startswith(_PREFIX):
|
||||
os.remove(root_dir + f)
|
||||
|
||||
|
||||
def _CreateCertSection(root_dir, source_file, label, options):
|
||||
command = 'openssl x509 -in %s%s -noout -C' %(root_dir, source_file)
|
||||
_PrintOutput(command, options)
|
||||
output = commands.getstatusoutput(command)[1]
|
||||
renamed_output = output.replace('unsigned char XXX_',
|
||||
'const unsigned char ' + label + '_')
|
||||
filtered_output = ''
|
||||
cert_block = '^const unsigned char.*?};$'
|
||||
prog = re.compile(cert_block, re.IGNORECASE | re.MULTILINE | re.DOTALL)
|
||||
if not options.full_cert:
|
||||
filtered_output = prog.sub('', renamed_output, count=2)
|
||||
else:
|
||||
filtered_output = renamed_output
|
||||
command = 'openssl x509 -in %s%s -noout -C' % (root_dir, source_file)
|
||||
_PrintOutput(command, options)
|
||||
output = commands.getstatusoutput(command)[1]
|
||||
renamed_output = output.replace('unsigned char XXX_',
|
||||
'const unsigned char ' + label + '_')
|
||||
filtered_output = ''
|
||||
cert_block = '^const unsigned char.*?};$'
|
||||
prog = re.compile(cert_block, re.IGNORECASE | re.MULTILINE | re.DOTALL)
|
||||
if not options.full_cert:
|
||||
filtered_output = prog.sub('', renamed_output, count=2)
|
||||
else:
|
||||
filtered_output = renamed_output
|
||||
|
||||
cert_size_block = r'\d\d\d+'
|
||||
prog2 = re.compile(cert_size_block, re.MULTILINE | re.VERBOSE)
|
||||
result = prog2.findall(renamed_output)
|
||||
cert_size = result[len(result) - 1]
|
||||
cert_size_block = r'\d\d\d+'
|
||||
prog2 = re.compile(cert_size_block, re.MULTILINE | re.VERBOSE)
|
||||
result = prog2.findall(renamed_output)
|
||||
cert_size = result[len(result) - 1]
|
||||
|
||||
return filtered_output, cert_size
|
||||
return filtered_output, cert_size
|
||||
|
||||
|
||||
def _CreateOutputHeader():
|
||||
output = ('/*\n'
|
||||
' * Copyright 2004 The WebRTC Project Authors. All rights '
|
||||
'reserved.\n'
|
||||
' *\n'
|
||||
' * Use of this source code is governed by a BSD-style license\n'
|
||||
' * that can be found in the LICENSE file in the root of the '
|
||||
'source\n'
|
||||
' * tree. An additional intellectual property rights grant can be '
|
||||
'found\n'
|
||||
' * in the file PATENTS. All contributing project authors may\n'
|
||||
' * be found in the AUTHORS file in the root of the source tree.\n'
|
||||
' */\n\n'
|
||||
'#ifndef RTC_BASE_SSL_ROOTS_H_\n'
|
||||
'#define RTC_BASE_SSL_ROOTS_H_\n\n'
|
||||
'// This file is the root certificates in C form that are needed to'
|
||||
' connect to\n// Google.\n\n'
|
||||
'// It was generated with the following command line:\n'
|
||||
'// > python tools_webrtc/sslroots/generate_sslroots.py'
|
||||
'\n// https://pki.goog/roots.pem\n\n'
|
||||
'// clang-format off\n'
|
||||
'// Don\'t bother formatting generated code,\n'
|
||||
'// also it would breaks subject/issuer lines.\n\n')
|
||||
return output
|
||||
output = (
|
||||
'/*\n'
|
||||
' * Copyright 2004 The WebRTC Project Authors. All rights '
|
||||
'reserved.\n'
|
||||
' *\n'
|
||||
' * Use of this source code is governed by a BSD-style license\n'
|
||||
' * that can be found in the LICENSE file in the root of the '
|
||||
'source\n'
|
||||
' * tree. An additional intellectual property rights grant can be '
|
||||
'found\n'
|
||||
' * in the file PATENTS. All contributing project authors may\n'
|
||||
' * be found in the AUTHORS file in the root of the source tree.\n'
|
||||
' */\n\n'
|
||||
'#ifndef RTC_BASE_SSL_ROOTS_H_\n'
|
||||
'#define RTC_BASE_SSL_ROOTS_H_\n\n'
|
||||
'// This file is the root certificates in C form that are needed to'
|
||||
' connect to\n// Google.\n\n'
|
||||
'// It was generated with the following command line:\n'
|
||||
'// > python tools_webrtc/sslroots/generate_sslroots.py'
|
||||
'\n// https://pki.goog/roots.pem\n\n'
|
||||
'// clang-format off\n'
|
||||
'// Don\'t bother formatting generated code,\n'
|
||||
'// also it would breaks subject/issuer lines.\n\n')
|
||||
return output
|
||||
|
||||
|
||||
def _CreateOutputFooter():
|
||||
output = ('// clang-format on\n\n'
|
||||
'#endif // RTC_BASE_SSL_ROOTS_H_\n')
|
||||
return output
|
||||
output = ('// clang-format on\n\n' '#endif // RTC_BASE_SSL_ROOTS_H_\n')
|
||||
return output
|
||||
|
||||
|
||||
def _CreateArraySectionHeader(type_name, type_type, options):
|
||||
output = ('const %s kSSLCert%sList[] = {\n') %(type_type, type_name)
|
||||
_PrintOutput(output, options)
|
||||
return output
|
||||
output = ('const %s kSSLCert%sList[] = {\n') % (type_type, type_name)
|
||||
_PrintOutput(output, options)
|
||||
return output
|
||||
|
||||
|
||||
def _AddLabelToArray(label, type_name):
|
||||
return ' %s_%s,\n' %(label, type_name)
|
||||
return ' %s_%s,\n' % (label, type_name)
|
||||
|
||||
|
||||
def _CreateArraySectionFooter():
|
||||
return '};\n\n'
|
||||
return '};\n\n'
|
||||
|
||||
|
||||
def _SafeName(original_file_name):
|
||||
bad_chars = ' -./\\()áéíőú'
|
||||
replacement_chars = ''
|
||||
for _ in bad_chars:
|
||||
replacement_chars += '_'
|
||||
translation_table = string.maketrans(bad_chars, replacement_chars)
|
||||
return original_file_name.translate(translation_table)
|
||||
bad_chars = ' -./\\()áéíőú'
|
||||
replacement_chars = ''
|
||||
for _ in bad_chars:
|
||||
replacement_chars += '_'
|
||||
translation_table = string.maketrans(bad_chars, replacement_chars)
|
||||
return original_file_name.translate(translation_table)
|
||||
|
||||
|
||||
def _PrintOutput(output, options):
|
||||
if options.verbose:
|
||||
print output
|
||||
if options.verbose:
|
||||
print output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -53,7 +53,6 @@
|
||||
#
|
||||
# * This has only been tested on gPrecise.
|
||||
|
||||
|
||||
import os
|
||||
import os.path
|
||||
import shlex
|
||||
@ -62,25 +61,26 @@ import sys
|
||||
|
||||
# Flags from YCM's default config.
|
||||
_DEFAULT_FLAGS = [
|
||||
'-DUSE_CLANG_COMPLETER',
|
||||
'-std=c++11',
|
||||
'-x',
|
||||
'c++',
|
||||
'-DUSE_CLANG_COMPLETER',
|
||||
'-std=c++11',
|
||||
'-x',
|
||||
'c++',
|
||||
]
|
||||
|
||||
_HEADER_ALTERNATES = ('.cc', '.cpp', '.c', '.mm', '.m')
|
||||
|
||||
_EXTENSION_FLAGS = {
|
||||
'.m': ['-x', 'objective-c'],
|
||||
'.mm': ['-x', 'objective-c++'],
|
||||
'.m': ['-x', 'objective-c'],
|
||||
'.mm': ['-x', 'objective-c++'],
|
||||
}
|
||||
|
||||
|
||||
def PathExists(*args):
|
||||
return os.path.exists(os.path.join(*args))
|
||||
return os.path.exists(os.path.join(*args))
|
||||
|
||||
|
||||
def FindWebrtcSrcFromFilename(filename):
|
||||
"""Searches for the root of the WebRTC checkout.
|
||||
"""Searches for the root of the WebRTC checkout.
|
||||
|
||||
Simply checks parent directories until it finds .gclient and src/.
|
||||
|
||||
@ -90,20 +90,20 @@ def FindWebrtcSrcFromFilename(filename):
|
||||
Returns:
|
||||
(String) Path of 'src/', or None if unable to find.
|
||||
"""
|
||||
curdir = os.path.normpath(os.path.dirname(filename))
|
||||
while not (os.path.basename(curdir) == 'src'
|
||||
and PathExists(curdir, 'DEPS')
|
||||
and (PathExists(curdir, '..', '.gclient')
|
||||
or PathExists(curdir, '.git'))):
|
||||
nextdir = os.path.normpath(os.path.join(curdir, '..'))
|
||||
if nextdir == curdir:
|
||||
return None
|
||||
curdir = nextdir
|
||||
return curdir
|
||||
curdir = os.path.normpath(os.path.dirname(filename))
|
||||
while not (os.path.basename(curdir) == 'src'
|
||||
and PathExists(curdir, 'DEPS') and
|
||||
(PathExists(curdir, '..', '.gclient')
|
||||
or PathExists(curdir, '.git'))):
|
||||
nextdir = os.path.normpath(os.path.join(curdir, '..'))
|
||||
if nextdir == curdir:
|
||||
return None
|
||||
curdir = nextdir
|
||||
return curdir
|
||||
|
||||
|
||||
def GetDefaultSourceFile(webrtc_root, filename):
|
||||
"""Returns the default source file to use as an alternative to |filename|.
|
||||
"""Returns the default source file to use as an alternative to |filename|.
|
||||
|
||||
Compile flags used to build the default source file is assumed to be a
|
||||
close-enough approximation for building |filename|.
|
||||
@ -115,13 +115,13 @@ def GetDefaultSourceFile(webrtc_root, filename):
|
||||
Returns:
|
||||
(String) Absolute path to substitute source file.
|
||||
"""
|
||||
if 'test.' in filename:
|
||||
return os.path.join(webrtc_root, 'base', 'logging_unittest.cc')
|
||||
return os.path.join(webrtc_root, 'base', 'logging.cc')
|
||||
if 'test.' in filename:
|
||||
return os.path.join(webrtc_root, 'base', 'logging_unittest.cc')
|
||||
return os.path.join(webrtc_root, 'base', 'logging.cc')
|
||||
|
||||
|
||||
def GetNinjaBuildOutputsForSourceFile(out_dir, filename):
|
||||
"""Returns a list of build outputs for filename.
|
||||
"""Returns a list of build outputs for filename.
|
||||
|
||||
The list is generated by invoking 'ninja -t query' tool to retrieve a list of
|
||||
inputs and outputs of |filename|. This list is then filtered to only include
|
||||
@ -135,32 +135,35 @@ def GetNinjaBuildOutputsForSourceFile(out_dir, filename):
|
||||
(List of Strings) List of target names. Will return [] if |filename| doesn't
|
||||
yield any .o or .obj outputs.
|
||||
"""
|
||||
# Ninja needs the path to the source file relative to the output build
|
||||
# directory.
|
||||
rel_filename = os.path.relpath(filename, out_dir)
|
||||
# Ninja needs the path to the source file relative to the output build
|
||||
# directory.
|
||||
rel_filename = os.path.relpath(filename, out_dir)
|
||||
|
||||
p = subprocess.Popen(['ninja', '-C', out_dir, '-t', 'query', rel_filename],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
universal_newlines=True)
|
||||
stdout, _ = p.communicate()
|
||||
if p.returncode != 0:
|
||||
return []
|
||||
p = subprocess.Popen(['ninja', '-C', out_dir, '-t', 'query', rel_filename],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True)
|
||||
stdout, _ = p.communicate()
|
||||
if p.returncode != 0:
|
||||
return []
|
||||
|
||||
# The output looks like:
|
||||
# ../../relative/path/to/source.cc:
|
||||
# outputs:
|
||||
# obj/reative/path/to/target.source.o
|
||||
# obj/some/other/target2.source.o
|
||||
# another/target.txt
|
||||
#
|
||||
outputs_text = stdout.partition('\n outputs:\n')[2]
|
||||
output_lines = [line.strip() for line in outputs_text.split('\n')]
|
||||
return [target for target in output_lines
|
||||
if target and (target.endswith('.o') or target.endswith('.obj'))]
|
||||
# The output looks like:
|
||||
# ../../relative/path/to/source.cc:
|
||||
# outputs:
|
||||
# obj/reative/path/to/target.source.o
|
||||
# obj/some/other/target2.source.o
|
||||
# another/target.txt
|
||||
#
|
||||
outputs_text = stdout.partition('\n outputs:\n')[2]
|
||||
output_lines = [line.strip() for line in outputs_text.split('\n')]
|
||||
return [
|
||||
target for target in output_lines
|
||||
if target and (target.endswith('.o') or target.endswith('.obj'))
|
||||
]
|
||||
|
||||
|
||||
def GetClangCommandLineForNinjaOutput(out_dir, build_target):
|
||||
"""Returns the Clang command line for building |build_target|
|
||||
"""Returns the Clang command line for building |build_target|
|
||||
|
||||
Asks ninja for the list of commands used to build |filename| and returns the
|
||||
final Clang invocation.
|
||||
@ -173,24 +176,25 @@ def GetClangCommandLineForNinjaOutput(out_dir, build_target):
|
||||
(String or None) Clang command line or None if a Clang command line couldn't
|
||||
be determined.
|
||||
"""
|
||||
p = subprocess.Popen(['ninja', '-v', '-C', out_dir,
|
||||
'-t', 'commands', build_target],
|
||||
stdout=subprocess.PIPE, universal_newlines=True)
|
||||
stdout, _ = p.communicate()
|
||||
if p.returncode != 0:
|
||||
return None
|
||||
p = subprocess.Popen(
|
||||
['ninja', '-v', '-C', out_dir, '-t', 'commands', build_target],
|
||||
stdout=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
stdout, _ = p.communicate()
|
||||
if p.returncode != 0:
|
||||
return None
|
||||
|
||||
# Ninja will return multiple build steps for all dependencies up to
|
||||
# |build_target|. The build step we want is the last Clang invocation, which
|
||||
# is expected to be the one that outputs |build_target|.
|
||||
for line in reversed(stdout.split('\n')):
|
||||
if 'clang' in line:
|
||||
return line
|
||||
return None
|
||||
# Ninja will return multiple build steps for all dependencies up to
|
||||
# |build_target|. The build step we want is the last Clang invocation, which
|
||||
# is expected to be the one that outputs |build_target|.
|
||||
for line in reversed(stdout.split('\n')):
|
||||
if 'clang' in line:
|
||||
return line
|
||||
return None
|
||||
|
||||
|
||||
def GetClangCommandLineFromNinjaForSource(out_dir, filename):
|
||||
"""Returns a Clang command line used to build |filename|.
|
||||
"""Returns a Clang command line used to build |filename|.
|
||||
|
||||
The same source file could be built multiple times using different tool
|
||||
chains. In such cases, this command returns the first Clang invocation. We
|
||||
@ -206,17 +210,17 @@ def GetClangCommandLineFromNinjaForSource(out_dir, filename):
|
||||
(String or None): Command line for Clang invocation using |filename| as a
|
||||
source. Returns None if no such command line could be found.
|
||||
"""
|
||||
build_targets = GetNinjaBuildOutputsForSourceFile(out_dir, filename)
|
||||
for build_target in build_targets:
|
||||
command_line = GetClangCommandLineForNinjaOutput(out_dir, build_target)
|
||||
if command_line:
|
||||
return command_line
|
||||
return None
|
||||
build_targets = GetNinjaBuildOutputsForSourceFile(out_dir, filename)
|
||||
for build_target in build_targets:
|
||||
command_line = GetClangCommandLineForNinjaOutput(out_dir, build_target)
|
||||
if command_line:
|
||||
return command_line
|
||||
return None
|
||||
|
||||
|
||||
def GetClangOptionsFromCommandLine(clang_commandline, out_dir,
|
||||
additional_flags):
|
||||
"""Extracts relevant command line options from |clang_commandline|
|
||||
"""Extracts relevant command line options from |clang_commandline|
|
||||
|
||||
Args:
|
||||
clang_commandline: (String) Full Clang invocation.
|
||||
@ -228,46 +232,47 @@ def GetClangOptionsFromCommandLine(clang_commandline, out_dir,
|
||||
(List of Strings) The list of command line flags for this source file. Can
|
||||
be empty.
|
||||
"""
|
||||
clang_flags = [] + additional_flags
|
||||
clang_flags = [] + additional_flags
|
||||
|
||||
# Parse flags that are important for YCM's purposes.
|
||||
clang_tokens = shlex.split(clang_commandline)
|
||||
for flag_index, flag in enumerate(clang_tokens):
|
||||
if flag.startswith('-I'):
|
||||
# Relative paths need to be resolved, because they're relative to the
|
||||
# output dir, not the source.
|
||||
if flag[2] == '/':
|
||||
clang_flags.append(flag)
|
||||
else:
|
||||
abs_path = os.path.normpath(os.path.join(out_dir, flag[2:]))
|
||||
clang_flags.append('-I' + abs_path)
|
||||
elif flag.startswith('-std'):
|
||||
clang_flags.append(flag)
|
||||
elif flag.startswith('-') and flag[1] in 'DWFfmO':
|
||||
if flag == '-Wno-deprecated-register' or flag == '-Wno-header-guard':
|
||||
# These flags causes libclang (3.3) to crash. Remove it until things
|
||||
# are fixed.
|
||||
continue
|
||||
clang_flags.append(flag)
|
||||
elif flag == '-isysroot':
|
||||
# On Mac -isysroot <path> is used to find the system headers.
|
||||
# Copy over both flags.
|
||||
if flag_index + 1 < len(clang_tokens):
|
||||
clang_flags.append(flag)
|
||||
clang_flags.append(clang_tokens[flag_index + 1])
|
||||
elif flag.startswith('--sysroot='):
|
||||
# On Linux we use a sysroot image.
|
||||
sysroot_path = flag.lstrip('--sysroot=')
|
||||
if sysroot_path.startswith('/'):
|
||||
clang_flags.append(flag)
|
||||
else:
|
||||
abs_path = os.path.normpath(os.path.join(out_dir, sysroot_path))
|
||||
clang_flags.append('--sysroot=' + abs_path)
|
||||
return clang_flags
|
||||
# Parse flags that are important for YCM's purposes.
|
||||
clang_tokens = shlex.split(clang_commandline)
|
||||
for flag_index, flag in enumerate(clang_tokens):
|
||||
if flag.startswith('-I'):
|
||||
# Relative paths need to be resolved, because they're relative to the
|
||||
# output dir, not the source.
|
||||
if flag[2] == '/':
|
||||
clang_flags.append(flag)
|
||||
else:
|
||||
abs_path = os.path.normpath(os.path.join(out_dir, flag[2:]))
|
||||
clang_flags.append('-I' + abs_path)
|
||||
elif flag.startswith('-std'):
|
||||
clang_flags.append(flag)
|
||||
elif flag.startswith('-') and flag[1] in 'DWFfmO':
|
||||
if flag == '-Wno-deprecated-register' or flag == '-Wno-header-guard':
|
||||
# These flags causes libclang (3.3) to crash. Remove it until things
|
||||
# are fixed.
|
||||
continue
|
||||
clang_flags.append(flag)
|
||||
elif flag == '-isysroot':
|
||||
# On Mac -isysroot <path> is used to find the system headers.
|
||||
# Copy over both flags.
|
||||
if flag_index + 1 < len(clang_tokens):
|
||||
clang_flags.append(flag)
|
||||
clang_flags.append(clang_tokens[flag_index + 1])
|
||||
elif flag.startswith('--sysroot='):
|
||||
# On Linux we use a sysroot image.
|
||||
sysroot_path = flag.lstrip('--sysroot=')
|
||||
if sysroot_path.startswith('/'):
|
||||
clang_flags.append(flag)
|
||||
else:
|
||||
abs_path = os.path.normpath(os.path.join(
|
||||
out_dir, sysroot_path))
|
||||
clang_flags.append('--sysroot=' + abs_path)
|
||||
return clang_flags
|
||||
|
||||
|
||||
def GetClangOptionsFromNinjaForFilename(webrtc_root, filename):
|
||||
"""Returns the Clang command line options needed for building |filename|.
|
||||
"""Returns the Clang command line options needed for building |filename|.
|
||||
|
||||
Command line options are based on the command used by ninja for building
|
||||
|filename|. If |filename| is a .h file, uses its companion .cc or .cpp file.
|
||||
@ -283,54 +288,55 @@ def GetClangOptionsFromNinjaForFilename(webrtc_root, filename):
|
||||
(List of Strings) The list of command line flags for this source file. Can
|
||||
be empty.
|
||||
"""
|
||||
if not webrtc_root:
|
||||
return []
|
||||
if not webrtc_root:
|
||||
return []
|
||||
|
||||
# Generally, everyone benefits from including WebRTC's src/, because all of
|
||||
# WebRTC's includes are relative to that.
|
||||
additional_flags = ['-I' + os.path.join(webrtc_root)]
|
||||
# Generally, everyone benefits from including WebRTC's src/, because all of
|
||||
# WebRTC's includes are relative to that.
|
||||
additional_flags = ['-I' + os.path.join(webrtc_root)]
|
||||
|
||||
# Version of Clang used to compile WebRTC can be newer then version of
|
||||
# libclang that YCM uses for completion. So it's possible that YCM's libclang
|
||||
# doesn't know about some used warning options, which causes compilation
|
||||
# warnings (and errors, because of '-Werror');
|
||||
additional_flags.append('-Wno-unknown-warning-option')
|
||||
# Version of Clang used to compile WebRTC can be newer then version of
|
||||
# libclang that YCM uses for completion. So it's possible that YCM's libclang
|
||||
# doesn't know about some used warning options, which causes compilation
|
||||
# warnings (and errors, because of '-Werror');
|
||||
additional_flags.append('-Wno-unknown-warning-option')
|
||||
|
||||
sys.path.append(os.path.join(webrtc_root, 'tools', 'vim'))
|
||||
from ninja_output import GetNinjaOutputDirectory
|
||||
out_dir = GetNinjaOutputDirectory(webrtc_root)
|
||||
sys.path.append(os.path.join(webrtc_root, 'tools', 'vim'))
|
||||
from ninja_output import GetNinjaOutputDirectory
|
||||
out_dir = GetNinjaOutputDirectory(webrtc_root)
|
||||
|
||||
basename, extension = os.path.splitext(filename)
|
||||
if extension == '.h':
|
||||
candidates = [basename + ext for ext in _HEADER_ALTERNATES]
|
||||
else:
|
||||
candidates = [filename]
|
||||
basename, extension = os.path.splitext(filename)
|
||||
if extension == '.h':
|
||||
candidates = [basename + ext for ext in _HEADER_ALTERNATES]
|
||||
else:
|
||||
candidates = [filename]
|
||||
|
||||
clang_line = None
|
||||
buildable_extension = extension
|
||||
for candidate in candidates:
|
||||
clang_line = GetClangCommandLineFromNinjaForSource(out_dir, candidate)
|
||||
if clang_line:
|
||||
buildable_extension = os.path.splitext(candidate)[1]
|
||||
break
|
||||
clang_line = None
|
||||
buildable_extension = extension
|
||||
for candidate in candidates:
|
||||
clang_line = GetClangCommandLineFromNinjaForSource(out_dir, candidate)
|
||||
if clang_line:
|
||||
buildable_extension = os.path.splitext(candidate)[1]
|
||||
break
|
||||
|
||||
additional_flags += _EXTENSION_FLAGS.get(buildable_extension, [])
|
||||
additional_flags += _EXTENSION_FLAGS.get(buildable_extension, [])
|
||||
|
||||
if not clang_line:
|
||||
# If ninja didn't know about filename or it's companion files, then try a
|
||||
# default build target. It is possible that the file is new, or build.ninja
|
||||
# is stale.
|
||||
clang_line = GetClangCommandLineFromNinjaForSource(
|
||||
out_dir, GetDefaultSourceFile(webrtc_root, filename))
|
||||
if not clang_line:
|
||||
# If ninja didn't know about filename or it's companion files, then try a
|
||||
# default build target. It is possible that the file is new, or build.ninja
|
||||
# is stale.
|
||||
clang_line = GetClangCommandLineFromNinjaForSource(
|
||||
out_dir, GetDefaultSourceFile(webrtc_root, filename))
|
||||
|
||||
if not clang_line:
|
||||
return additional_flags
|
||||
if not clang_line:
|
||||
return additional_flags
|
||||
|
||||
return GetClangOptionsFromCommandLine(clang_line, out_dir, additional_flags)
|
||||
return GetClangOptionsFromCommandLine(clang_line, out_dir,
|
||||
additional_flags)
|
||||
|
||||
|
||||
def FlagsForFile(filename):
|
||||
"""This is the main entry point for YCM. Its interface is fixed.
|
||||
"""This is the main entry point for YCM. Its interface is fixed.
|
||||
|
||||
Args:
|
||||
filename: (String) Path to source file being edited.
|
||||
@ -340,18 +346,16 @@ def FlagsForFile(filename):
|
||||
'flags': (List of Strings) Command line flags.
|
||||
'do_cache': (Boolean) True if the result should be cached.
|
||||
"""
|
||||
abs_filename = os.path.abspath(filename)
|
||||
webrtc_root = FindWebrtcSrcFromFilename(abs_filename)
|
||||
clang_flags = GetClangOptionsFromNinjaForFilename(webrtc_root, abs_filename)
|
||||
abs_filename = os.path.abspath(filename)
|
||||
webrtc_root = FindWebrtcSrcFromFilename(abs_filename)
|
||||
clang_flags = GetClangOptionsFromNinjaForFilename(webrtc_root,
|
||||
abs_filename)
|
||||
|
||||
# If clang_flags could not be determined, then assume that was due to a
|
||||
# transient failure. Preventing YCM from caching the flags allows us to try to
|
||||
# determine the flags again.
|
||||
should_cache_flags_for_file = bool(clang_flags)
|
||||
# If clang_flags could not be determined, then assume that was due to a
|
||||
# transient failure. Preventing YCM from caching the flags allows us to try to
|
||||
# determine the flags again.
|
||||
should_cache_flags_for_file = bool(clang_flags)
|
||||
|
||||
final_flags = _DEFAULT_FLAGS + clang_flags
|
||||
final_flags = _DEFAULT_FLAGS + clang_flags
|
||||
|
||||
return {
|
||||
'flags': final_flags,
|
||||
'do_cache': should_cache_flags_for_file
|
||||
}
|
||||
return {'flags': final_flags, 'do_cache': should_cache_flags_for_file}
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
"""Generate graphs for data generated by loopback tests.
|
||||
|
||||
Usage examples:
|
||||
@ -34,14 +33,14 @@ import numpy
|
||||
|
||||
# Fields
|
||||
DROPPED = 0
|
||||
INPUT_TIME = 1 # ms (timestamp)
|
||||
SEND_TIME = 2 # ms (timestamp)
|
||||
RECV_TIME = 3 # ms (timestamp)
|
||||
RENDER_TIME = 4 # ms (timestamp)
|
||||
ENCODED_FRAME_SIZE = 5 # bytes
|
||||
INPUT_TIME = 1 # ms (timestamp)
|
||||
SEND_TIME = 2 # ms (timestamp)
|
||||
RECV_TIME = 3 # ms (timestamp)
|
||||
RENDER_TIME = 4 # ms (timestamp)
|
||||
ENCODED_FRAME_SIZE = 5 # bytes
|
||||
PSNR = 6
|
||||
SSIM = 7
|
||||
ENCODE_TIME = 8 # ms (time interval)
|
||||
ENCODE_TIME = 8 # ms (time interval)
|
||||
|
||||
TOTAL_RAW_FIELDS = 9
|
||||
|
||||
@ -78,111 +77,116 @@ _FIELDS = [
|
||||
NAME_TO_ID = {field[1]: field[0] for field in _FIELDS}
|
||||
ID_TO_TITLE = {field[0]: field[2] for field in _FIELDS}
|
||||
|
||||
|
||||
def FieldArgToId(arg):
|
||||
if arg == "none":
|
||||
return None
|
||||
if arg in NAME_TO_ID:
|
||||
return NAME_TO_ID[arg]
|
||||
if arg + "_ms" in NAME_TO_ID:
|
||||
return NAME_TO_ID[arg + "_ms"]
|
||||
raise Exception("Unrecognized field name \"{}\"".format(arg))
|
||||
if arg == "none":
|
||||
return None
|
||||
if arg in NAME_TO_ID:
|
||||
return NAME_TO_ID[arg]
|
||||
if arg + "_ms" in NAME_TO_ID:
|
||||
return NAME_TO_ID[arg + "_ms"]
|
||||
raise Exception("Unrecognized field name \"{}\"".format(arg))
|
||||
|
||||
|
||||
class PlotLine(object):
|
||||
"""Data for a single graph line."""
|
||||
"""Data for a single graph line."""
|
||||
|
||||
def __init__(self, label, values, flags):
|
||||
self.label = label
|
||||
self.values = values
|
||||
self.flags = flags
|
||||
def __init__(self, label, values, flags):
|
||||
self.label = label
|
||||
self.values = values
|
||||
self.flags = flags
|
||||
|
||||
|
||||
class Data(object):
|
||||
"""Object representing one full stack test."""
|
||||
"""Object representing one full stack test."""
|
||||
|
||||
def __init__(self, filename):
|
||||
self.title = ""
|
||||
self.length = 0
|
||||
self.samples = defaultdict(list)
|
||||
def __init__(self, filename):
|
||||
self.title = ""
|
||||
self.length = 0
|
||||
self.samples = defaultdict(list)
|
||||
|
||||
self._ReadSamples(filename)
|
||||
self._ReadSamples(filename)
|
||||
|
||||
def _ReadSamples(self, filename):
|
||||
"""Reads graph data from the given file."""
|
||||
f = open(filename)
|
||||
it = iter(f)
|
||||
def _ReadSamples(self, filename):
|
||||
"""Reads graph data from the given file."""
|
||||
f = open(filename)
|
||||
it = iter(f)
|
||||
|
||||
self.title = it.next().strip()
|
||||
self.length = int(it.next())
|
||||
field_names = [name.strip() for name in it.next().split()]
|
||||
field_ids = [NAME_TO_ID[name] for name in field_names]
|
||||
self.title = it.next().strip()
|
||||
self.length = int(it.next())
|
||||
field_names = [name.strip() for name in it.next().split()]
|
||||
field_ids = [NAME_TO_ID[name] for name in field_names]
|
||||
|
||||
for field_id in field_ids:
|
||||
self.samples[field_id] = [0.0] * self.length
|
||||
for field_id in field_ids:
|
||||
self.samples[field_id] = [0.0] * self.length
|
||||
|
||||
for sample_id in xrange(self.length):
|
||||
for col, value in enumerate(it.next().split()):
|
||||
self.samples[field_ids[col]][sample_id] = float(value)
|
||||
for sample_id in xrange(self.length):
|
||||
for col, value in enumerate(it.next().split()):
|
||||
self.samples[field_ids[col]][sample_id] = float(value)
|
||||
|
||||
self._SubtractFirstInputTime()
|
||||
self._GenerateAdditionalData()
|
||||
self._SubtractFirstInputTime()
|
||||
self._GenerateAdditionalData()
|
||||
|
||||
f.close()
|
||||
f.close()
|
||||
|
||||
def _SubtractFirstInputTime(self):
|
||||
offset = self.samples[INPUT_TIME][0]
|
||||
for field in [INPUT_TIME, SEND_TIME, RECV_TIME, RENDER_TIME]:
|
||||
if field in self.samples:
|
||||
self.samples[field] = [x - offset for x in self.samples[field]]
|
||||
def _SubtractFirstInputTime(self):
|
||||
offset = self.samples[INPUT_TIME][0]
|
||||
for field in [INPUT_TIME, SEND_TIME, RECV_TIME, RENDER_TIME]:
|
||||
if field in self.samples:
|
||||
self.samples[field] = [x - offset for x in self.samples[field]]
|
||||
|
||||
def _GenerateAdditionalData(self):
|
||||
"""Calculates sender time, receiver time etc. from the raw data."""
|
||||
s = self.samples
|
||||
last_render_time = 0
|
||||
for field_id in [SENDER_TIME, RECEIVER_TIME, END_TO_END, RENDERED_DELTA]:
|
||||
s[field_id] = [0] * self.length
|
||||
def _GenerateAdditionalData(self):
|
||||
"""Calculates sender time, receiver time etc. from the raw data."""
|
||||
s = self.samples
|
||||
last_render_time = 0
|
||||
for field_id in [
|
||||
SENDER_TIME, RECEIVER_TIME, END_TO_END, RENDERED_DELTA
|
||||
]:
|
||||
s[field_id] = [0] * self.length
|
||||
|
||||
for k in range(self.length):
|
||||
s[SENDER_TIME][k] = s[SEND_TIME][k] - s[INPUT_TIME][k]
|
||||
for k in range(self.length):
|
||||
s[SENDER_TIME][k] = s[SEND_TIME][k] - s[INPUT_TIME][k]
|
||||
|
||||
decoded_time = s[RENDER_TIME][k]
|
||||
s[RECEIVER_TIME][k] = decoded_time - s[RECV_TIME][k]
|
||||
s[END_TO_END][k] = decoded_time - s[INPUT_TIME][k]
|
||||
if not s[DROPPED][k]:
|
||||
if k > 0:
|
||||
s[RENDERED_DELTA][k] = decoded_time - last_render_time
|
||||
last_render_time = decoded_time
|
||||
decoded_time = s[RENDER_TIME][k]
|
||||
s[RECEIVER_TIME][k] = decoded_time - s[RECV_TIME][k]
|
||||
s[END_TO_END][k] = decoded_time - s[INPUT_TIME][k]
|
||||
if not s[DROPPED][k]:
|
||||
if k > 0:
|
||||
s[RENDERED_DELTA][k] = decoded_time - last_render_time
|
||||
last_render_time = decoded_time
|
||||
|
||||
def _Hide(self, values):
|
||||
"""
|
||||
def _Hide(self, values):
|
||||
"""
|
||||
Replaces values for dropped frames with None.
|
||||
These values are then skipped by the Plot() method.
|
||||
"""
|
||||
|
||||
return [None if self.samples[DROPPED][k] else values[k]
|
||||
for k in range(len(values))]
|
||||
return [
|
||||
None if self.samples[DROPPED][k] else values[k]
|
||||
for k in range(len(values))
|
||||
]
|
||||
|
||||
def AddSamples(self, config, target_lines_list):
|
||||
"""Creates graph lines from the current data set with given config."""
|
||||
for field in config.fields:
|
||||
# field is None means the user wants just to skip the color.
|
||||
if field is None:
|
||||
target_lines_list.append(None)
|
||||
continue
|
||||
def AddSamples(self, config, target_lines_list):
|
||||
"""Creates graph lines from the current data set with given config."""
|
||||
for field in config.fields:
|
||||
# field is None means the user wants just to skip the color.
|
||||
if field is None:
|
||||
target_lines_list.append(None)
|
||||
continue
|
||||
|
||||
field_id = field & FIELD_MASK
|
||||
values = self.samples[field_id]
|
||||
field_id = field & FIELD_MASK
|
||||
values = self.samples[field_id]
|
||||
|
||||
if field & HIDE_DROPPED:
|
||||
values = self._Hide(values)
|
||||
if field & HIDE_DROPPED:
|
||||
values = self._Hide(values)
|
||||
|
||||
target_lines_list.append(PlotLine(
|
||||
self.title + " " + ID_TO_TITLE[field_id],
|
||||
values, field & ~FIELD_MASK))
|
||||
target_lines_list.append(
|
||||
PlotLine(self.title + " " + ID_TO_TITLE[field_id], values,
|
||||
field & ~FIELD_MASK))
|
||||
|
||||
|
||||
def AverageOverCycle(values, length):
|
||||
"""
|
||||
"""
|
||||
Returns the list:
|
||||
[
|
||||
avg(values[0], values[length], ...),
|
||||
@ -194,221 +198,272 @@ def AverageOverCycle(values, length):
|
||||
Skips None values when calculating the average value.
|
||||
"""
|
||||
|
||||
total = [0.0] * length
|
||||
count = [0] * length
|
||||
for k, val in enumerate(values):
|
||||
if val is not None:
|
||||
total[k % length] += val
|
||||
count[k % length] += 1
|
||||
total = [0.0] * length
|
||||
count = [0] * length
|
||||
for k, val in enumerate(values):
|
||||
if val is not None:
|
||||
total[k % length] += val
|
||||
count[k % length] += 1
|
||||
|
||||
result = [0.0] * length
|
||||
for k in range(length):
|
||||
result[k] = total[k] / count[k] if count[k] else None
|
||||
return result
|
||||
result = [0.0] * length
|
||||
for k in range(length):
|
||||
result[k] = total[k] / count[k] if count[k] else None
|
||||
return result
|
||||
|
||||
|
||||
class PlotConfig(object):
|
||||
"""Object representing a single graph."""
|
||||
"""Object representing a single graph."""
|
||||
|
||||
def __init__(self, fields, data_list, cycle_length=None, frames=None,
|
||||
offset=0, output_filename=None, title="Graph"):
|
||||
self.fields = fields
|
||||
self.data_list = data_list
|
||||
self.cycle_length = cycle_length
|
||||
self.frames = frames
|
||||
self.offset = offset
|
||||
self.output_filename = output_filename
|
||||
self.title = title
|
||||
def __init__(self,
|
||||
fields,
|
||||
data_list,
|
||||
cycle_length=None,
|
||||
frames=None,
|
||||
offset=0,
|
||||
output_filename=None,
|
||||
title="Graph"):
|
||||
self.fields = fields
|
||||
self.data_list = data_list
|
||||
self.cycle_length = cycle_length
|
||||
self.frames = frames
|
||||
self.offset = offset
|
||||
self.output_filename = output_filename
|
||||
self.title = title
|
||||
|
||||
def Plot(self, ax1):
|
||||
lines = []
|
||||
for data in self.data_list:
|
||||
if not data:
|
||||
# Add None lines to skip the colors.
|
||||
lines.extend([None] * len(self.fields))
|
||||
else:
|
||||
data.AddSamples(self, lines)
|
||||
def Plot(self, ax1):
|
||||
lines = []
|
||||
for data in self.data_list:
|
||||
if not data:
|
||||
# Add None lines to skip the colors.
|
||||
lines.extend([None] * len(self.fields))
|
||||
else:
|
||||
data.AddSamples(self, lines)
|
||||
|
||||
def _SliceValues(values):
|
||||
if self.offset:
|
||||
values = values[self.offset:]
|
||||
if self.frames:
|
||||
values = values[:self.frames]
|
||||
return values
|
||||
def _SliceValues(values):
|
||||
if self.offset:
|
||||
values = values[self.offset:]
|
||||
if self.frames:
|
||||
values = values[:self.frames]
|
||||
return values
|
||||
|
||||
length = None
|
||||
for line in lines:
|
||||
if line is None:
|
||||
continue
|
||||
length = None
|
||||
for line in lines:
|
||||
if line is None:
|
||||
continue
|
||||
|
||||
line.values = _SliceValues(line.values)
|
||||
if self.cycle_length:
|
||||
line.values = AverageOverCycle(line.values, self.cycle_length)
|
||||
line.values = _SliceValues(line.values)
|
||||
if self.cycle_length:
|
||||
line.values = AverageOverCycle(line.values, self.cycle_length)
|
||||
|
||||
if length is None:
|
||||
length = len(line.values)
|
||||
elif length != len(line.values):
|
||||
raise Exception("All arrays should have the same length!")
|
||||
if length is None:
|
||||
length = len(line.values)
|
||||
elif length != len(line.values):
|
||||
raise Exception("All arrays should have the same length!")
|
||||
|
||||
ax1.set_xlabel("Frame", fontsize="large")
|
||||
if any(line.flags & RIGHT_Y_AXIS for line in lines if line):
|
||||
ax2 = ax1.twinx()
|
||||
ax2.set_xlabel("Frame", fontsize="large")
|
||||
else:
|
||||
ax2 = None
|
||||
ax1.set_xlabel("Frame", fontsize="large")
|
||||
if any(line.flags & RIGHT_Y_AXIS for line in lines if line):
|
||||
ax2 = ax1.twinx()
|
||||
ax2.set_xlabel("Frame", fontsize="large")
|
||||
else:
|
||||
ax2 = None
|
||||
|
||||
# Have to implement color_cycle manually, due to two scales in a graph.
|
||||
color_cycle = ["b", "r", "g", "c", "m", "y", "k"]
|
||||
color_iter = itertools.cycle(color_cycle)
|
||||
# Have to implement color_cycle manually, due to two scales in a graph.
|
||||
color_cycle = ["b", "r", "g", "c", "m", "y", "k"]
|
||||
color_iter = itertools.cycle(color_cycle)
|
||||
|
||||
for line in lines:
|
||||
if not line:
|
||||
color_iter.next()
|
||||
continue
|
||||
for line in lines:
|
||||
if not line:
|
||||
color_iter.next()
|
||||
continue
|
||||
|
||||
if self.cycle_length:
|
||||
x = numpy.array(range(self.cycle_length))
|
||||
else:
|
||||
x = numpy.array(range(self.offset, self.offset + len(line.values)))
|
||||
y = numpy.array(line.values)
|
||||
ax = ax2 if line.flags & RIGHT_Y_AXIS else ax1
|
||||
ax.Plot(x, y, "o-", label=line.label, markersize=3.0, linewidth=1.0,
|
||||
color=color_iter.next())
|
||||
if self.cycle_length:
|
||||
x = numpy.array(range(self.cycle_length))
|
||||
else:
|
||||
x = numpy.array(
|
||||
range(self.offset, self.offset + len(line.values)))
|
||||
y = numpy.array(line.values)
|
||||
ax = ax2 if line.flags & RIGHT_Y_AXIS else ax1
|
||||
ax.Plot(x,
|
||||
y,
|
||||
"o-",
|
||||
label=line.label,
|
||||
markersize=3.0,
|
||||
linewidth=1.0,
|
||||
color=color_iter.next())
|
||||
|
||||
ax1.grid(True)
|
||||
if ax2:
|
||||
ax1.legend(loc="upper left", shadow=True, fontsize="large")
|
||||
ax2.legend(loc="upper right", shadow=True, fontsize="large")
|
||||
else:
|
||||
ax1.legend(loc="best", shadow=True, fontsize="large")
|
||||
ax1.grid(True)
|
||||
if ax2:
|
||||
ax1.legend(loc="upper left", shadow=True, fontsize="large")
|
||||
ax2.legend(loc="upper right", shadow=True, fontsize="large")
|
||||
else:
|
||||
ax1.legend(loc="best", shadow=True, fontsize="large")
|
||||
|
||||
|
||||
def LoadFiles(filenames):
|
||||
result = []
|
||||
for filename in filenames:
|
||||
if filename in LoadFiles.cache:
|
||||
result.append(LoadFiles.cache[filename])
|
||||
else:
|
||||
data = Data(filename)
|
||||
LoadFiles.cache[filename] = data
|
||||
result.append(data)
|
||||
return result
|
||||
result = []
|
||||
for filename in filenames:
|
||||
if filename in LoadFiles.cache:
|
||||
result.append(LoadFiles.cache[filename])
|
||||
else:
|
||||
data = Data(filename)
|
||||
LoadFiles.cache[filename] = data
|
||||
result.append(data)
|
||||
return result
|
||||
|
||||
|
||||
LoadFiles.cache = {}
|
||||
|
||||
|
||||
def GetParser():
|
||||
class CustomAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if "ordered_args" not in namespace:
|
||||
namespace.ordered_args = []
|
||||
namespace.ordered_args.append((self.dest, values))
|
||||
class CustomAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if "ordered_args" not in namespace:
|
||||
namespace.ordered_args = []
|
||||
namespace.ordered_args.append((self.dest, values))
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
"-c", "--cycle_length", nargs=1, action=CustomAction,
|
||||
type=int, help="Cycle length over which to average the values.")
|
||||
parser.add_argument(
|
||||
"-f", "--field", nargs=1, action=CustomAction,
|
||||
help="Name of the field to show. Use 'none' to skip a color.")
|
||||
parser.add_argument("-r", "--right", nargs=0, action=CustomAction,
|
||||
help="Use right Y axis for given field.")
|
||||
parser.add_argument("-d", "--drop", nargs=0, action=CustomAction,
|
||||
help="Hide values for dropped frames.")
|
||||
parser.add_argument("-o", "--offset", nargs=1, action=CustomAction, type=int,
|
||||
help="Frame offset.")
|
||||
parser.add_argument("-n", "--next", nargs=0, action=CustomAction,
|
||||
help="Separator for multiple graphs.")
|
||||
parser.add_argument(
|
||||
"--frames", nargs=1, action=CustomAction, type=int,
|
||||
help="Frame count to show or take into account while averaging.")
|
||||
parser.add_argument("-t", "--title", nargs=1, action=CustomAction,
|
||||
help="Title of the graph.")
|
||||
parser.add_argument(
|
||||
"-O", "--output_filename", nargs=1, action=CustomAction,
|
||||
help="Use to save the graph into a file. "
|
||||
"Otherwise, a window will be shown.")
|
||||
parser.add_argument(
|
||||
"files", nargs="+", action=CustomAction,
|
||||
help="List of text-based files generated by loopback tests.")
|
||||
return parser
|
||||
parser.add_argument("-c",
|
||||
"--cycle_length",
|
||||
nargs=1,
|
||||
action=CustomAction,
|
||||
type=int,
|
||||
help="Cycle length over which to average the values.")
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--field",
|
||||
nargs=1,
|
||||
action=CustomAction,
|
||||
help="Name of the field to show. Use 'none' to skip a color.")
|
||||
parser.add_argument("-r",
|
||||
"--right",
|
||||
nargs=0,
|
||||
action=CustomAction,
|
||||
help="Use right Y axis for given field.")
|
||||
parser.add_argument("-d",
|
||||
"--drop",
|
||||
nargs=0,
|
||||
action=CustomAction,
|
||||
help="Hide values for dropped frames.")
|
||||
parser.add_argument("-o",
|
||||
"--offset",
|
||||
nargs=1,
|
||||
action=CustomAction,
|
||||
type=int,
|
||||
help="Frame offset.")
|
||||
parser.add_argument("-n",
|
||||
"--next",
|
||||
nargs=0,
|
||||
action=CustomAction,
|
||||
help="Separator for multiple graphs.")
|
||||
parser.add_argument(
|
||||
"--frames",
|
||||
nargs=1,
|
||||
action=CustomAction,
|
||||
type=int,
|
||||
help="Frame count to show or take into account while averaging.")
|
||||
parser.add_argument("-t",
|
||||
"--title",
|
||||
nargs=1,
|
||||
action=CustomAction,
|
||||
help="Title of the graph.")
|
||||
parser.add_argument("-O",
|
||||
"--output_filename",
|
||||
nargs=1,
|
||||
action=CustomAction,
|
||||
help="Use to save the graph into a file. "
|
||||
"Otherwise, a window will be shown.")
|
||||
parser.add_argument(
|
||||
"files",
|
||||
nargs="+",
|
||||
action=CustomAction,
|
||||
help="List of text-based files generated by loopback tests.")
|
||||
return parser
|
||||
|
||||
|
||||
def _PlotConfigFromArgs(args, graph_num):
|
||||
# Pylint complains about using kwargs, so have to do it this way.
|
||||
cycle_length = None
|
||||
frames = None
|
||||
offset = 0
|
||||
output_filename = None
|
||||
title = "Graph"
|
||||
# Pylint complains about using kwargs, so have to do it this way.
|
||||
cycle_length = None
|
||||
frames = None
|
||||
offset = 0
|
||||
output_filename = None
|
||||
title = "Graph"
|
||||
|
||||
fields = []
|
||||
files = []
|
||||
mask = 0
|
||||
for key, values in args:
|
||||
if key == "cycle_length":
|
||||
cycle_length = values[0]
|
||||
elif key == "frames":
|
||||
frames = values[0]
|
||||
elif key == "offset":
|
||||
offset = values[0]
|
||||
elif key == "output_filename":
|
||||
output_filename = values[0]
|
||||
elif key == "title":
|
||||
title = values[0]
|
||||
elif key == "drop":
|
||||
mask |= HIDE_DROPPED
|
||||
elif key == "right":
|
||||
mask |= RIGHT_Y_AXIS
|
||||
elif key == "field":
|
||||
field_id = FieldArgToId(values[0])
|
||||
fields.append(field_id | mask if field_id is not None else None)
|
||||
mask = 0 # Reset mask after the field argument.
|
||||
elif key == "files":
|
||||
files.extend(values)
|
||||
fields = []
|
||||
files = []
|
||||
mask = 0
|
||||
for key, values in args:
|
||||
if key == "cycle_length":
|
||||
cycle_length = values[0]
|
||||
elif key == "frames":
|
||||
frames = values[0]
|
||||
elif key == "offset":
|
||||
offset = values[0]
|
||||
elif key == "output_filename":
|
||||
output_filename = values[0]
|
||||
elif key == "title":
|
||||
title = values[0]
|
||||
elif key == "drop":
|
||||
mask |= HIDE_DROPPED
|
||||
elif key == "right":
|
||||
mask |= RIGHT_Y_AXIS
|
||||
elif key == "field":
|
||||
field_id = FieldArgToId(values[0])
|
||||
fields.append(field_id | mask if field_id is not None else None)
|
||||
mask = 0 # Reset mask after the field argument.
|
||||
elif key == "files":
|
||||
files.extend(values)
|
||||
|
||||
if not files:
|
||||
raise Exception("Missing file argument(s) for graph #{}".format(graph_num))
|
||||
if not fields:
|
||||
raise Exception("Missing field argument(s) for graph #{}".format(graph_num))
|
||||
if not files:
|
||||
raise Exception(
|
||||
"Missing file argument(s) for graph #{}".format(graph_num))
|
||||
if not fields:
|
||||
raise Exception(
|
||||
"Missing field argument(s) for graph #{}".format(graph_num))
|
||||
|
||||
return PlotConfig(fields, LoadFiles(files), cycle_length=cycle_length,
|
||||
frames=frames, offset=offset, output_filename=output_filename,
|
||||
title=title)
|
||||
return PlotConfig(fields,
|
||||
LoadFiles(files),
|
||||
cycle_length=cycle_length,
|
||||
frames=frames,
|
||||
offset=offset,
|
||||
output_filename=output_filename,
|
||||
title=title)
|
||||
|
||||
|
||||
def PlotConfigsFromArgs(args):
|
||||
"""Generates plot configs for given command line arguments."""
|
||||
# The way it works:
|
||||
# First we detect separators -n/--next and split arguments into groups, one
|
||||
# for each plot. For each group, we partially parse it with
|
||||
# argparse.ArgumentParser, modified to remember the order of arguments.
|
||||
# Then we traverse the argument list and fill the PlotConfig.
|
||||
args = itertools.groupby(args, lambda x: x in ["-n", "--next"])
|
||||
prep_args = list(list(group) for match, group in args if not match)
|
||||
"""Generates plot configs for given command line arguments."""
|
||||
# The way it works:
|
||||
# First we detect separators -n/--next and split arguments into groups, one
|
||||
# for each plot. For each group, we partially parse it with
|
||||
# argparse.ArgumentParser, modified to remember the order of arguments.
|
||||
# Then we traverse the argument list and fill the PlotConfig.
|
||||
args = itertools.groupby(args, lambda x: x in ["-n", "--next"])
|
||||
prep_args = list(list(group) for match, group in args if not match)
|
||||
|
||||
parser = GetParser()
|
||||
plot_configs = []
|
||||
for index, raw_args in enumerate(prep_args):
|
||||
graph_args = parser.parse_args(raw_args).ordered_args
|
||||
plot_configs.append(_PlotConfigFromArgs(graph_args, index))
|
||||
return plot_configs
|
||||
parser = GetParser()
|
||||
plot_configs = []
|
||||
for index, raw_args in enumerate(prep_args):
|
||||
graph_args = parser.parse_args(raw_args).ordered_args
|
||||
plot_configs.append(_PlotConfigFromArgs(graph_args, index))
|
||||
return plot_configs
|
||||
|
||||
|
||||
def ShowOrSavePlots(plot_configs):
|
||||
for config in plot_configs:
|
||||
fig = plt.figure(figsize=(14.0, 10.0))
|
||||
ax = fig.add_subPlot(1, 1, 1)
|
||||
for config in plot_configs:
|
||||
fig = plt.figure(figsize=(14.0, 10.0))
|
||||
ax = fig.add_subPlot(1, 1, 1)
|
||||
|
||||
plt.title(config.title)
|
||||
config.Plot(ax)
|
||||
if config.output_filename:
|
||||
print "Saving to", config.output_filename
|
||||
fig.savefig(config.output_filename)
|
||||
plt.close(fig)
|
||||
plt.title(config.title)
|
||||
config.Plot(ax)
|
||||
if config.output_filename:
|
||||
print "Saving to", config.output_filename
|
||||
fig.savefig(config.output_filename)
|
||||
plt.close(fig)
|
||||
|
||||
plt.show()
|
||||
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
ShowOrSavePlots(PlotConfigsFromArgs(sys.argv[1:]))
|
||||
ShowOrSavePlots(PlotConfigsFromArgs(sys.argv[1:]))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user