import socket
from collections.abc import Mapping
from importlib.metadata import version as get_version
import numpy
release = get_version("pyicat_plus")
[docs]
def assert_equal_hdf5_dict(adict: Mapping, expected: Mapping):
"""Check that HDF5 content is the same."""
stack = [(adict, expected)]
while stack:
adict, expected = stack.pop(0)
existing_keys = set(adict.keys())
expected_keys = set(expected.keys())
missing_keys = expected_keys - existing_keys
unexpected_keys = existing_keys - expected_keys
if missing_keys:
raise AssertionError(f"missing keys: {missing_keys}")
if unexpected_keys:
raise AssertionError(f"unexpected keys: {unexpected_keys}")
for k, v in expected.items():
if isinstance(v, Mapping):
stack.append((adict[k], v))
elif isinstance(v, numpy.ndarray) and v.size > 1:
assert adict[k].shape == v.shape
assert adict[k].dtype == v.dtype
if adict[k].dtype != object:
assert all(
numpy.isnan(adict[k].flatten()) == numpy.isnan(v.flatten())
)
mask = numpy.logical_not(numpy.isnan(v.flatten()))
assert all((adict[k].flatten() == v.flatten())[mask])
else:
assert all(adict[k].flatten() == v.flatten())
elif isinstance(v, (list, tuple)):
assert list(adict[k]) == list(v)
else:
assert adict[k] == v
[docs]
def assert_equal_dataset_message(actual: dict, expected: dict) -> None:
"""Check that ICAT ingester dataset messages are the same."""
# Cannot know the exact time
_set_expected_to_actual_value(actual, expected, "startDate")
_set_expected_to_actual_value(actual, expected, "endDate")
_set_expected_to_actual_value(actual, expected, "parameter", "startDate")
_set_expected_to_actual_value(actual, expected, "parameter", "endDate")
# Always set when missing
_set_expected_value(expected, "pyicat-plus_v" + release, "parameter", "software")
_set_expected_value(expected, socket.getfqdn(), "parameter", "machine")
# XSD defaults (excluding None)
_set_expected_value(expected, [], "datafile")
_set_expected_value(expected, [], "sample", "parameter")
# Parameter list can be in any order
_sort_parameters(actual)
_sort_parameters(expected)
assert actual == expected
[docs]
def assert_equal_investigation_message(actual: dict, expected: dict) -> None:
"""Check that ICAT ingester investigation messages are the same."""
_set_expected_to_actual_value(actual, expected, "startDate")
assert actual == expected
def _set_expected_value(expected: dict, value: str, *keys) -> None:
"""Set expected value when not provided."""
expected_parent = expected
for i, key in enumerate(keys[:-1]):
if key == "parameter":
next_key = keys[i + 1]
param_list = expected_parent.setdefault("parameter", [])
if not isinstance(param_list, list):
raise TypeError(
f"Expected list at '{'.'.join(keys[:i+1])}', got {type(param_list).__name__}"
)
# Do not overwrite existing parameter
for param in param_list:
if param.get("name") == next_key:
return
param_list.append({"name": next_key, "value": value})
return
expected_parent = expected_parent.setdefault(key, {})
_ = expected_parent.setdefault(keys[-1], value)
def _set_expected_to_actual_value(actual: dict, expected: dict, *keys: str) -> None:
"""Set expected value to actual value when not provided."""
if not keys:
return
value = actual
for i, key in enumerate(keys):
if not isinstance(value, dict):
raise TypeError(
f"Expected dict while traversing path '{'.'.join(keys)}', "
f"but got {type(value).__name__} at key '{key}'."
)
if key not in value:
path = ".".join(keys[:i]) or "<root>"
raise KeyError(
f"Missing key '{key}' in '{list(value)}' under '{path}' "
f"while traversing '{'.'.join(keys)}'"
)
value = value[key]
if key == "parameter":
# parameter is a list of {"name": ..., "value": ...}
next_key = keys[i + 1]
if not isinstance(value, list):
raise TypeError(
f"Expected list at '{'.'.join(keys[:i+1])}', "
f"got {type(value).__name__}"
)
match = next(
(p["value"] for p in value if p.get("name") == next_key),
None,
)
if match is None:
raise KeyError(
f"Parameter '{next_key}' not found in "
f"{[p.get('name') for p in value]}"
)
value = match
break
_set_expected_value(expected, value, *keys)
def _sort_parameters(d: dict) -> None:
"""Recursively sort all 'parameter' lists by their 'name' field in a nested dict."""
if not isinstance(d, Mapping):
return
for key, value in d.items():
if key == "parameter" and isinstance(value, list):
# Sort the list of dicts by 'name'
value.sort(key=lambda p: p.get("name", ""))
elif isinstance(value, dict):
_sort_parameters(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
_sort_parameters(item)