Improve type annotations based on mypy errors

This commit is contained in:
Jan Petykiewicz 2020-05-11 19:09:35 -07:00
commit 157df47884
13 changed files with 151 additions and 117 deletions

View file

@ -2,7 +2,8 @@
Base object for containing a lithography mask.
"""
from typing import List, Callable, Tuple, Dict, Union, Set
from typing import List, Callable, Tuple, Dict, Union, Set, Sequence, Optional, Type
from typing import MutableMapping, Iterable
import copy
import itertools
import pickle
@ -39,7 +40,7 @@ class Pattern:
labels: List[Label]
""" List of all labels in this Pattern. """
subpatterns: List[SubPattern or GridRepetition]
subpatterns: List[Union[SubPattern, GridRepetition]]
""" List of all objects referencing other patterns in this Pattern.
Examples are SubPattern (gdsii "instances") or GridRepetition (gdsii "arrays")
Multiple objects in this list may reference the same Pattern object
@ -54,9 +55,9 @@ class Pattern:
def __init__(self,
name: str = '',
shapes: List[Shape] = (),
labels: List[Label] = (),
subpatterns: List[SubPattern] = (),
shapes: Sequence[Shape] = (),
labels: Sequence[Label] = (),
subpatterns: Sequence[Union[SubPattern, GridRepetition]] = (),
locked: bool = False,
):
"""
@ -129,7 +130,7 @@ class Pattern:
def subset(self,
shapes_func: Callable[[Shape], bool] = None,
labels_func: Callable[[Label], bool] = None,
subpatterns_func: Callable[[SubPattern], bool] = None,
subpatterns_func: Callable[[Union[SubPattern, GridRepetition]], bool] = None,
recursive: bool = False,
) -> 'Pattern':
"""
@ -172,9 +173,9 @@ class Pattern:
return pat
def apply(self,
func: Callable[['Pattern'], 'Pattern'],
memo: Dict[int, 'Pattern'] = None,
) -> 'Pattern':
func: Callable[[Optional['Pattern']], Optional['Pattern']],
memo: Optional[Dict[int, Optional['Pattern']]] = None,
) -> Optional['Pattern']:
"""
Recursively apply func() to this pattern and any pattern it references.
func() is expected to take and return a Pattern.
@ -217,9 +218,9 @@ class Pattern:
def dfs(self,
visit_before: visitor_function_t = None,
visit_after: visitor_function_t = None,
transform: numpy.ndarray or bool or None = False ,
memo: Dict = None,
hierarchy: Tuple['Pattern'] = (),
transform: Union[numpy.ndarray, bool, None] = False,
memo: Optional[Dict] = None,
hierarchy: Tuple['Pattern', ...] = (),
) -> 'Pattern':
"""
Experimental convenience function.
@ -270,7 +271,7 @@ class Pattern:
pat = self
if visit_before is not None:
pat = visit_before(pat, hierarchy=hierarchy, memo=memo, transform=transform)
pat = visit_before(pat, hierarchy=hierarchy, memo=memo, transform=transform) # type: ignore
for subpattern in self.subpatterns:
if transform is not False:
@ -293,12 +294,12 @@ class Pattern:
hierarchy=hierarchy + (self,))
if visit_after is not None:
pat = visit_after(pat, hierarchy=hierarchy, memo=memo, transform=transform)
pat = visit_after(pat, hierarchy=hierarchy, memo=memo, transform=transform) # type: ignore
return pat
def polygonize(self,
poly_num_points: int = None,
poly_max_arclen: float = None,
poly_num_points: Optional[int] = None,
poly_max_arclen: Optional[float] = None,
) -> 'Pattern':
"""
Calls `.to_polygons(...)` on all the shapes in this Pattern and any referenced patterns,
@ -349,7 +350,7 @@ class Pattern:
def subpatternize(self,
recursive: bool = True,
norm_value: int = int(1e6),
exclude_types: Tuple[Shape] = (Polygon,)
exclude_types: Tuple[Type] = (Polygon,)
) -> 'Pattern':
"""
Iterates through this `Pattern` and all referenced `Pattern`s. Within each `Pattern`, it iterates
@ -387,7 +388,7 @@ class Pattern:
# Create a dict which uses the label tuple from `.normalized_form()` as a key, and which
# stores `(function_to_create_normalized_shape, [(index_in_shapes, values), ...])`, where
# values are the `(offset, scale, rotation, dose)` values as calculated by `.normalized_form()`
shape_table = defaultdict(lambda: [None, list()])
shape_table: MutableMapping[Tuple, List] = defaultdict(lambda: [None, list()])
for i, shape in enumerate(self.shapes):
if not any((isinstance(shape, t) for t in exclude_types)):
label, values, func = shape.normalized_form(norm_value)
@ -429,9 +430,9 @@ class Pattern:
is of the form `[[x0, y0], [x1, y1],...]`.
"""
pat = self.deepcopy().deepunlock().polygonize().flatten()
return [shape.vertices + shape.offset for shape in pat.shapes]
return [shape.vertices + shape.offset for shape in pat.shapes] # type: ignore # mypy can't figure out that shapes are all Polygons now
def referenced_patterns_by_id(self) -> Dict[int, 'Pattern']:
def referenced_patterns_by_id(self) -> Dict[int, Optional['Pattern']]:
"""
Create a dictionary with `{id(pat): pat}` for all Pattern objects referenced by this
Pattern (operates recursively on all referenced Patterns as well)
@ -447,7 +448,7 @@ class Pattern:
ids.update(subpat.pattern.referenced_patterns_by_id())
return ids
def referenced_patterns_by_name(self) -> List[Tuple[str, 'Pattern']]:
def referenced_patterns_by_name(self) -> List[Tuple[Optional[str], Optional['Pattern']]]:
"""
Create a list of `(pat.name, pat)` tuples for all Pattern objects referenced by this
Pattern (operates recursively on all referenced Patterns as well).
@ -507,7 +508,7 @@ class Pattern:
"""
subpatterns = copy.deepcopy(self.subpatterns)
self.subpatterns = []
shape_counts = {}
shape_counts: Dict[Tuple, int] = {}
for subpat in subpatterns:
if subpat.pattern is None:
continue
@ -839,7 +840,7 @@ class Pattern:
pyplot.show()
@staticmethod
def find_toplevel(patterns: List['Pattern']) -> List['Pattern']:
def find_toplevel(patterns: Iterable['Pattern']) -> List['Pattern']:
"""
Given a list of Pattern objects, return those that are not referenced by
any other pattern.
@ -863,7 +864,7 @@ class Pattern:
return memo
patterns = set(patterns)
not_toplevel = set()
not_toplevel: Set['Pattern'] = set()
for pattern in patterns:
not_toplevel |= get_children(pattern, not_toplevel)