improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-29 18:08:55 -07:00
parent 691ce03150
commit 8e451c64db
3 changed files with 30 additions and 20 deletions

View File

@ -344,7 +344,7 @@ def read_bstring(stream: IO[bytes]) -> bytes:
return _read(stream, length) return _read(stream, length)
def write_bstring(stream: IO[bytes], bstring: bytes): def write_bstring(stream: IO[bytes], bstring: bytes) -> int:
""" """
Write a binary string to the stream. Write a binary string to the stream.
See `read_bstring()` for format details. See `read_bstring()` for format details.
@ -1194,7 +1194,8 @@ class GridRepetition:
a_vector: Sequence[int], a_vector: Sequence[int],
a_count: int, a_count: int,
b_vector: Sequence[int] | None = None, b_vector: Sequence[int] | None = None,
b_count: int | None = None): b_count: int | None = None,
) -> None:
""" """
Args: Args:
a_vector: First lattice vector, of the form `[x, y]`. a_vector: First lattice vector, of the form `[x, y]`.
@ -1828,7 +1829,7 @@ def write_property_value(
else: else:
size = write_uint(stream, 8) size = write_uint(stream, 8)
size += write_uint(stream, value) size += write_uint(stream, value)
elif isinstance(value, (Fraction, float, int)): elif isinstance(value, Fraction | float | int):
size = write_real(stream, value, force_float32) size = write_real(stream, value, force_float32)
elif isinstance(value, AString): elif isinstance(value, AString):
size = write_uint(stream, 10) size = write_uint(stream, 10)
@ -2229,7 +2230,7 @@ def write_magic_bytes(stream: IO[bytes]) -> int:
return stream.write(MAGIC_BYTES) return stream.write(MAGIC_BYTES)
def read_magic_bytes(stream: IO[bytes]): def read_magic_bytes(stream: IO[bytes]) -> None:
""" """
Read the magic byte sequence from a stream. Read the magic byte sequence from a stream.
Raise an `InvalidDataError` if it was not found. Raise an `InvalidDataError` if it was not found.

View File

@ -421,7 +421,7 @@ class Cell:
placements: list[records.Placement] | None = None, placements: list[records.Placement] | None = None,
geometry: list[records.geometry_t] | None = None, geometry: list[records.geometry_t] | None = None,
) -> None: ) -> None:
self.name = name if isinstance(name, (NString, int)) else NString(name) self.name = name if isinstance(name, NString | int) else NString(name)
self.properties = [] if properties is None else properties self.properties = [] if properties is None else properties
self.placements = [] if placements is None else placements self.placements = [] if placements is None else placements
self.geometry = [] if geometry is None else geometry self.geometry = [] if geometry is None else geometry

View File

@ -10,7 +10,7 @@ Higher-level code (e.g. monitoring for combinations of records with
parse, or code for dealing with nested records in a CBlock) should live parse, or code for dealing with nested records in a CBlock) should live
in main.py instead. in main.py instead.
""" """
from typing import Any, TypeVar, IO, Union from typing import Any, TypeVar, IO, Union, Protocol
from collections.abc import Sequence from collections.abc import Sequence
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import copy import copy
@ -227,6 +227,15 @@ class Record(metaclass=ABCMeta):
return f'{self.__class__}: ' + pprint.pformat(self.__dict__) return f'{self.__class__}: ' + pprint.pformat(self.__dict__)
class HasRepetition(Protocol):
repetition: repetition_t | None
class HasXY(Protocol):
x: int | None
y: int | None
class GeometryMixin(metaclass=ABCMeta): class GeometryMixin(metaclass=ABCMeta):
""" """
Mixin defining common functions for geometry records Mixin defining common functions for geometry records
@ -465,10 +474,10 @@ class End(Record):
self.validation = validation self.validation = validation
self.offset_table = offset_table self.offset_table = offset_table
def merge_with_modals(self, modals: Modals): def merge_with_modals(self, modals: Modals) -> None:
pass pass
def deduplicate_with_modals(self, modals: Modals): def deduplicate_with_modals(self, modals: Modals) -> None:
pass pass
@staticmethod @staticmethod
@ -703,10 +712,10 @@ class PropName(Record):
self.nstring = NString(nstring) self.nstring = NString(nstring)
self.reference_number = reference_number self.reference_number = reference_number
def merge_with_modals(self, modals: Modals): def merge_with_modals(self, modals: Modals) -> None:
modals.reset() modals.reset()
def deduplicate_with_modals(self, modals: Modals): def deduplicate_with_modals(self, modals: Modals) -> None:
modals.reset() modals.reset()
@staticmethod @staticmethod
@ -931,7 +940,7 @@ class Property(Record):
is_standard: `True` if this is a standard property. `None` to use modal. is_standard: `True` if this is a standard property. `None` to use modal.
Default `None`. Default `None`.
""" """
if isinstance(name, (NString, int)) or name is None: if isinstance(name, NString | int) or name is None:
self.name = name self.name = name
else: else:
self.name = NString(name) self.name = NString(name)
@ -1255,7 +1264,7 @@ class Cell(Record):
Args: Args:
name: `NString`, or an int specifying a `CellName` reference number. name: `NString`, or an int specifying a `CellName` reference number.
""" """
self.name = name if isinstance(name, (int, NString)) else NString(name) self.name = name if isinstance(name, int | NString) else NString(name)
def merge_with_modals(self, modals: Modals) -> None: def merge_with_modals(self, modals: Modals) -> None:
modals.reset() modals.reset()
@ -1338,7 +1347,7 @@ class Placement(Record):
self.flip = flip self.flip = flip
self.magnification = magnification self.magnification = magnification
self.angle = angle self.angle = angle
if isinstance(name, (int, NString)) or name is None: if isinstance(name, int | NString) or name is None:
self.name = name self.name = name
else: else:
self.name = NString(name) self.name = NString(name)
@ -1474,7 +1483,7 @@ class Text(Record, GeometryMixin):
self.x = x self.x = x
self.y = y self.y = y
self.repetition = repetition self.repetition = repetition
if isinstance(string, (AString, int)) or string is None: if isinstance(string, int | AString) or string is None:
self.string = string self.string = string
else: else:
self.string = AString(string) self.string = AString(string)
@ -2512,7 +2521,7 @@ class Circle(Record, GeometryMixin):
return size return size
def adjust_repetition(record, modals: Modals) -> None: def adjust_repetition(record: HasRepetition, modals: Modals) -> None:
""" """
Merge the record's repetition entry with the one in the modals Merge the record's repetition entry with the one in the modals
@ -2533,7 +2542,7 @@ def adjust_repetition(record, modals: Modals) -> None:
modals.repetition = copy.copy(record.repetition) modals.repetition = copy.copy(record.repetition)
def adjust_field(record, r_field: str, modals: Modals, m_field: str) -> None: def adjust_field(record: Record, r_field: str, modals: Modals, m_field: str) -> None:
""" """
Merge `record.r_field` with `modals.m_field` Merge `record.r_field` with `modals.m_field`
@ -2557,7 +2566,7 @@ def adjust_field(record, r_field: str, modals: Modals, m_field: str) -> None:
raise InvalidDataError(f'Unfillable field: {m_field}') raise InvalidDataError(f'Unfillable field: {m_field}')
def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> None: def adjust_coordinates(record: HasXY, modals: Modals, mx_field: str, my_field: str) -> None:
""" """
Merge `record.x` and `record.y` with `modals.mx_field` and `modals.my_field`, Merge `record.x` and `record.y` with `modals.mx_field` and `modals.my_field`,
taking into account the value of `modals.xy_relative`. taking into account the value of `modals.xy_relative`.
@ -2591,7 +2600,7 @@ def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str) ->
# TODO: Clarify the docs on the dedup_* functions # TODO: Clarify the docs on the dedup_* functions
def dedup_repetition(record, modals: Modals) -> None: def dedup_repetition(record: HasRepetition, modals: Modals) -> None:
""" """
Deduplicate the record's repetition entry with the one in the modals. Deduplicate the record's repetition entry with the one in the modals.
Update the one in the modals if they are different. Update the one in the modals if they are different.
@ -2618,7 +2627,7 @@ def dedup_repetition(record, modals: Modals) -> None:
modals.repetition = record.repetition modals.repetition = record.repetition
def dedup_field(record, r_field: str, modals: Modals, m_field: str) -> None: def dedup_field(record: Record, r_field: str, modals: Modals, m_field: str) -> None:
""" """
Deduplicate `record.r_field` using `modals.m_field` Deduplicate `record.r_field` using `modals.m_field`
Update the `modals.m_field` if they are different. Update the `modals.m_field` if they are different.
@ -2651,7 +2660,7 @@ def dedup_field(record, r_field: str, modals: Modals, m_field: str) -> None:
raise InvalidDataError('Unfillable field') raise InvalidDataError('Unfillable field')
def dedup_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> None: def dedup_coordinates(record: HasXY, modals: Modals, mx_field: str, my_field: str) -> None:
""" """
Deduplicate `record.x` and `record.y` using `modals.mx_field` and `modals.my_field`, Deduplicate `record.x` and `record.y` using `modals.mx_field` and `modals.my_field`,
taking into account the value of `modals.xy_relative`. taking into account the value of `modals.xy_relative`.