Source code for returns.trampolines

from functools import wraps
from typing import Callable, Generic, TypeVar, Union, final

from typing_extensions import ParamSpec

_ReturnType = TypeVar('_ReturnType')
_FuncParams = ParamSpec('_FuncParams')


[docs]@final class Trampoline(Generic[_ReturnType]): """ Represents a wrapped function call. Primitive to convert recursion into an actual object. """ __slots__ = ('func', 'args', 'kwargs') def __init__( # noqa: WPS451 self, func: Callable[_FuncParams, _ReturnType], /, # We use pos-only here to be able to store `kwargs` correctly. *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> None: """Save function and given arguments.""" self.func = getattr(func, '_orig_func', func) self.args = args self.kwargs = kwargs def __call__(self) -> _ReturnType: """Call wrapped function with given arguments.""" return self.func(*self.args, **self.kwargs)
[docs]def trampoline( func: Callable[_FuncParams, Union[_ReturnType, Trampoline[_ReturnType]]], ) -> Callable[_FuncParams, _ReturnType]: """ Convert functions using recursion to regular functions. Trampolines allow to unwrap recursion into a regular ``while`` loop, which does not raise any ``RecursionError`` ever. Since python does not have TCO (tail call optimization), we have to provide this helper. This is done by wrapping real function calls into :class:`returns.trampolines.Trampoline` objects: .. code:: python >>> from typing import Union >>> from returns.trampolines import Trampoline, trampoline >>> @trampoline ... def get_factorial( ... for_number: int, ... current_number: int = 0, ... acc: int = 1, ... ) -> Union[int, Trampoline[int]]: ... assert for_number >= 0 ... if for_number <= current_number: ... return acc ... return Trampoline( ... get_factorial, ... for_number, ... current_number=current_number + 1, ... acc=acc * (current_number + 1), ... ) >>> assert get_factorial(0) == 1 >>> assert get_factorial(3) == 6 >>> assert get_factorial(4) == 24 See also: - eli.thegreenplace.net/2017/on-recursion-continuations-and-trampolines - https://en.wikipedia.org/wiki/Tail_call """ @wraps(func) def decorator( *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> _ReturnType: trampoline_result = func(*args, **kwargs) while isinstance(trampoline_result, Trampoline): trampoline_result = trampoline_result() return trampoline_result decorator._orig_func = func # type: ignore[attr-defined] # noqa: WPS437 return decorator