Rewrite/fix apply() implementation

This commit is contained in:
jan 2018-04-15 19:25:42 -07:00
parent 082236b6fd
commit 6fda991700

View File

@ -14,7 +14,7 @@ import numpy
from .subpattern import SubPattern from .subpattern import SubPattern
from .shapes import Shape, Polygon from .shapes import Shape, Polygon
from .utils import rotation_matrix_2d, vector2 from .utils import rotation_matrix_2d, vector2
from .error import PatternError
__author__ = 'Jan Petykiewicz' __author__ = 'Jan Petykiewicz'
@ -105,33 +105,37 @@ class Pattern:
return pat return pat
def apply(self, def apply(self,
func: Callable[['Pattern'], 'Pattern'] func: Callable[['Pattern'], 'Pattern'],
memo: Dict[int, 'Pattern']=None,
) -> 'Pattern': ) -> 'Pattern':
""" """
Recursively apply func() to this pattern and any pattern it references. Recursively apply func() to this pattern and any pattern it references.
func() is expected to take and return a Pattern. func() is expected to take and return a Pattern.
func() is first applied to the pattern as a whole, then the referenced patterns. func() is first applied to the pattern as a whole, then any referenced patterns.
It is only applied to any given pattern once, regardless of how many times it is It is only applied to any given pattern once, regardless of how many times it is
referenced. referenced.
:param func: Function which accepts a Pattern, and returns a pattern. :param func: Function which accepts a Pattern, and returns a pattern.
:param memo: Dictionary used to avoid re-running on multiply-referenced patterns.
Stores {id(pattern): func(pattern)} for patterns which have already been processed.
Default None (no already-processed patterns).
:return: The result of applying func() to this pattern and all subpatterns. :return: The result of applying func() to this pattern and all subpatterns.
:raises: PatternError if called on a pattern containing a circular reference. :raises: PatternError if called on a pattern containing a circular reference.
""" """
pat_map = {id(self): None} if memo is None:
pat = func(self) memo = {}
pat_map[id(self)] = pat
pat_id = id(self)
if pat_id not in memo:
memo[pat_id] = None
pat = func(self)
for subpat in pat.subpatterns: for subpat in pat.subpatterns:
ref_pat_id = id(subpat.pattern) subpat.pattern = subpat.pattern.apply(func, memo)
if ref_pat_id not in pat_map: memo[pat_id] = pat
pat_map[ref_pat_id] = None elif memo[pat_id] is None:
subpat.pattern = subpat.pattern.apply(func)
pat_map[ref_pat_id] = subpat.pattern
elif pat_map[ref_pat_id] is None:
raise PatternError('.apply() called on pattern with circular reference') raise PatternError('.apply() called on pattern with circular reference')
else: else:
subpat.pattern = pat_map[ref_pat_id] pat = memo[pat_id]
return pat return pat
def polygonize(self, def polygonize(self,