Rewrite/fix apply() implementation

This commit is contained in:
jan 2018-04-15 19:25:42 -07:00
parent 082236b6fd
commit 6fda991700

View File

@ -14,7 +14,7 @@ import numpy
from .subpattern import SubPattern from .subpattern import SubPattern
from .shapes import Shape, Polygon from .shapes import Shape, Polygon
from .utils import rotation_matrix_2d, vector2 from .utils import rotation_matrix_2d, vector2
from .error import PatternError
__author__ = 'Jan Petykiewicz' __author__ = 'Jan Petykiewicz'
@ -105,33 +105,37 @@ class Pattern:
return pat return pat
def apply(self, def apply(self,
func: Callable[['Pattern'], 'Pattern'] func: Callable[['Pattern'], 'Pattern'],
memo: Dict[int, 'Pattern']=None,
) -> 'Pattern': ) -> 'Pattern':
""" """
Recursively apply func() to this pattern and any pattern it references. Recursively apply func() to this pattern and any pattern it references.
func() is expected to take and return a Pattern. func() is expected to take and return a Pattern.
func() is first applied to the pattern as a whole, then the referenced patterns. func() is first applied to the pattern as a whole, then any referenced patterns.
It is only applied to any given pattern once, regardless of how many times it is It is only applied to any given pattern once, regardless of how many times it is
referenced. referenced.
:param func: Function which accepts a Pattern, and returns a pattern. :param func: Function which accepts a Pattern, and returns a pattern.
:param memo: Dictionary used to avoid re-running on multiply-referenced patterns.
Stores {id(pattern): func(pattern)} for patterns which have already been processed.
Default None (no already-processed patterns).
:return: The result of applying func() to this pattern and all subpatterns. :return: The result of applying func() to this pattern and all subpatterns.
:raises: PatternError if called on a pattern containing a circular reference. :raises: PatternError if called on a pattern containing a circular reference.
""" """
pat_map = {id(self): None} if memo is None:
pat = func(self) memo = {}
pat_map[id(self)] = pat
for subpat in pat.subpatterns: pat_id = id(self)
ref_pat_id = id(subpat.pattern) if pat_id not in memo:
if ref_pat_id not in pat_map: memo[pat_id] = None
pat_map[ref_pat_id] = None pat = func(self)
subpat.pattern = subpat.pattern.apply(func) for subpat in pat.subpatterns:
pat_map[ref_pat_id] = subpat.pattern subpat.pattern = subpat.pattern.apply(func, memo)
elif pat_map[ref_pat_id] is None: memo[pat_id] = pat
raise PatternError('.apply() called on pattern with circular reference') elif memo[pat_id] is None:
else: raise PatternError('.apply() called on pattern with circular reference')
subpat.pattern = pat_map[ref_pat_id] else:
pat = memo[pat_id]
return pat return pat
def polygonize(self, def polygonize(self,