diff --git a/masque/error.py b/masque/error.py index 3bac722..0e46849 100644 --- a/masque/error.py +++ b/masque/error.py @@ -30,3 +30,10 @@ class PortError(MasqueError): Exception raised by builder-related functions """ pass + +class OneShotError(MasqueError): + """ + Exception raised when a function decorated with `@oneshot` is called more than once + """ + def __init__(self, func_name: str) -> None: + Exception.__init__(self, f'Function "{func_name}" with @oneshot was called more than once') diff --git a/masque/utils/__init__.py b/masque/utils/__init__.py index 1e04a35..4df9f9a 100644 --- a/masque/utils/__init__.py +++ b/masque/utils/__init__.py @@ -5,6 +5,7 @@ from .types import layer_t, annotations_t, SupportsBool from .array import is_scalar from .autoslots import AutoSlots from .deferreddict import DeferredDict +from .decorators import oneshot from .bitwise import get_bit, set_bit from .vertices import ( diff --git a/masque/utils/decorators.py b/masque/utils/decorators.py new file mode 100644 index 0000000..c212cdd --- /dev/null +++ b/masque/utils/decorators.py @@ -0,0 +1,21 @@ +from typing import Callable +from functools import wraps + +from ..error import OneShotError + + +def oneshot(func: Callable) -> Callable: + """ + Raises a OneShotError if the decorated function is called more than once + """ + expired = False + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal expired + if expired: + raise OneShotError(func.__name__) + expired = True + return func(*args, **kwargs) + + return wrapper