Source code for pyicat_plus.utils.sync_store

import concurrent.futures
import datetime
import json
import logging
import os
from glob import glob
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

from tqdm import tqdm

from . import path_utils
from . import raw_data
from . import sync_types

logger = logging.getLogger("ICAT SYNC")


[docs] class ExperimentalSessionStore: def __init__( self, cache_dir: Optional[str] = None, save_dir: Optional[str] = None, raw_data_format: str = "esrfv3", invalidate_cache: bool = False, ) -> None: """ :param cache_dir: one JSON file for each cached experimental session :param save_dir: one CSV+BASH file for each sync status :param raw_data_format: will be used as subdirectory name :param invalidate_cache: invalidating session when loaded from cache """ self._save_dir = save_dir self._cache_dir = cache_dir self._raw_data_format = raw_data_format self._invalidate_cache = invalidate_cache self._save_subdir = None self._cache_subdir = None if save_dir is not None: self._save_subdir = os.path.join(save_dir, raw_data_format) if cache_dir is not None: self._cache_subdir = os.path.join(cache_dir, raw_data_format) self._exp_sessions: Dict[str, sync_types.ExperimentalSession] = dict() self._load_all_session_files() self._start_time = datetime.datetime.now() def __enter__(self) -> "ExperimentalSessionStore": self._start_time = datetime.datetime.now() return self def __exit__(self, *_) -> bool: self.close() return False
[docs] def get_session(self, session_dir: str) -> Optional[sync_types.ExperimentalSession]: return self._exp_sessions.get(session_dir)
[docs] def add_session(self, exp_session: sync_types.ExperimentalSession) -> None: """Add session to the in-memory cache and when enabled save it on disk.""" self._exp_sessions[exp_session.session_dir] = exp_session if not self._cache_subdir: return filename = self._cache_filename(exp_session) data = exp_session.as_dict() os.makedirs(self._cache_subdir, mode=0o775, exist_ok=True) _json_save(data, filename)
[docs] def remove_session(self, exp_session: sync_types.ExperimentalSession) -> None: """Remove session from the in-memory cache and when enabled from disk.""" self._exp_sessions.pop(exp_session.session_dir, None) if not self._cache_subdir: return filename = self._cache_filename(exp_session) try: os.remove(filename) except FileNotFoundError: pass logger.info("Removed session cache %s", filename)
[docs] def close(self) -> None: try: self._save_overview() finally: self._print_final()
def _load_all_session_files(self) -> None: """Load all the session files when enabled and delete the ones for sessions which directory no longer exists.""" if not self._cache_subdir: return logger.info("Loading cache from %s ...", self._cache_subdir) filenames = glob(os.path.join(self._cache_subdir, "*.json")) logger.info("Loading cache of %d sessions ...", len(filenames)) data = _executor_map_with_progress(_json_load, filenames) if self._invalidate_cache: logger.info("Invalidate cache of %d sessions ...", len(data)) if self._invalidate_cache: func = _parse_session_cache_with_invalidation else: func = _parse_session_cache exp_sessions = _executor_map_with_progress(func, filenames, data) for exp_session in exp_sessions: if exp_session is not None: self._exp_sessions[exp_session.session_dir] = exp_session logger.info("Loaded cache of %d sessions.", len(self._exp_sessions)) def _cache_filename(self, exp_session: sync_types.ExperimentalSession) -> str: basename = ( f"{exp_session.proposal}_{exp_session.beamline}_{exp_session.session}.json" ) return os.path.join(self._cache_subdir, basename) def _save_overview(self) -> None: if not self._save_subdir: return logger.info( "Saving synchronization overview of %d sessions ...", len(self._exp_sessions), ) os.makedirs(self._save_subdir, mode=0o775, exist_ok=True) save_columns = [ "beamline", "proposal", "timeslot", "session", "url", "search_url", "path", *sync_types.ExperimentalSession.DATASET_STATUSES, ] save_category = [ "todo", "ok", "empty", "has_invalid", "todo_ongoing", "not_in_icat", ] register_cmd_params = "" if self._save_dir: register_cmd_params += f"--save-dir {self._save_dir} " if self._cache_dir: register_cmd_params += f"--cache-dir {self._cache_dir} " register_cmd_params += '"$@"' with concurrent.futures.ThreadPoolExecutor() as executor: # Clear CSV files and add header stats_files = dict( (s, os.path.join(self._save_subdir, s + ".csv")) for s in save_category ) def _init_csv(filename): _save_csv_line(filename, save_columns, clear=True) try: os.chmod(filename, 0o0664) except PermissionError: pass _ = list(executor.map(_init_csv, stats_files.values())) # Clear bash files and add header cmd_files = dict( (s, os.path.join(self._save_subdir, s + ".sh")) for s in save_category ) def _init_cmd(filename): _save_file_line(filename, "#!/bin/bash", clear=True) _save_file_line(filename, "") _save_file_line( filename, "# Run a cached copy because this script can be modified during execution", ) _save_file_line(filename, 'if [ -z "$SCRIPT_CACHED" ]; then') _save_file_line(filename, ' SCRIPT_CACHED=$(<"$0")') _save_file_line(filename, ' script_name=$(basename "$0")') _save_file_line( filename, ' echo "$SCRIPT_CACHED" > /tmp/cached_$script_name' ) _save_file_line( filename, ' SCRIPT_CACHED=1 /bin/bash /tmp/cached_$script_name "$@"', ) _save_file_line(filename, " exit $?") _save_file_line(filename, "fi") _save_file_line(filename, "") try: os.chmod(filename, 0o0774) except PermissionError: pass _ = list(executor.map(_init_cmd, cmd_files.values())) # Add sessions CSV and BASH files def _save_session_line(exp_session): icat_investigation = exp_session.icat_investigation if not icat_investigation: save_category = "not_in_icat" else: statuses = exp_session.dataset_statuses() if statuses & { "not_uploaded", "unregistered", "registered_without_files", }: if icat_investigation and icat_investigation.ongoing: save_category = "todo_ongoing" else: save_category = "todo" elif "invalid" in statuses: save_category = "has_invalid" elif "registered" in statuses: save_category = "ok" else: save_category = "empty" tabular_data = exp_session.tabular_data() stats = [tabular_data[k] for k in save_columns] _save_csv_line(stats_files[save_category], stats) register_cmd = f"icat-sync-raw --beamline {exp_session.beamline} --proposal {exp_session.proposal} --session {exp_session.session} {register_cmd_params}" _save_file_line(cmd_files[save_category], register_cmd) _ = list(executor.map(_save_session_line, self._exp_sessions.values())) logger.info("Saving synchronization finished.") def _print_final(self): end_time = datetime.datetime.now() logger.info("") logger.info("Total synchronization time was %s", end_time - self._start_time) if self._cache_subdir: logger.info("Cache was used from %s", self._cache_subdir) if self._save_subdir: logger.info("Results can be found in %s", self._save_subdir) logger.info( "To fix issues run %s", os.path.join(self._save_subdir, "todo.sh") )
def _executor_map_with_progress(func: Callable, *args) -> List[Any]: results = [] with concurrent.futures.ThreadPoolExecutor() as executor: with tqdm(total=len(args[0])) as progress: def update_progress(future): progress.update(1) results.append(future.result()) futures = [executor.submit(func, *_args) for _args in zip(*args)] for future in futures: future.add_done_callback(update_progress) result = concurrent.futures.wait(futures) for future in result.done: if future.exception(): raise future.exception() return results def _parse_session_cache( filename: str, data: dict ) -> Optional[sync_types.ExperimentalSession]: try: return sync_types.ExperimentalSession.from_dict(data) except Exception as e: logger.info("Invalidate %s: wrong cache format (%s)", filename, e) try: os.remove(filename) except FileNotFoundError: pass def _parse_session_cache_with_invalidation( filename: str, data: dict ) -> Optional[sync_types.ExperimentalSession]: exp_session = _parse_session_cache(filename, data) if exp_session is None: return invalidation_reason = _session_invalidation_reason(exp_session) if invalidation_reason: logger.info("Invalidate %s: %s", exp_session.session_dir, invalidation_reason) try: os.remove(filename) except FileNotFoundError: pass return return exp_session def _session_invalidation_reason( exp_session: sync_types.ExperimentalSession, ) -> Optional[str]: if not os.path.isdir(exp_session.session_dir): return "no longer exists" if not exp_session.icat_investigation: # Cached session does not contain any datasets when # no ICAT investigation was found return cached = {path_utils.markdir(dset.path) for dset in exp_session.iter_datasets()} dataset_filters = raw_data.get_dataset_filters( exp_session.raw_root_dir, raw_data_format=exp_session.raw_data_format ) actual = { path_utils.markdir(path) for dataset_filter in dataset_filters for path in glob(path_utils.markdir(dataset_filter)) } if cached != actual: nnew = len(actual - cached) nremoved = len(cached - actual) return f"has {nnew} new datasets and {nremoved} datasets were removed" def _json_save(data: Dict[str, Any], filename: str) -> None: try: with open(filename, "w") as fp: json.dump(data, fp, default=_json_serializer, indent=2) except Exception as e: raise IOError(f"writing failed for file {filename}") from e def _json_load(filename: str) -> Dict[str, Any]: try: with open(filename, "r") as fp: return json.load(fp, object_pairs_hook=_json_pair_deserializer) except Exception as e: raise IOError(f"reading failed for file {filename}") from e def _json_serializer(obj: Any): """Serialize datetime and date objects to string.""" if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() else: raise TypeError( "Object of type '{}' is not JSON serializable".format(type(obj)) ) def _json_pair_deserializer(items: List[Tuple[str, Any]]) -> Dict[str, Any]: """Deserialize string to datetime or date objects.""" result = dict() for key, value in items: if "date" in key and value: try: value = datetime.date.fromisoformat(value) except ValueError: value = datetime.datetime.fromisoformat(value) result[key] = value return result def _save_csv_line(filename: str, row: List[Any], clear: bool = False) -> None: line = ",".join([str(item).replace(",", " ").strip() for item in row]) _save_file_line(filename, line, clear) def _save_file_line(filename: str, line: str, clear: bool = False) -> None: mode = "w" if clear else "a" with open(filename, mode) as f: f.write(f"{line}\n")