From 6fda9917002dacd1ec114147318eb5b2daa1e12a Mon Sep 17 00:00:00 2001 From: jan Date: Sun, 15 Apr 2018 19:25:42 -0700 Subject: [PATCH] Rewrite/fix apply() implementation --- masque/pattern.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/masque/pattern.py b/masque/pattern.py index 1ee77da..5f9605b 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -14,7 +14,7 @@ import numpy from .subpattern import SubPattern from .shapes import Shape, Polygon from .utils import rotation_matrix_2d, vector2 - +from .error import PatternError __author__ = 'Jan Petykiewicz' @@ -105,33 +105,37 @@ class Pattern: return pat def apply(self, - func: Callable[['Pattern'], 'Pattern'] + func: Callable[['Pattern'], 'Pattern'], + memo: Dict[int, 'Pattern']=None, ) -> 'Pattern': """ Recursively apply func() to this pattern and any pattern it references. func() is expected to take and return a Pattern. - func() is first applied to the pattern as a whole, then the referenced patterns. + func() is first applied to the pattern as a whole, then any referenced patterns. It is only applied to any given pattern once, regardless of how many times it is referenced. :param func: Function which accepts a Pattern, and returns a pattern. + :param memo: Dictionary used to avoid re-running on multiply-referenced patterns. + Stores {id(pattern): func(pattern)} for patterns which have already been processed. + Default None (no already-processed patterns). :return: The result of applying func() to this pattern and all subpatterns. :raises: PatternError if called on a pattern containing a circular reference. """ - pat_map = {id(self): None} - pat = func(self) - pat_map[id(self)] = pat + if memo is None: + memo = {} - for subpat in pat.subpatterns: - ref_pat_id = id(subpat.pattern) - if ref_pat_id not in pat_map: - pat_map[ref_pat_id] = None - subpat.pattern = subpat.pattern.apply(func) - pat_map[ref_pat_id] = subpat.pattern - elif pat_map[ref_pat_id] is None: - raise PatternError('.apply() called on pattern with circular reference') - else: - subpat.pattern = pat_map[ref_pat_id] + pat_id = id(self) + if pat_id not in memo: + memo[pat_id] = None + pat = func(self) + for subpat in pat.subpatterns: + subpat.pattern = subpat.pattern.apply(func, memo) + memo[pat_id] = pat + elif memo[pat_id] is None: + raise PatternError('.apply() called on pattern with circular reference') + else: + pat = memo[pat_id] return pat def polygonize(self,