Rewrite/fix apply() implementation
This commit is contained in:
parent
082236b6fd
commit
6fda991700
@ -14,7 +14,7 @@ import numpy
|
|||||||
from .subpattern import SubPattern
|
from .subpattern import SubPattern
|
||||||
from .shapes import Shape, Polygon
|
from .shapes import Shape, Polygon
|
||||||
from .utils import rotation_matrix_2d, vector2
|
from .utils import rotation_matrix_2d, vector2
|
||||||
|
from .error import PatternError
|
||||||
|
|
||||||
__author__ = 'Jan Petykiewicz'
|
__author__ = 'Jan Petykiewicz'
|
||||||
|
|
||||||
@ -105,33 +105,37 @@ class Pattern:
|
|||||||
return pat
|
return pat
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
func: Callable[['Pattern'], 'Pattern']
|
func: Callable[['Pattern'], 'Pattern'],
|
||||||
|
memo: Dict[int, 'Pattern']=None,
|
||||||
) -> 'Pattern':
|
) -> 'Pattern':
|
||||||
"""
|
"""
|
||||||
Recursively apply func() to this pattern and any pattern it references.
|
Recursively apply func() to this pattern and any pattern it references.
|
||||||
func() is expected to take and return a Pattern.
|
func() is expected to take and return a Pattern.
|
||||||
func() is first applied to the pattern as a whole, then the referenced patterns.
|
func() is first applied to the pattern as a whole, then any referenced patterns.
|
||||||
It is only applied to any given pattern once, regardless of how many times it is
|
It is only applied to any given pattern once, regardless of how many times it is
|
||||||
referenced.
|
referenced.
|
||||||
|
|
||||||
:param func: Function which accepts a Pattern, and returns a pattern.
|
:param func: Function which accepts a Pattern, and returns a pattern.
|
||||||
|
:param memo: Dictionary used to avoid re-running on multiply-referenced patterns.
|
||||||
|
Stores {id(pattern): func(pattern)} for patterns which have already been processed.
|
||||||
|
Default None (no already-processed patterns).
|
||||||
:return: The result of applying func() to this pattern and all subpatterns.
|
:return: The result of applying func() to this pattern and all subpatterns.
|
||||||
:raises: PatternError if called on a pattern containing a circular reference.
|
:raises: PatternError if called on a pattern containing a circular reference.
|
||||||
"""
|
"""
|
||||||
pat_map = {id(self): None}
|
if memo is None:
|
||||||
pat = func(self)
|
memo = {}
|
||||||
pat_map[id(self)] = pat
|
|
||||||
|
|
||||||
|
pat_id = id(self)
|
||||||
|
if pat_id not in memo:
|
||||||
|
memo[pat_id] = None
|
||||||
|
pat = func(self)
|
||||||
for subpat in pat.subpatterns:
|
for subpat in pat.subpatterns:
|
||||||
ref_pat_id = id(subpat.pattern)
|
subpat.pattern = subpat.pattern.apply(func, memo)
|
||||||
if ref_pat_id not in pat_map:
|
memo[pat_id] = pat
|
||||||
pat_map[ref_pat_id] = None
|
elif memo[pat_id] is None:
|
||||||
subpat.pattern = subpat.pattern.apply(func)
|
|
||||||
pat_map[ref_pat_id] = subpat.pattern
|
|
||||||
elif pat_map[ref_pat_id] is None:
|
|
||||||
raise PatternError('.apply() called on pattern with circular reference')
|
raise PatternError('.apply() called on pattern with circular reference')
|
||||||
else:
|
else:
|
||||||
subpat.pattern = pat_map[ref_pat_id]
|
pat = memo[pat_id]
|
||||||
return pat
|
return pat
|
||||||
|
|
||||||
def polygonize(self,
|
def polygonize(self,
|
||||||
|
Loading…
Reference in New Issue
Block a user