From 8e451c64db8f9671f9407466a81f7f81d442a4c8 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 18:08:55 -0700 Subject: [PATCH] improve type annotations --- fatamorgana/basic.py | 9 +++++---- fatamorgana/main.py | 2 +- fatamorgana/records.py | 39 ++++++++++++++++++++++++--------------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/fatamorgana/basic.py b/fatamorgana/basic.py index f72000d..fcd5d0b 100644 --- a/fatamorgana/basic.py +++ b/fatamorgana/basic.py @@ -344,7 +344,7 @@ def read_bstring(stream: IO[bytes]) -> bytes: return _read(stream, length) -def write_bstring(stream: IO[bytes], bstring: bytes): +def write_bstring(stream: IO[bytes], bstring: bytes) -> int: """ Write a binary string to the stream. See `read_bstring()` for format details. @@ -1194,7 +1194,8 @@ class GridRepetition: a_vector: Sequence[int], a_count: int, b_vector: Sequence[int] | None = None, - b_count: int | None = None): + b_count: int | None = None, + ) -> None: """ Args: a_vector: First lattice vector, of the form `[x, y]`. @@ -1828,7 +1829,7 @@ def write_property_value( else: size = write_uint(stream, 8) size += write_uint(stream, value) - elif isinstance(value, (Fraction, float, int)): + elif isinstance(value, Fraction | float | int): size = write_real(stream, value, force_float32) elif isinstance(value, AString): size = write_uint(stream, 10) @@ -2229,7 +2230,7 @@ def write_magic_bytes(stream: IO[bytes]) -> int: return stream.write(MAGIC_BYTES) -def read_magic_bytes(stream: IO[bytes]): +def read_magic_bytes(stream: IO[bytes]) -> None: """ Read the magic byte sequence from a stream. Raise an `InvalidDataError` if it was not found. diff --git a/fatamorgana/main.py b/fatamorgana/main.py index 317378c..74c37aa 100644 --- a/fatamorgana/main.py +++ b/fatamorgana/main.py @@ -421,7 +421,7 @@ class Cell: placements: list[records.Placement] | None = None, geometry: list[records.geometry_t] | None = None, ) -> None: - self.name = name if isinstance(name, (NString, int)) else NString(name) + self.name = name if isinstance(name, NString | int) else NString(name) self.properties = [] if properties is None else properties self.placements = [] if placements is None else placements self.geometry = [] if geometry is None else geometry diff --git a/fatamorgana/records.py b/fatamorgana/records.py index 1e16cbf..88863a7 100644 --- a/fatamorgana/records.py +++ b/fatamorgana/records.py @@ -10,7 +10,7 @@ Higher-level code (e.g. monitoring for combinations of records with parse, or code for dealing with nested records in a CBlock) should live in main.py instead. """ -from typing import Any, TypeVar, IO, Union +from typing import Any, TypeVar, IO, Union, Protocol from collections.abc import Sequence from abc import ABCMeta, abstractmethod import copy @@ -227,6 +227,15 @@ class Record(metaclass=ABCMeta): return f'{self.__class__}: ' + pprint.pformat(self.__dict__) +class HasRepetition(Protocol): + repetition: repetition_t | None + + +class HasXY(Protocol): + x: int | None + y: int | None + + class GeometryMixin(metaclass=ABCMeta): """ Mixin defining common functions for geometry records @@ -465,10 +474,10 @@ class End(Record): self.validation = validation self.offset_table = offset_table - def merge_with_modals(self, modals: Modals): + def merge_with_modals(self, modals: Modals) -> None: pass - def deduplicate_with_modals(self, modals: Modals): + def deduplicate_with_modals(self, modals: Modals) -> None: pass @staticmethod @@ -703,10 +712,10 @@ class PropName(Record): self.nstring = NString(nstring) self.reference_number = reference_number - def merge_with_modals(self, modals: Modals): + def merge_with_modals(self, modals: Modals) -> None: modals.reset() - def deduplicate_with_modals(self, modals: Modals): + def deduplicate_with_modals(self, modals: Modals) -> None: modals.reset() @staticmethod @@ -931,7 +940,7 @@ class Property(Record): is_standard: `True` if this is a standard property. `None` to use modal. Default `None`. """ - if isinstance(name, (NString, int)) or name is None: + if isinstance(name, NString | int) or name is None: self.name = name else: self.name = NString(name) @@ -1255,7 +1264,7 @@ class Cell(Record): Args: name: `NString`, or an int specifying a `CellName` reference number. """ - self.name = name if isinstance(name, (int, NString)) else NString(name) + self.name = name if isinstance(name, int | NString) else NString(name) def merge_with_modals(self, modals: Modals) -> None: modals.reset() @@ -1338,7 +1347,7 @@ class Placement(Record): self.flip = flip self.magnification = magnification self.angle = angle - if isinstance(name, (int, NString)) or name is None: + if isinstance(name, int | NString) or name is None: self.name = name else: self.name = NString(name) @@ -1474,7 +1483,7 @@ class Text(Record, GeometryMixin): self.x = x self.y = y self.repetition = repetition - if isinstance(string, (AString, int)) or string is None: + if isinstance(string, int | AString) or string is None: self.string = string else: self.string = AString(string) @@ -2512,7 +2521,7 @@ class Circle(Record, GeometryMixin): return size -def adjust_repetition(record, modals: Modals) -> None: +def adjust_repetition(record: HasRepetition, modals: Modals) -> None: """ Merge the record's repetition entry with the one in the modals @@ -2533,7 +2542,7 @@ def adjust_repetition(record, modals: Modals) -> None: modals.repetition = copy.copy(record.repetition) -def adjust_field(record, r_field: str, modals: Modals, m_field: str) -> None: +def adjust_field(record: Record, r_field: str, modals: Modals, m_field: str) -> None: """ Merge `record.r_field` with `modals.m_field` @@ -2557,7 +2566,7 @@ def adjust_field(record, r_field: str, modals: Modals, m_field: str) -> None: raise InvalidDataError(f'Unfillable field: {m_field}') -def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> None: +def adjust_coordinates(record: HasXY, modals: Modals, mx_field: str, my_field: str) -> None: """ Merge `record.x` and `record.y` with `modals.mx_field` and `modals.my_field`, taking into account the value of `modals.xy_relative`. @@ -2591,7 +2600,7 @@ def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> # TODO: Clarify the docs on the dedup_* functions -def dedup_repetition(record, modals: Modals) -> None: +def dedup_repetition(record: HasRepetition, modals: Modals) -> None: """ Deduplicate the record's repetition entry with the one in the modals. Update the one in the modals if they are different. @@ -2618,7 +2627,7 @@ def dedup_repetition(record, modals: Modals) -> None: modals.repetition = record.repetition -def dedup_field(record, r_field: str, modals: Modals, m_field: str) -> None: +def dedup_field(record: Record, r_field: str, modals: Modals, m_field: str) -> None: """ Deduplicate `record.r_field` using `modals.m_field` Update the `modals.m_field` if they are different. @@ -2651,7 +2660,7 @@ def dedup_field(record, r_field: str, modals: Modals, m_field: str) -> None: raise InvalidDataError('Unfillable field') -def dedup_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> None: +def dedup_coordinates(record: HasXY, modals: Modals, mx_field: str, my_field: str) -> None: """ Deduplicate `record.x` and `record.y` using `modals.mx_field` and `modals.my_field`, taking into account the value of `modals.xy_relative`.