Source code for deeprob.context

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

import contextlib
import contextvars

#: Thread-safe context variables, i.e. each thread will have its own flags assignments
_context_variables = contextvars.ContextVar(
    'context_variables',
    default={
        'check_dtype': True,
        'check_spn': True
    }
)


[docs]def is_check_dtype_enabled() -> bool: """Returns whether the context flag 'check_dtype' is enabled.""" return _context_variables.get()['check_dtype']
[docs]def is_check_spn_enabled() -> bool: """Returns whether the context flag 'check_spn' is enabled.""" return _context_variables.get()['check_spn']
[docs]class ContextState(contextlib.ContextDecorator): def __init__(self, **kwargs): """ Thread-safe Context State that disables some flags during execution. Current supported flags are the following: - check_dtype: bool = True, Whether to check (and cast when needed) Numpy arrays data types. - check_spn: bool = True, Whether to check the SPNs structure properties. """ self.__token = None self.__state = _context_variables.get().copy() for flag, value in kwargs.items(): if flag not in self.__state: raise ValueError("Cannot set an unknown flag called '{}', suitable flags are: {}".format( flag, ', '.join(self.__state.keys()) )) self.__state[flag] = value def __enter__(self): self.__token = _context_variables.set(self.__state) def __exit__(self, *exc): _context_variables.reset(self.__token)