diff --git a/masque/pattern.py b/masque/pattern.py index 9cd49c2..e3f8c8c 100644 --- a/masque/pattern.py +++ b/masque/pattern.py @@ -486,6 +486,32 @@ class Pattern: pat_list = [(p.name if p is not None else None, p) for p in pats_by_id.values()] return pat_list + def subpatterns_by_id(self, + include_none: bool = False, + recursive: bool = True, + ) -> Dict[int, List[subpattern_t]]: + """ + Create a dictionary which maps `{id(referenced_pattern): [subpattern0, ...]}` + for all SubPattern objects referenced by this Pattern (by default, operates + recursively on all referenced Patterns as well). + + Args: + include_none: If `True`, references to `None` will be included. Default `False`. + recursive: If `True`, operates recursively on all referenced patterns. Default `True`. + + Returns: + Dictionary mapping each pattern id to a list of subpattern objects referencing the pattern. + """ + ids: Dict[int, List[subpattern_t]] = defaultdict(list) + for subpat in self.subpatterns: + pat = subpat.pattern + if include_none or pat is not None: + ids[id(pat)].append(subpat) + if recursive and pat is not None: + ids.update(pat.subpatterns_by_id(include_none=include_none)) + return dict(ids) + + def get_bounds(self) -> Union[numpy.ndarray, None]: """ Return a `numpy.ndarray` containing `[[x_min, y_min], [x_max, y_max]]`, corresponding to the