improve type annotations

This commit is contained in:
Jan Petykiewicz 2020-04-18 15:38:52 -07:00
parent e9cf010f54
commit 7f0c46525e
3 changed files with 55 additions and 49 deletions

View File

@ -140,7 +140,7 @@ if _USE_NUMPY:
byte_arr = _read(stream, 1) byte_arr = _read(stream, 1)
return numpy.unpackbits(numpy.frombuffer(byte_arr, dtype=numpy.uint8)) return numpy.unpackbits(numpy.frombuffer(byte_arr, dtype=numpy.uint8))
def write_bool_byte(stream: io.BufferedIOBase, bits: Tuple[bool]) -> int: def write_bool_byte(stream: io.BufferedIOBase, bits: Tuple[Union[bool, int], ...]) -> int:
""" """
Pack 8 booleans into a byte, and write it to the stream. Pack 8 booleans into a byte, and write it to the stream.
@ -173,7 +173,7 @@ else:
bits = [bool((byte >> i) & 0x01) for i in reversed(range(8))] bits = [bool((byte >> i) & 0x01) for i in reversed(range(8))]
return bits return bits
def write_bool_byte(stream: io.BufferedIOBase, bits: Tuple[bool]) -> int: def write_bool_byte(stream: io.BufferedIOBase, bits: Tuple[Union[bool, int], ...]) -> int:
""" """
Pack 8 booleans into a byte, and write it to the stream. Pack 8 booleans into a byte, and write it to the stream.
@ -1611,6 +1611,7 @@ def write_point_list(stream: io.BufferedIOBase,
return size return size
# Try writing a bunch of Manhattan or Octangular deltas # Try writing a bunch of Manhattan or Octangular deltas
deltas: Union[List[ManhattanDelta], List[OctangularDelta], List[Delta]]
list_type = None list_type = None
try: try:
deltas = [ManhattanDelta(x, y) for x, y in points] deltas = [ManhattanDelta(x, y) for x, y in points]
@ -1721,6 +1722,7 @@ def read_property_value(stream: io.BufferedIOBase) -> property_value_t:
Raises: Raises:
InvalidDataError: if an invalid type is read. InvalidDataError: if an invalid type is read.
""" """
ref_type: Type
prop_type = read_uint(stream) prop_type = read_uint(stream)
if 0 <= prop_type <= 7: if 0 <= prop_type <= 7:
return read_real(stream, prop_type) return read_real(stream, prop_type)
@ -1964,20 +1966,20 @@ class OffsetTable:
layernames (OffsetEntry): Offset for LayerNames layernames (OffsetEntry): Offset for LayerNames
xnames (OffsetEntry): Offset for XNames xnames (OffsetEntry): Offset for XNames
""" """
cellnames: OffsetEntry = None cellnames: OffsetEntry
textstrings: OffsetEntry = None textstrings: OffsetEntry
propnames: OffsetEntry = None propnames: OffsetEntry
propstrings: OffsetEntry = None propstrings: OffsetEntry
layernames: OffsetEntry = None layernames: OffsetEntry
xnames: OffsetEntry = None xnames: OffsetEntry
def __init__(self, def __init__(self,
cellnames: OffsetEntry = None, cellnames: Optional[OffsetEntry] = None,
textstrings: OffsetEntry = None, textstrings: Optional[OffsetEntry] = None,
propnames: OffsetEntry = None, propnames: Optional[OffsetEntry] = None,
propstrings: OffsetEntry = None, propstrings: Optional[OffsetEntry] = None,
layernames: OffsetEntry = None, layernames: Optional[OffsetEntry] = None,
xnames: OffsetEntry = None): xnames: Optional[OffsetEntry] = None):
""" """
All parameters default to a non-strict entry with offset `0`. All parameters default to a non-strict entry with offset `0`.
@ -2204,5 +2206,5 @@ def read_magic_bytes(stream: io.BufferedIOBase):
magic = _read(stream, len(MAGIC_BYTES)) magic = _read(stream, len(MAGIC_BYTES))
if magic != MAGIC_BYTES: if magic != MAGIC_BYTES:
raise InvalidDataError('Could not read magic bytes, ' raise InvalidDataError('Could not read magic bytes, '
'found {} : {}'.format(magic, magic.decode())) 'found {!r} : {}'.format(magic, magic.decode()))

View File

@ -3,12 +3,12 @@ This module contains data structures and functions for reading from and
writing to whole OASIS layout files, and provides a few additional writing to whole OASIS layout files, and provides a few additional
abstractions for the data contained inside them. abstractions for the data contained inside them.
""" """
from typing import List, Dict, Union, Optional from typing import List, Dict, Union, Optional, Type
import io import io
import logging import logging
from . import records from . import records
from .records import Modals from .records import Modals, Record
from .basic import OffsetEntry, OffsetTable, NString, AString, real_t, Validation, \ from .basic import OffsetEntry, OffsetTable, NString, AString, real_t, Validation, \
read_magic_bytes, write_magic_bytes, read_uint, EOFError, \ read_magic_bytes, write_magic_bytes, read_uint, EOFError, \
InvalidDataError, InvalidRecordError InvalidDataError, InvalidRecordError
@ -29,7 +29,6 @@ class FileModals:
xname_implicit: Optional[bool] = None xname_implicit: Optional[bool] = None
textstring_implicit: Optional[bool] = None textstring_implicit: Optional[bool] = None
propstring_implicit: Optional[bool] = None propstring_implicit: Optional[bool] = None
cellname_implicit: Optional[bool] = None
within_cell: bool = False within_cell: bool = False
within_cblock: bool = False within_cblock: bool = False
@ -158,6 +157,8 @@ class OasisLayout:
logger.info('read_record of type {} at position 0x{:x}'.format(record_id, stream.tell())) logger.info('read_record of type {} at position 0x{:x}'.format(record_id, stream.tell()))
record: Record
# CBlock # CBlock
if record_id == 34: if record_id == 34:
if file_state.within_cblock: if file_state.within_cblock:
@ -451,7 +452,7 @@ class XName:
# Mapping from record id to record class. # Mapping from record id to record class.
_GEOMETRY = { _GEOMETRY: Dict[int, Type] = {
19: records.Text, 19: records.Text,
20: records.Rectangle, 20: records.Rectangle,
21: records.Polygon, 21: records.Polygon,

View File

@ -11,7 +11,7 @@ Higher-level code (e.g. monitoring for combinations of records with
in main.py instead. in main.py instead.
""" """
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import List, Dict, Tuple, Union, Optional, Sequence from typing import List, Dict, Tuple, Union, Optional, Sequence, Any
import copy import copy
import math import math
import zlib import zlib
@ -337,17 +337,17 @@ class Start(Record):
Attributes: Attributes:
version (AString): "1.0" version (AString): "1.0"
unit (real_t): positive real number, grid steps per micron unit (real_t): positive real number, grid steps per micron
offset_table (OffsetTable or None): If `None` then table must be offset_table (Optional[OffsetTable]): If `None` then table must be
placed in the `End` record) placed in the `End` record)
""" """
version: AString version: AString
unit: real_t unit: real_t
offset_table: OffsetTable = None offset_table: Optional[OffsetTable] = None
def __init__(self, def __init__(self,
unit: real_t, unit: real_t,
version: Union[AString, str] = None, version: Union[AString, str] = None,
offset_table: OffsetTable = None): offset_table: Optional[OffsetTable] = None):
""" """
Args Args
unit: Grid steps per micron (positive real number) unit: Grid steps per micron (positive real number)
@ -390,6 +390,7 @@ class Start(Record):
version = AString.read(stream) version = AString.read(stream)
unit = read_real(stream) unit = read_real(stream)
has_offset_table = read_uint(stream) == 0 has_offset_table = read_uint(stream) == 0
offset_table: Optional[OffsetTable]
if has_offset_table: if has_offset_table:
offset_table = OffsetTable.read(stream) offset_table = OffsetTable.read(stream)
else: else:
@ -448,7 +449,7 @@ class End(Record):
if record_id != 2: if record_id != 2:
raise InvalidDataError('Invalid record id for End {}'.format(record_id)) raise InvalidDataError('Invalid record id for End {}'.format(record_id))
if has_offset_table: if has_offset_table:
offset_table = OffsetTable.read(stream) offset_table: Optional[OffsetTable] = OffsetTable.read(stream)
else: else:
offset_table = None offset_table = None
_padding_string = read_bstring(stream) _padding_string = read_bstring(stream)
@ -630,7 +631,7 @@ class CellName(Record):
'{}'.format(record_id)) '{}'.format(record_id))
nstring = NString.read(stream) nstring = NString.read(stream)
if record_id == 4: if record_id == 4:
reference_number = read_uint(stream) reference_number: Optional[int] = read_uint(stream)
else: else:
reference_number = None reference_number = None
record = CellName(nstring, reference_number) record = CellName(nstring, reference_number)
@ -684,7 +685,7 @@ class PropName(Record):
'{}'.format(record_id)) '{}'.format(record_id))
nstring = NString.read(stream) nstring = NString.read(stream)
if record_id == 8: if record_id == 8:
reference_number = read_uint(stream) reference_number: Optional[int] = read_uint(stream)
else: else:
reference_number = None reference_number = None
record = PropName(nstring, reference_number) record = PropName(nstring, reference_number)
@ -739,7 +740,7 @@ class TextString(Record):
'{}'.format(record_id)) '{}'.format(record_id))
astring = AString.read(stream) astring = AString.read(stream)
if record_id == 6: if record_id == 6:
reference_number = read_uint(stream) reference_number: Optional[int] = read_uint(stream)
else: else:
reference_number = None reference_number = None
record = TextString(astring, reference_number) record = TextString(astring, reference_number)
@ -794,7 +795,7 @@ class PropString(Record):
'{}'.format(record_id)) '{}'.format(record_id))
astring = AString.read(stream) astring = AString.read(stream)
if record_id == 10: if record_id == 10:
reference_number = read_uint(stream) reference_number: Optional[int] = read_uint(stream)
else: else:
reference_number = None reference_number = None
record = PropString(astring, reference_number) record = PropString(astring, reference_number)
@ -936,6 +937,7 @@ class Property(Record):
s = 0x01 & (byte >> 0) s = 0x01 & (byte >> 0)
name = read_refname(stream, c, n) name = read_refname(stream, c, n)
values: Optional[List[property_value_t]]
if v == 0: if v == 0:
if u < 0x0f: if u < 0x0f:
value_count = u value_count = u
@ -1028,7 +1030,7 @@ class XName(Record):
attribute = read_uint(stream) attribute = read_uint(stream)
bstring = read_bstring(stream) bstring = read_bstring(stream)
if record_id == 31: if record_id == 31:
reference_number = read_uint(stream) reference_number: Optional[int] = read_uint(stream)
else: else:
reference_number = None reference_number = None
record = XName(attribute, bstring, reference_number) record = XName(attribute, bstring, reference_number)
@ -1113,11 +1115,11 @@ class XGeometry(Record):
def __init__(self, def __init__(self,
attribute: int, attribute: int,
bstring: bytes, bstring: bytes,
layer: int = None, layer: Optional[int] = None,
datatype: int = None, datatype: Optional[int] = None,
x: int = None, x: Optional[int] = None,
y: int = None, y: Optional[int] = None,
repetition: repetition_t = None): repetition: Optional[repetition_t] = None):
""" """
Args: Args:
attribute: Attribute number for this XGeometry. attribute: Attribute number for this XGeometry.
@ -1158,7 +1160,7 @@ class XGeometry(Record):
if z0 or z1 or z2: if z0 or z1 or z2:
raise InvalidDataError('Malformed XGeometry header') raise InvalidDataError('Malformed XGeometry header')
attribute = read_uint(stream) attribute = read_uint(stream)
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -1223,6 +1225,7 @@ class Cell(Record):
@staticmethod @staticmethod
def read(stream: io.BufferedIOBase, record_id: int) -> 'Cell': def read(stream: io.BufferedIOBase, record_id: int) -> 'Cell':
name: Union[int, NString]
if record_id == 13: if record_id == 13:
name = read_uint(stream) name = read_uint(stream)
elif record_id == 14: elif record_id == 14:
@ -1317,7 +1320,7 @@ class Placement(Record):
#CNXYRAAF (17) or CNXYRMAF (18) #CNXYRAAF (17) or CNXYRMAF (18)
c, n, x, y, r, ma0, ma1, flip = read_bool_byte(stream) c, n, x, y, r, ma0, ma1, flip = read_bool_byte(stream)
optional = {} optional: Dict[str, Any] = {}
name = read_refname(stream, c, n) name = read_refname(stream, c, n)
if record_id == 17: if record_id == 17:
aa = (ma0 << 1) | ma1 aa = (ma0 << 1) | ma1
@ -1451,7 +1454,7 @@ class Text(Record):
if z0: if z0:
raise InvalidDataError('Malformed Text header') raise InvalidDataError('Malformed Text header')
optional = {} optional: Dict[str, Any] = {}
string = read_refstring(stream, c, n) string = read_refstring(stream, c, n)
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
@ -1584,7 +1587,7 @@ class Rectangle(Record):
'{}'.format(record_id)) '{}'.format(record_id))
is_square, w, h, x, y, r, d, l = read_bool_byte(stream) is_square, w, h, x, y, r, d, l = read_bool_byte(stream)
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -1706,7 +1709,7 @@ class Polygon(Record):
if z0 or z1: if z0 or z1:
raise InvalidDataError('Invalid polygon header') raise InvalidDataError('Invalid polygon header')
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -1844,7 +1847,7 @@ class Path(Record):
'{}'.format(record_id)) '{}'.format(record_id))
e, w, p, x, y, r, d, l = read_bool_byte(stream) e, w, p, x, y, r, d, l = read_bool_byte(stream)
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -2035,7 +2038,7 @@ class Trapezoid(Record):
'{}'.format(record_id)) '{}'.format(record_id))
is_vertical, w, h, x, y, r, d, l = read_bool_byte(stream) is_vertical, w, h, x, y, r, d, l = read_bool_byte(stream)
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -2227,7 +2230,7 @@ class CTrapezoid(Record):
'{}'.format(record_id)) '{}'.format(record_id))
t, w, h, x, y, r, d, l = read_bool_byte(stream) t, w, h, x, y, r, d, l = read_bool_byte(stream)
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -2348,7 +2351,7 @@ class Circle(Record):
if z0 or z1: if z0 or z1:
raise InvalidDataError('Malformed circle header') raise InvalidDataError('Malformed circle header')
optional = {} optional: Dict[str, Any] = {}
if l: if l:
optional['layer'] = read_uint(stream) optional['layer'] = read_uint(stream)
if d: if d:
@ -2390,7 +2393,7 @@ class Circle(Record):
return size return size
def adjust_repetition(record: Record, modals: Modals): def adjust_repetition(record, modals: Modals):
""" """
Merge the record's repetition entry with the one in the modals Merge the record's repetition entry with the one in the modals
@ -2412,7 +2415,7 @@ def adjust_repetition(record: Record, modals: Modals):
modals.repetition = copy.copy(record.repetition) modals.repetition = copy.copy(record.repetition)
def adjust_field(record: Record, r_field: str, modals: Modals, m_field: str): def adjust_field(record, r_field: str, modals: Modals, m_field: str):
""" """
Merge `record.r_field` with `modals.m_field` Merge `record.r_field` with `modals.m_field`
@ -2436,7 +2439,7 @@ def adjust_field(record: Record, r_field: str, modals: Modals, m_field: str):
raise InvalidDataError('Unfillable field: {}'.format(m_field)) raise InvalidDataError('Unfillable field: {}'.format(m_field))
def adjust_coordinates(record: Record, modals: Modals, mx_field: str, my_field: str): def adjust_coordinates(record, modals: Modals, mx_field: str, my_field: str):
""" """
Merge `record.x` and `record.y` with `modals.mx_field` and `modals.my_field`, Merge `record.x` and `record.y` with `modals.mx_field` and `modals.my_field`,
taking into account the value of `modals.xy_relative`. taking into account the value of `modals.xy_relative`.
@ -2472,7 +2475,7 @@ def adjust_coordinates(record: Record, modals: Modals, mx_field: str, my_field:
# TODO: Clarify the docs on the dedup_* functions # TODO: Clarify the docs on the dedup_* functions
def dedup_repetition(record: Record, modals: Modals): def dedup_repetition(record, modals: Modals):
""" """
Deduplicate the record's repetition entry with the one in the modals. Deduplicate the record's repetition entry with the one in the modals.
Update the one in the modals if they are different. Update the one in the modals if they are different.
@ -2499,7 +2502,7 @@ def dedup_repetition(record: Record, modals: Modals):
modals.repetition = record.repetition modals.repetition = record.repetition
def dedup_field(record: Record, r_field: str, modals: Modals, m_field: str): def dedup_field(record, r_field: str, modals: Modals, m_field: str):
""" """
Deduplicate `record.r_field` using `modals.m_field` Deduplicate `record.r_field` using `modals.m_field`
Update the `modals.m_field` if they are different. Update the `modals.m_field` if they are different.
@ -2529,7 +2532,7 @@ def dedup_field(record: Record, r_field: str, modals: Modals, m_field: str):
raise InvalidDataError('Unfillable field') raise InvalidDataError('Unfillable field')
def dedup_coordinates(record: Record, modals: Modals, mx_field: str, my_field: str): def dedup_coordinates(record, modals: Modals, mx_field: str, my_field: str):
""" """
Deduplicate `record.x` and `record.y` using `modals.mx_field` and `modals.my_field`, Deduplicate `record.x` and `record.y` using `modals.mx_field` and `modals.my_field`,
taking into account the value of `modals.xy_relative`. taking into account the value of `modals.xy_relative`.