improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-29 18:08:55 -07:00
parent 691ce03150
commit 8e451c64db
3 changed files with 30 additions and 20 deletions

View File

@ -344,7 +344,7 @@ def read_bstring(stream: IO[bytes]) -> bytes:
return _read(stream, length)
def write_bstring(stream: IO[bytes], bstring: bytes):
def write_bstring(stream: IO[bytes], bstring: bytes) -> int:
"""
Write a binary string to the stream.
See `read_bstring()` for format details.
@ -1194,7 +1194,8 @@ class GridRepetition:
a_vector: Sequence[int],
a_count: int,
b_vector: Sequence[int] | None = None,
b_count: int | None = None):
b_count: int | None = None,
) -> None:
"""
Args:
a_vector: First lattice vector, of the form `[x, y]`.
@ -1828,7 +1829,7 @@ def write_property_value(
else:
size = write_uint(stream, 8)
size += write_uint(stream, value)
elif isinstance(value, (Fraction, float, int)):
elif isinstance(value, Fraction | float | int):
size = write_real(stream, value, force_float32)
elif isinstance(value, AString):
size = write_uint(stream, 10)
@ -2229,7 +2230,7 @@ def write_magic_bytes(stream: IO[bytes]) -> int:
return stream.write(MAGIC_BYTES)
def read_magic_bytes(stream: IO[bytes]):
def read_magic_bytes(stream: IO[bytes]) -> None:
"""
Read the magic byte sequence from a stream.
Raise an `InvalidDataError` if it was not found.

View File

@ -421,7 +421,7 @@ class Cell:
placements: list[records.Placement] | None = None,
geometry: list[records.geometry_t] | None = None,
) -> None:
self.name = name if isinstance(name, (NString, int)) else NString(name)
self.name = name if isinstance(name, NString | int) else NString(name)
self.properties = [] if properties is None else properties
self.placements = [] if placements is None else placements
self.geometry = [] if geometry is None else geometry

View File

@ -10,7 +10,7 @@ Higher-level code (e.g. monitoring for combinations of records with
parse, or code for dealing with nested records in a CBlock) should live
in main.py instead.
"""
from typing import Any, TypeVar, IO, Union
from typing import Any, TypeVar, IO, Union, Protocol
from collections.abc import Sequence
from abc import ABCMeta, abstractmethod
import copy
@ -227,6 +227,15 @@ class Record(metaclass=ABCMeta):
return f'{self.__class__}: ' + pprint.pformat(self.__dict__)
class HasRepetition(Protocol):
repetition: repetition_t | None
class HasXY(Protocol):
x: int | None
y: int | None
class GeometryMixin(metaclass=ABCMeta):
"""
Mixin defining common functions for geometry records
@ -465,10 +474,10 @@ class End(Record):
self.validation = validation
self.offset_table = offset_table
def merge_with_modals(self, modals: Modals):
def merge_with_modals(self, modals: Modals) -> None:
pass
def deduplicate_with_modals(self, modals: Modals):
def deduplicate_with_modals(self, modals: Modals) -> None:
pass
@staticmethod
@ -703,10 +712,10 @@ class PropName(Record):
self.nstring = NString(nstring)
self.reference_number = reference_number
def merge_with_modals(self, modals: Modals):
def merge_with_modals(self, modals: Modals) -> None:
modals.reset()
def deduplicate_with_modals(self, modals: Modals):
def deduplicate_with_modals(self, modals: Modals) -> None:
modals.reset()
@staticmethod
@ -931,7 +940,7 @@ class Property(Record):
is_standard: `True` if this is a standard property. `None` to use modal.
Default `None`.
"""
if isinstance(name, (NString, int)) or name is None:
if isinstance(name, NString | int) or name is None:
self.name = name
else:
self.name = NString(name)
@ -1255,7 +1264,7 @@ class Cell(Record):
Args:
name: `NString`, or an int specifying a `CellName` reference number.
"""
self.name = name if isinstance(name, (int, NString)) else NString(name)
self.name = name if isinstance(name, int | NString) else NString(name)
def merge_with_modals(self, modals: Modals) -> None:
modals.reset()
@ -1338,7 +1347,7 @@ class Placement(Record):
self.flip = flip
self.magnification = magnification
self.angle = angle
if isinstance(name, (int, NString)) or name is None:
if isinstance(name, int | NString) or name is None:
self.name = name
else:
self.name = NString(name)
@ -1474,7 +1483,7 @@ class Text(Record, GeometryMixin):
self.x = x
self.y = y
self.repetition = repetition
if isinstance(string, (AString, int)) or string is None:
if isinstance(string, int | AString) or string is None:
self.string = string
else:
self.string = AString(string)
@ -2512,7 +2521,7 @@ class Circle(Record, GeometryMixin):
return size
def adjust_repetition(record, modals: Modals) -> None:
def adjust_repetition(record: HasRepetition, modals: Modals) -> None:
"""
Merge the record's repetition entry with the one in the modals
@ -2533,7 +2542,7 @@ def adjust_repetition(record, modals: Modals) -> None:
modals.repetition = copy.copy(record.repetition)
def adjust_field(record, r_field: str, modals: Modals, m_field: str) -> None:
def adjust_field(record: Record, r_field: str, modals: Modals, m_field: str) -> None:
"""
Merge `record.r_field` with `modals.m_field`
@ -2557,7 +2566,7 @@ def adjust_field(record, r_field: str, modals: Modals, m_field: str) -> None:
raise InvalidDataError(f'Unfillable field: {m_field}')
def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> None:
def adjust_coordinates(record: HasXY, modals: Modals, mx_field: str, my_field: str) -> None:
"""
Merge `record.x` and `record.y` with `modals.mx_field` and `modals.my_field`,
taking into account the value of `modals.xy_relative`.
@ -2591,7 +2600,7 @@ def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str) ->
# TODO: Clarify the docs on the dedup_* functions
def dedup_repetition(record, modals: Modals) -> None:
def dedup_repetition(record: HasRepetition, modals: Modals) -> None:
"""
Deduplicate the record's repetition entry with the one in the modals.
Update the one in the modals if they are different.
@ -2618,7 +2627,7 @@ def dedup_repetition(record, modals: Modals) -> None:
modals.repetition = record.repetition
def dedup_field(record, r_field: str, modals: Modals, m_field: str) -> None:
def dedup_field(record: Record, r_field: str, modals: Modals, m_field: str) -> None:
"""
Deduplicate `record.r_field` using `modals.m_field`
Update the `modals.m_field` if they are different.
@ -2651,7 +2660,7 @@ def dedup_field(record, r_field: str, modals: Modals, m_field: str) -> None:
raise InvalidDataError('Unfillable field')
def dedup_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> None:
def dedup_coordinates(record: HasXY, modals: Modals, mx_field: str, my_field: str) -> None:
"""
Deduplicate `record.x` and `record.y` using `modals.mx_field` and `modals.my_field`,
taking into account the value of `modals.xy_relative`.