add subpatterns_by_id()

This commit is contained in:
Jan Petykiewicz 2020-05-23 19:39:03 -07:00
parent 1976c6e684
commit 07ee25e735

View File

@ -486,6 +486,32 @@ class Pattern:
pat_list = [(p.name if p is not None else None, p) for p in pats_by_id.values()] pat_list = [(p.name if p is not None else None, p) for p in pats_by_id.values()]
return pat_list return pat_list
def subpatterns_by_id(self,
include_none: bool = False,
recursive: bool = True,
) -> Dict[int, List[subpattern_t]]:
"""
Create a dictionary which maps `{id(referenced_pattern): [subpattern0, ...]}`
for all SubPattern objects referenced by this Pattern (by default, operates
recursively on all referenced Patterns as well).
Args:
include_none: If `True`, references to `None` will be included. Default `False`.
recursive: If `True`, operates recursively on all referenced patterns. Default `True`.
Returns:
Dictionary mapping each pattern id to a list of subpattern objects referencing the pattern.
"""
ids: Dict[int, List[subpattern_t]] = defaultdict(list)
for subpat in self.subpatterns:
pat = subpat.pattern
if include_none or pat is not None:
ids[id(pat)].append(subpat)
if recursive and pat is not None:
ids.update(pat.subpatterns_by_id(include_none=include_none))
return dict(ids)
def get_bounds(self) -> Union[numpy.ndarray, None]: def get_bounds(self) -> Union[numpy.ndarray, None]:
""" """
Return a `numpy.ndarray` containing `[[x_min, y_min], [x_max, y_max]]`, corresponding to the Return a `numpy.ndarray` containing `[[x_min, y_min], [x_max, y_max]]`, corresponding to the