From 5b84a436a08e87867da7dedbb029c4e5cdb6b0ca Mon Sep 17 00:00:00 2001 From: jan Date: Mon, 22 Mar 2021 11:29:50 -0700 Subject: [PATCH] Make flatten() work in-place on all subpatterns (avoid copies and repeated calls). Also fix a bug around identifier generation introduced in a5900f6ad. --- masque/pattern.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/masque/pattern.py b/masque/pattern.py index aa585a3..1a54d48 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -545,6 +545,8 @@ class Pattern(LockableImpl, AnnotatableImpl, Mirrorable, metaclass=AutoSlots): def flatten(self: P) -> P: """ Removes all subpatterns and adds equivalent shapes. + Also flattens all subpatterns. + Modifies patterns in-place. Shape/label identifiers are changed to represent their original position in the pattern hierarchy: @@ -555,23 +557,30 @@ class Pattern(LockableImpl, AnnotatableImpl, Mirrorable, metaclass=AutoSlots): Returns: self """ - subpatterns = copy.deepcopy(self.subpatterns) - self.subpatterns = [] + def flatten_single(pat: P, flattened: Set[P]) -> P: + # Update identifiers so each shape has a unique one + for ss, shape in enumerate(pat.shapes): + shape.identifier = (ss,) + shape.identifier + for ll, label in enumerate(pat.labels): + label.identifier = (ll,) + label.identifier - # Update identifiers so each shape has a unique one - for ss, shape in enumerate(self.shapes): - shape.identifier = (ss,) + shape.identifier - for ll, label in enumerate(self.labels): - label.identifier = (ll,) + label.identifier + for pp, subpat in enumerate(pat.subpatterns): + if subpat.pattern is None: + continue - for pp, subpat in enumerate(subpatterns): - if subpat.pattern is None: - continue - subpat.pattern.flatten() - p = subpat.as_pattern() - for item in chain(p.shapes, p.labels): - item.identifier += (pp,) + item.identifier - self.append(p) + if subpat.pattern not in flattened: + flatten_single(subpat.pattern, flattened) + flattened.add(subpat.pattern) + + p = subpat.as_pattern() + for item in chain(p.shapes, p.labels): + item.identifier = (pp,) + item.identifier + pat.append(p) + + pat.subpatterns = [] + return pat + + flatten_single(self, set()) return self def wrap_repeated_shapes(self: P,