from datetime import timedelta
from hashlib import blake2b
import os
from pathlib import Path, PurePath
import re
import tempfile
import threading
from typing import AsyncIterator, Callable, Dict, Iterable, Optional, Tuple, Union
import aiohttp
from confuse.core import ConfigView
from lark import Token, Transformer
from qastle import parse
# Access to thread local storage.
threadLocal = threading.local()
[docs]async def default_client_session() -> aiohttp.ClientSession:
"""
Return a client session per thread.
"""
client = getattr(threadLocal, "client_session", None)
if client is None:
connector = aiohttp.TCPConnector(limit=20)
client = await aiohttp.ClientSession(connector=connector).__aenter__()
threadLocal.client_session = client
return client
[docs]def write_query_log(
request_id: str,
n_files: Optional[int],
n_skip: int,
time: timedelta,
success: bool,
path_to_log_dir: Path,
):
"""
Log to a csv file the status of a run.
"""
l_file = path_to_log_dir / "log.csv"
if l_file.parent.exists():
if not l_file.exists():
l_file.write_text("RequestID,n_files,n_skip,time_sec,no_error\n")
with l_file.open(mode="a") as f:
s_text = "1" if success else "0"
f.write(
f"{request_id},{n_files if n_files is not None else -1},"
f"{n_skip},{time.total_seconds()},{s_text}\n"
)
[docs]class log_adaptor:
"""
Helper method to allow easy mocking.
"""
[docs] @staticmethod
def write_query_log(
request_id: str,
n_files: Optional[int],
n_skip: int,
time: timedelta,
success: bool,
path_to_log_dir: Path,
):
write_query_log(request_id, n_files, n_skip, time, success, path_to_log_dir)
class ServiceXException(Exception):
"Raised when something has gone wrong in the ServiceX remote service"
def __init__(self, msg):
super().__init__(self, msg)
class ServiceXFatalTransformException(Exception):
"Raised when something has gone wrong in the ServiceX remote service"
def __init__(self, msg):
super().__init__(self, msg)
class ServiceXUnknownRequestID(Exception):
"Raised when we try to access ServiceX with a request ID it does not know about"
def __init__(self, msg):
super().__init__(self, msg)
class ServiceXNoFilesInCache(Exception):
"""Raised when we try to access ServiceX with a request ID that it
knows about, but no longer has the results cached - and they
need to be regenerated"
"""
def __init__(self, msg):
super().__init__(self, msg)
class ServiceXUnknownDataRequestID(Exception):
"""Raised when we expect ServiceX to know about the data for a request id, but it does not.
Possible causes:
- ServiceX has been reset and all previously run requests have been removed
- ServiceX's cache rolled over this entry, and it wasn't cached locally.
"""
def __init__(self, msg):
super().__init__(self, msg)
class ServiceXFailedFileTransform(Exception):
"Raised when a file(s) fail to transform"
def __init__(self, msg):
super().__init__(self, msg)
[docs]def sanitize_filename(fname: str):
"No matter the string given, make it an acceptable filename"
return fname.replace("*", "_").replace(";", "_").replace(":", "_")
# Datasets can be specified in multiple ways
# to ServiceX. This datatype encapsulates them
# all.
DatasetType = Union[
str, # A single URI with a scheme (e.g. rucio://did_name, http://, root://)
Iterable[str], # A list of http:// or root:// files to be directly accessed
]
StatusUpdateCallback = Callable[[Optional[int], int, int, int], None]
# The sig of the call-back. Arguments are:
# 1. Processed
# 1. Downloaded
# 1. Remaining
# 1. Failed
StatusUpdateFactory = Callable[[DatasetType, Optional[str], bool], StatusUpdateCallback]
# Factory method that returns a factory that can run the status callback
# First argument is a string, second argument is True if a download progress
# bar can be displayed as well.
class _status_update_wrapper:
"""
Internal class to make it easier to deal with the updater
"""
def __init__(self, callback: Optional[StatusUpdateCallback] = None):
self._callback = callback
self.reset()
def reset(self):
"Reset back to zero"
self._total: Optional[int] = None
self._processed: Optional[int] = None
self._downloaded: Optional[int] = None
self._failed: int = 0
self._remaining: Optional[int] = None
@property
def total(self) -> Optional[int]:
return self._total
@property
def failed(self) -> int:
return self._failed
def broadcast(self):
"Send an update back to the system"
if self._callback is not None:
processed = self._processed if self._processed is not None else 0
downloaded = self._downloaded if self._downloaded is not None else 0
failed = self._failed if self._failed is not None else 0
self._callback(self._total, processed, downloaded, failed)
def _update_total(self):
"Update total number of files, without letting inconsistencies change things"
if (
self._processed is not None
and self._remaining is not None
and self._failed is not None
):
total = self._remaining + self._processed + self._failed
if self._total is None:
self._total = total
else:
if self._total < total:
self._total = total
def update(
self,
processed: Optional[int] = None,
downloaded: Optional[int] = None,
remaining: Optional[int] = None,
failed: Optional[int] = None,
):
if processed is not None:
self._processed = processed
if downloaded is not None:
self._downloaded = downloaded
if failed is not None:
self._failed = failed
if remaining is not None:
self._remaining = remaining
self._update_total()
def inc(self, downloaded: Optional[int] = None):
if downloaded is not None:
if self._downloaded is None:
self._downloaded = downloaded
else:
self._downloaded += downloaded
TransformTuple = Tuple[Optional[int], int, Optional[int]]
[docs]async def stream_status_updates(
stream: AsyncIterator[TransformTuple], notifier: _status_update_wrapper
):
"""
As the transformed status goes by, update the notifier with the new
values.
"""
async for p in stream:
remaining, processed, failed = p
notifier.update(processed=processed, failed=failed, remaining=remaining)
notifier.broadcast()
yield p
[docs]async def stream_unique_updates_only(stream: AsyncIterator[TransformTuple]):
"""
As status goes by, only let through changes
"""
last_p: Optional[TransformTuple] = None
async for p in stream:
if p != last_p:
last_p = p
yield p
def _run_default_wrapper(
ds_name: DatasetType, title: Optional[str], downloading: bool
) -> StatusUpdateCallback:
"""Create a feedback object for everyone to use to pass feedback to. Uses tqdm (default).
Args:
ds_name (str): The dataset name
downloading (bool): True if we will be downloading the files, not just processing them.
Returns:
StatusUpdateCallback: The updater callback
"""
return _default_wrapper_mgr(ds_name, title, downloading).update
def _null_progress_feedback(
ds_name: DatasetType, title: Optional[str], downloading: bool
) -> None:
"""
Internal routine to create a feedback object that does not
give anyone feedback!
"""
return None
[docs]def dataset_as_name(
sample_name: Optional[DatasetType],
title: Optional[str] = None,
max_len: Optional[int] = 20,
) -> str:
"""Returns a name that is suitable for a progress bar
Args:
sample_name (str): The sample info
title (Optional[str]): The title specified with the request
max_len (Optional[int]): The maximum length of the title. None
means no length limit.
Returns:
str: Name to use for a progress bar
"""
if title is not None:
name = title
else:
if sample_name is not None:
name = (
sample_name
if isinstance(sample_name, str)
else f"[{', '.join(sample_name)}]"
)
else:
name = "<none>"
if (max_len is not None) and (len(name) > max_len):
name = name[0:max_len] + "..."
return name
# Hack to work around tqdm not understanding vscode notebooks (see #186).
if "VSCODE_PID" in os.environ.keys(): # pragma: no cover
from tqdm.notebook import tqdm
else:
from tqdm.auto import tqdm
class _default_wrapper_mgr:
"Default progress bar"
def __init__(
self,
sample_name: Optional[DatasetType] = None,
title: Optional[str] = None,
show_download_bar: bool = True,
):
self._tqdm_p: Optional[tqdm] = None
self._tqdm_d: Optional[tqdm] = None
self._show_download_progress = show_download_bar
self._name = dataset_as_name(sample_name, title)
def _init_tqdm(self):
if self._tqdm_p is not None:
return
self._tqdm_p = tqdm(
total=9e9,
desc=self._name,
unit="file",
leave=False,
dynamic_ncols=True,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]",
)
if self._show_download_progress:
self._tqdm_d = tqdm(
total=9e9,
desc=f" {self._name} Downloaded",
unit="file",
leave=False,
dynamic_ncols=True,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]",
)
def _update_bar(
self, bar: Optional[tqdm], total: Optional[int], num: int, failed: int
):
assert bar is not None, "Internal error - bar was not initialized"
if total is not None:
if bar.total != total:
# There is a bug in the tqm library if running in a notebook
# See https://github.com/tqdm/tqdm/issues/688
# This forces us to do the reset, sadly.
old_num = bar.n # type: int
bar.reset(total)
bar.update(old_num)
bar.update(num - bar.n)
bar.refresh()
if total is not None and (num + failed) == total:
bar.close()
def update(
self, total: Optional[int], processed: int, downloaded: int, failed: int
):
self._init_tqdm()
self._update_bar(self._tqdm_p, total, processed, failed)
if self._show_download_progress:
self._update_bar(self._tqdm_d, total, downloaded, failed)
# Changes in the json that won't affect the result
_json_keys_to_ignore_for_hash = ["workers", "title"]
def _query_cache_hash(json_query: Dict[str, str]) -> str:
"""
Return a query cache file.
"""
hasher = blake2b(digest_size=20)
for k, v in json_query.items():
if k not in _json_keys_to_ignore_for_hash:
if k == "selection":
v = clean_linq(v)
hasher.update(k.encode())
hasher.update(str(v).encode())
hash = hasher.hexdigest()
return hash
def _string_hash(s_list: Iterable[Union[str, Iterable[str]]]) -> str:
"""
Return a hash for an input list of strings.
"""
hasher = blake2b(digest_size=20)
for v in s_list:
if isinstance(v, str):
hasher.update(v.encode())
else:
for v1 in v:
hasher.update(v1.encode())
hash = hasher.hexdigest()
return hash
[docs]def clean_linq(q: str) -> str:
"""
Normalize the variables in a qastle expression. Should make the
linq expression more easily comparable even if the algorithm that
generates the underlying variable numbers changes. If the selection algorithm will violate
the qastle syntax, it is returned unaltered and with no errors.
Arguments
q String containing the qastle code`
Returns
clean_q Sanitized query - lambda arguments are given a uniform source code
labeling so that two queries with the same structure are marked as the same.
"""
from collections import namedtuple
ParseTracker = namedtuple("ParseTracker", ["text", "info"])
arg_index = 0
def new_arg():
nonlocal arg_index
s = f"a{arg_index}"
arg_index += 1
return s
def _replace_strings(s: str, rep: Dict[str, str]) -> str:
"Replace all strings in dict that are found in s"
for k in rep.keys():
s = re.sub(f"\\b{k}\\b", rep[k], s)
return s
class translator(Transformer):
def record(self, children):
if (
len(children) == 0
or isinstance(children[0], Token)
and children[0].type == "WHITESPACE"
): # type: ignore
return ""
else:
return children[0].text # type: ignore
def expression(self, children):
for child in children:
if isinstance(child, ParseTracker):
return child
raise SyntaxError("Expression does not contain a node")
def atom(self, children):
child = children[0]
return ParseTracker(child.value, child.value)
def composite(self, children):
fields = []
node_type: Optional[str] = None
for child in children:
if isinstance(child, Token):
if child.type == "NODE_TYPE": # type: ignore
node_type = child.value # type: ignore
elif isinstance(child, ParseTracker):
fields.append(child)
assert node_type is not None
if node_type == "lambda":
if len(fields) != 2:
raise ServiceXException(
f'The qastle "{q}" is not valid - found a lambda '
f"expression with {len(fields)} arguments - not the required 2!"
)
arg_list = [f.info for f in fields[0].info]
arg_mapping = {old: new_arg() for old in arg_list}
fields[0] = ParseTracker(
f'(list {" ".join(arg_mapping[k] for k in arg_list)})',
[ParseTracker(arg_mapping[f], arg_mapping[f]) for f in arg_list],
)
fields[1] = ParseTracker(
_replace_strings(fields[1].text, arg_mapping), fields[1].info
)
return ParseTracker(
f'({node_type} {" ".join([f.text for f in fields])})', fields
)
try:
tree = parse(q)
new_tree = translator().transform(tree)
assert isinstance(new_tree, str)
return new_tree
except Exception:
return q
# Below is copied and very lightly modified from the backoff library
# See https://github.com/litl/backoff/issues/131
from backoff import full_jitter # noqa: E402
from backoff._common import ( # noqa: E402
_prepare_logger,
_config_handlers,
_log_giveup,
_log_backoff,
_maybe_call,
_init_wait_gen,
_next_wait,
)
import logging # noqa: E402
from backoff._async import ( # noqa: E402
_ensure_coroutines,
_ensure_coroutine,
_call_handlers,
)
import asyncio # noqa: E402
import functools # noqa: E402
import datetime # noqa: E402
[docs]def retry_exception_itr(
target,
wait_gen,
exception,
max_tries,
max_time,
jitter,
giveup,
on_success,
on_backoff,
on_giveup,
wait_gen_kwargs,
):
on_success = _ensure_coroutines(on_success)
on_backoff = _ensure_coroutines(on_backoff)
on_giveup = _ensure_coroutines(on_giveup)
giveup = _ensure_coroutine(giveup)
# Easy to implement, please report if you need this.
assert not asyncio.iscoroutinefunction(max_tries)
assert not asyncio.iscoroutinefunction(jitter)
@functools.wraps(target)
async def retry(*args, **kwargs):
# change names because python 2.x doesn't have nonlocal
max_tries_ = _maybe_call(max_tries)
max_time_ = _maybe_call(max_time)
tries = 0
start = datetime.datetime.now()
wait = _init_wait_gen(wait_gen, wait_gen_kwargs)
while True:
tries += 1
elapsed = timedelta.total_seconds(datetime.datetime.now() - start)
details = {
"target": target,
"args": args,
"kwargs": kwargs,
"tries": tries,
"elapsed": elapsed,
}
got_one_item = False
try:
async for item in target(*args, **kwargs):
got_one_item = True
yield item
except exception as e:
# If we've already fed a result out of this method,
# we can't pull it back. So don't try to pull back/retry
# the exception either.
if got_one_item:
raise
giveup_result = await giveup(e)
max_tries_exceeded = tries == max_tries_
max_time_exceeded = max_time_ is not None and elapsed >= max_time_
if giveup_result or max_tries_exceeded or max_time_exceeded:
await _call_handlers(on_giveup, **details)
raise
try:
seconds = _next_wait(
wait, giveup_result, jitter, elapsed, max_time_
)
except StopIteration: # pragma: no cover
await _call_handlers(on_giveup, **details)
raise e
await _call_handlers(on_backoff, **details, wait=seconds)
# Note: there is no convenient way to pass explicit event
# loop to decorator, so here we assume that either default
# thread event loop is set and correct (it mostly is
# by default), or Python >= 3.5.3 or Python >= 3.6 is used
# where loop.get_event_loop() in coroutine guaranteed to
# return correct value.
# See for details:
# <https://groups.google.com/forum/#!topic/python-tulip/yF9C-rFpiKk>
# <https://bugs.python.org/issue28613>
await asyncio.sleep(seconds)
else:
await _call_handlers(on_success, **details)
return
return retry
[docs]def on_exception_itr(
wait_gen,
exception,
max_tries=None,
max_time=None,
jitter=full_jitter,
giveup=lambda e: False,
on_success=None,
on_backoff=None,
on_giveup=None,
logger="backoff",
backoff_log_level=logging.INFO,
giveup_log_level=logging.ERROR,
**wait_gen_kwargs,
):
def decorate(target):
# change names because python 2.x doesn't have nonlocal
logger_ = _prepare_logger(logger)
on_success_ = _config_handlers(on_success)
on_backoff_ = _config_handlers(
on_backoff,
default_handler=_log_backoff,
logger=logger_,
log_level=backoff_log_level,
)
on_giveup_ = _config_handlers(
on_giveup,
default_handler=_log_giveup,
logger=logger_,
log_level=giveup_log_level,
)
return retry_exception_itr(
target,
wait_gen,
exception,
max_tries,
max_time,
jitter,
giveup,
on_success_,
on_backoff_,
on_giveup_,
wait_gen_kwargs,
)
# Return a function which decorates a target with a retry loop.
return decorate