from itertools import groupby, product
from operator import itemgetter
from typing import Iterator, List, Optional, Tuple, cast, final
from mypy.nodes import ARG_STAR, ARG_STAR2
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, FunctionLike, Overloaded
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from returns.contrib.mypy._structures.args import FuncArg
from returns.contrib.mypy._typeops.transform_callable import (
Intermediate,
proper_type,
)
#: Raw material to build `_ArgTree`.
_RawArgTree = List[List[List[FuncArg]]]
[docs]def analyze(ctx: FunctionContext) -> MypyType:
"""Returns proper type for curried functions."""
if not isinstance(ctx.arg_types[0][0], CallableType):
return ctx.default_return_type
if not isinstance(ctx.default_return_type, CallableType):
return ctx.default_return_type
return _CurryFunctionOverloads(
ctx.arg_types[0][0], ctx,
).build_overloads()
@final
class _ArgTree:
"""Represents a node in tree of arguments."""
def __init__(self, case: Optional[CallableType]) -> None:
self.case = case
self.children: List['_ArgTree'] = []
@final
class _CurryFunctionOverloads:
"""
Implementation of ``@curry`` decorator typings.
Basically does just two things:
1. Creates all possible ordered combitations of arguments
2. Creates ``Overload`` instances for functions' return types
"""
def __init__(self, original: CallableType, ctx: FunctionContext) -> None:
"""
Saving the things we need.
Args:
original: original function that was passed to ``@curry``.
ctx: function context.
"""
self._original = original
self._ctx = ctx
self._overloads: List[CallableType] = []
self._args = FuncArg.from_callable(self._original)
# We need to get rid of generics here.
# Because, otherwise `detach_callable` with add
# unused variables to intermediate callables.
self._default = cast(
CallableType, self._ctx.default_return_type,
).copy_modified(
ret_type=AnyType(TypeOfAny.implementation_artifact),
)
def build_overloads(self) -> MypyType:
"""
Builds lots of possible overloads for a given function.
Inside we try to repsent all functions as sequence of arguments,
grouped by the similar ones and returning one more overload instance.
"""
if not self._args: # There's nothing to do, function has 0 args.
return self._original
if any(arg.kind in {ARG_STAR, ARG_STAR2} for arg in self._args):
# We don't support `*args` and `**kwargs`.
# Because it is very complex. It might be fixes in the future.
return self._default.ret_type # Any
argtree = self._build_argtree(
_ArgTree(None), # starting from root node
list(self._slices(self._args)),
)
self._build_overloads_from_argtree(argtree)
return proper_type(self._overloads)
def _build_argtree(
self,
node: _ArgTree,
source: _RawArgTree,
) -> '_ArgTree':
"""
Builds argument tree.
Each argument can point to zero, one, or more other nodes.
Arguments that have zero children are treated as bottom (last) ones.
Arguments that have just one child are meant to be regular functions.
Arguments that have more than one child are treated as overloads.
"""
def factory(
args: _RawArgTree,
) -> Iterator[Tuple[List[FuncArg], _RawArgTree]]:
if not args or not args[0]:
return # we have reached an end of arguments
yield from (
(case, [group[1:] for group in grouped])
for case, grouped in groupby(args, itemgetter(0))
)
for case, rest in factory(source):
new_node = _ArgTree(
Intermediate(self._default).with_signature(case),
)
node.children.append(new_node)
self._build_argtree(source=rest, node=new_node)
return node
def _build_overloads_from_argtree(self, argtree: _ArgTree) -> None:
"""Generates functions from argument tree."""
for child in argtree.children:
self._build_overloads_from_argtree(child)
assert child.case # mypy is not happy # noqa: S101
if not child.children:
child.case = Intermediate(child.case).with_ret_type(
self._original.ret_type,
)
if argtree.case is not None:
# We need to go backwards and to replace the return types
# of the previous functions. Like so:
# 1. `def x -> A`
# 2. `def y -> A`
# Will take `2` and apply its type to the previous function `1`.
# Will result in `def x -> y -> A`
# We also overloadify existing return types.
ret_type = argtree.case.ret_type
temp_any = isinstance(
ret_type, AnyType,
) and ret_type.type_of_any == TypeOfAny.implementation_artifact
argtree.case = Intermediate(argtree.case).with_ret_type(
child.case if temp_any else Overloaded(
[child.case, *cast(FunctionLike, ret_type).items],
),
)
else: # Root is reached, we need to save the result:
self._overloads.append(child.case)
def _slices(self, source: List[FuncArg]) -> Iterator[List[List[FuncArg]]]:
"""
Generate all possible slices of a source list.
Example::
_slices("AB") ->
"AB"
"A" "B"
_slices("ABC") ->
"ABC"
"AB" "C"
"A" "BC"
"A" "B" "C"
"""
for doslice in product([True, False], repeat=len(source) - 1):
slices = []
start = 0
for index, slicehere in enumerate(doslice, 1):
if slicehere:
slices.append(source[start:index])
start = index
slices.append(source[start:])
yield slices