Make flatten() work in-place on all subpatterns (avoid copies and repeated calls). Also fix a bug around identifier generation introduced in a5900f6ad.

This commit is contained in:
jan 2021-03-22 11:29:50 -07:00
parent 8ad4082f6d
commit 5b84a436a0

View File

@ -545,6 +545,8 @@ class Pattern(LockableImpl, AnnotatableImpl, Mirrorable, metaclass=AutoSlots):
def flatten(self: P) -> P: def flatten(self: P) -> P:
""" """
Removes all subpatterns and adds equivalent shapes. Removes all subpatterns and adds equivalent shapes.
Also flattens all subpatterns.
Modifies patterns in-place.
Shape/label identifiers are changed to represent their original position in the Shape/label identifiers are changed to represent their original position in the
pattern hierarchy: pattern hierarchy:
@ -555,23 +557,30 @@ class Pattern(LockableImpl, AnnotatableImpl, Mirrorable, metaclass=AutoSlots):
Returns: Returns:
self self
""" """
subpatterns = copy.deepcopy(self.subpatterns) def flatten_single(pat: P, flattened: Set[P]) -> P:
self.subpatterns = []
# Update identifiers so each shape has a unique one # Update identifiers so each shape has a unique one
for ss, shape in enumerate(self.shapes): for ss, shape in enumerate(pat.shapes):
shape.identifier = (ss,) + shape.identifier shape.identifier = (ss,) + shape.identifier
for ll, label in enumerate(self.labels): for ll, label in enumerate(pat.labels):
label.identifier = (ll,) + label.identifier label.identifier = (ll,) + label.identifier
for pp, subpat in enumerate(subpatterns): for pp, subpat in enumerate(pat.subpatterns):
if subpat.pattern is None: if subpat.pattern is None:
continue continue
subpat.pattern.flatten()
if subpat.pattern not in flattened:
flatten_single(subpat.pattern, flattened)
flattened.add(subpat.pattern)
p = subpat.as_pattern() p = subpat.as_pattern()
for item in chain(p.shapes, p.labels): for item in chain(p.shapes, p.labels):
item.identifier += (pp,) + item.identifier item.identifier = (pp,) + item.identifier
self.append(p) pat.append(p)
pat.subpatterns = []
return pat
flatten_single(self, set())
return self return self
def wrap_repeated_shapes(self: P, def wrap_repeated_shapes(self: P,