improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-28 23:15:11 -07:00
parent e7e42a2ef8
commit 59c94f7c17
3 changed files with 52 additions and 52 deletions

View File

@ -1,7 +1,8 @@
""" """
Generic record-level read/write functionality. Generic record-level read/write functionality.
""" """
from typing import Sequence, IO, TypeVar, ClassVar, Type from typing import IO, ClassVar, Self, Generic, TypeVar
from collections.abc import Sequence
import struct import struct
import io import io
from datetime import datetime from datetime import datetime
@ -17,6 +18,8 @@ from .basic import parse_ascii, pack_ascii, read
_RECORD_HEADER_FMT = struct.Struct('>HH') _RECORD_HEADER_FMT = struct.Struct('>HH')
II = TypeVar('II') # Input type
OO = TypeVar('OO') # Output type
def write_record_header(stream: IO[bytes], data_size: int, tag: int) -> int: def write_record_header(stream: IO[bytes], data_size: int, tag: int) -> int:
@ -53,30 +56,27 @@ def expect_record(stream: IO[bytes], tag: int) -> int:
return data_size return data_size
R = TypeVar('R', bound='Record') class Record(Generic[II, OO], metaclass=ABCMeta):
class Record(metaclass=ABCMeta):
tag: ClassVar[int] = -1 tag: ClassVar[int] = -1
expected_size: ClassVar[int | None] = None expected_size: ClassVar[int | None] = None
@classmethod @classmethod
def check_size(cls, size: int): def check_size(cls: type[Self], size: int) -> None:
if cls.expected_size is not None and size != cls.expected_size: if cls.expected_size is not None and size != cls.expected_size:
raise KlamathError(f'Expected size {cls.expected_size}, got {size}') raise KlamathError(f'Expected size {cls.expected_size}, got {size}')
@classmethod @classmethod
def check_data(cls, data): def check_data(cls: type[Self], data: II) -> None:
pass pass
@classmethod @classmethod
@abstractmethod @abstractmethod
def read_data(cls, stream: IO[bytes], size: int): def read_data(cls: type[Self], stream: IO[bytes], size: int) -> OO:
pass pass
@classmethod @classmethod
@abstractmethod @abstractmethod
def pack_data(cls, data) -> bytes: def pack_data(cls: type[Self], data: II) -> bytes:
pass pass
@staticmethod @staticmethod
@ -84,11 +84,11 @@ class Record(metaclass=ABCMeta):
return read_record_header(stream) return read_record_header(stream)
@classmethod @classmethod
def write_header(cls, stream: IO[bytes], data_size: int) -> int: def write_header(cls: type[Self], stream: IO[bytes], data_size: int) -> int:
return write_record_header(stream, data_size, cls.tag) return write_record_header(stream, data_size, cls.tag)
@classmethod @classmethod
def skip_past(cls, stream: IO[bytes]) -> bool: def skip_past(cls: type[Self], stream: IO[bytes]) -> bool:
""" """
Skip to the end of the next occurence of this record. Skip to the end of the next occurence of this record.
@ -110,7 +110,7 @@ class Record(metaclass=ABCMeta):
return True return True
@classmethod @classmethod
def skip_and_read(cls, stream: IO[bytes]): def skip_and_read(cls: type[Self], stream: IO[bytes]) -> OO:
size, tag = Record.read_header(stream) size, tag = Record.read_header(stream)
while tag != cls.tag: while tag != cls.tag:
stream.seek(size, io.SEEK_CUR) stream.seek(size, io.SEEK_CUR)
@ -119,90 +119,90 @@ class Record(metaclass=ABCMeta):
return data return data
@classmethod @classmethod
def read(cls: Type[R], stream: IO[bytes]): def read(cls: type[Self], stream: IO[bytes]) -> OO:
size = expect_record(stream, cls.tag) size = expect_record(stream, cls.tag)
data = cls.read_data(stream, size) data = cls.read_data(stream, size)
return data return data
@classmethod @classmethod
def write(cls, stream: IO[bytes], data) -> int: def write(cls: type[Self], stream: IO[bytes], data: II) -> int:
data_bytes = cls.pack_data(data) data_bytes = cls.pack_data(data)
b = cls.write_header(stream, len(data_bytes)) b = cls.write_header(stream, len(data_bytes))
b += stream.write(data_bytes) b += stream.write(data_bytes)
return b return b
class NoDataRecord(Record): class NoDataRecord(Record[None, None]):
expected_size: ClassVar[int | None] = 0 expected_size: ClassVar[int | None] = 0
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> None: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> None:
stream.read(size) stream.read(size)
@classmethod @classmethod
def pack_data(cls, data: None) -> bytes: def pack_data(cls: type[Self], data: None) -> bytes:
if data is not None: if data is not None:
raise KlamathError('?? Packing {data} into NoDataRecord??') raise KlamathError('?? Packing {data} into NoDataRecord??')
return b'' return b''
class BitArrayRecord(Record): class BitArrayRecord(Record[int, int]):
expected_size: ClassVar[int | None] = 2 expected_size: ClassVar[int | None] = 2
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> int: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> int: # noqa: ARG003 size unused
return parse_bitarray(read(stream, 2)) return parse_bitarray(read(stream, 2))
@classmethod @classmethod
def pack_data(cls, data: int) -> bytes: def pack_data(cls: type[Self], data: int) -> bytes:
return pack_bitarray(data) return pack_bitarray(data)
class Int2Record(Record): class Int2Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int16]]):
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int16]: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int16]:
return parse_int2(read(stream, size)) return parse_int2(read(stream, size))
@classmethod @classmethod
def pack_data(cls, data: Sequence[int]) -> bytes: def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
return pack_int2(data) return pack_int2(data)
class Int4Record(Record): class Int4Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int32]]):
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int32]: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int32]:
return parse_int4(read(stream, size)) return parse_int4(read(stream, size))
@classmethod @classmethod
def pack_data(cls, data: Sequence[int]) -> bytes: def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
return pack_int4(data) return pack_int4(data)
class Real8Record(Record): class Real8Record(Record[Sequence[float] | float, NDArray[numpy.float64]]):
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.float64]: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.float64]:
return parse_real8(read(stream, size)) return parse_real8(read(stream, size))
@classmethod @classmethod
def pack_data(cls, data: Sequence[int]) -> bytes: def pack_data(cls: type[Self], data: Sequence[float] | float) -> bytes:
return pack_real8(data) return pack_real8(data)
class ASCIIRecord(Record): class ASCIIRecord(Record[bytes, bytes]):
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> bytes: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> bytes:
return parse_ascii(read(stream, size)) return parse_ascii(read(stream, size))
@classmethod @classmethod
def pack_data(cls, data: bytes) -> bytes: def pack_data(cls: type[Self], data: bytes) -> bytes:
return pack_ascii(data) return pack_ascii(data)
class DateTimeRecord(Record): class DateTimeRecord(Record[Sequence[datetime], list[datetime]]):
@classmethod @classmethod
def read_data(cls, stream: IO[bytes], size: int) -> list[datetime]: def read_data(cls: type[Self], stream: IO[bytes], size: int) -> list[datetime]:
return parse_datetime(read(stream, size)) return parse_datetime(read(stream, size))
@classmethod @classmethod
def pack_data(cls, data: Sequence[datetime]) -> bytes: def pack_data(cls: type[Self], data: Sequence[datetime]) -> bytes:
return pack_datetime(data) return pack_datetime(data)

View File

@ -144,7 +144,7 @@ class REFLIBS(ASCIIRecord):
tag = 0x1f06 tag = 0x1f06
@classmethod @classmethod
def check_size(cls, size: int): def check_size(cls: type[Self], size: int) -> None:
if size != 0 and size % 44 != 0: if size != 0 and size % 44 != 0:
raise Exception(f'Expected size to be multiple of 44, got {size}') raise Exception(f'Expected size to be multiple of 44, got {size}')
@ -153,7 +153,7 @@ class FONTS(ASCIIRecord):
tag = 0x2006 tag = 0x2006
@classmethod @classmethod
def check_size(cls, size: int): def check_size(cls: type[Self], size: int) -> None:
if size != 0 and size % 44 != 0: if size != 0 and size % 44 != 0:
raise Exception(f'Expected size to be multiple of 44, got {size}') raise Exception(f'Expected size to be multiple of 44, got {size}')
@ -168,7 +168,7 @@ class GENERATIONS(Int2Record):
expected_size = 2 expected_size = 2
@classmethod @classmethod
def check_data(cls, data: Sequence[int]): def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None:
if len(data) != 1: if len(data) != 1:
raise Exception(f'Expected exactly one integer, got {data}') raise Exception(f'Expected exactly one integer, got {data}')
@ -177,7 +177,7 @@ class ATTRTABLE(ASCIIRecord):
tag = 0x2306 tag = 0x2306
@classmethod @classmethod
def check_size(cls, size: int): def check_size(cls: type[Self], size: int) -> None:
if size > 44: if size > 44:
raise Exception(f'Expected size <= 44, got {size}') raise Exception(f'Expected size <= 44, got {size}')
@ -266,7 +266,7 @@ class FORMAT(Int2Record):
expected_size = 2 expected_size = 2
@classmethod @classmethod
def check_data(cls, data: Sequence[int]): def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None:
if len(data) != 1: if len(data) != 1:
raise Exception(f'Expected exactly one integer, got {data}') raise Exception(f'Expected exactly one integer, got {data}')

View File

@ -12,7 +12,7 @@ from .basic import decode_real8, encode_real8, parse_datetime
from .basic import KlamathError from .basic import KlamathError
def test_parse_bitarray(): def test_parse_bitarray() -> None:
assert parse_bitarray(b'59') == 13625 assert parse_bitarray(b'59') == 13625
assert parse_bitarray(b'\0\0') == 0 assert parse_bitarray(b'\0\0') == 0
assert parse_bitarray(b'\xff\xff') == 65535 assert parse_bitarray(b'\xff\xff') == 65535
@ -26,7 +26,7 @@ def test_parse_bitarray():
parse_bitarray(b'') parse_bitarray(b'')
def test_parse_int2(): def test_parse_int2() -> None:
assert_array_equal(parse_int2(b'59\xff\xff\0\0'), (13625, -1, 0)) assert_array_equal(parse_int2(b'59\xff\xff\0\0'), (13625, -1, 0))
# odd length # odd length
@ -38,7 +38,7 @@ def test_parse_int2():
parse_int2(b'') parse_int2(b'')
def test_parse_int4(): def test_parse_int4() -> None:
assert_array_equal(parse_int4(b'4321'), (875770417,)) assert_array_equal(parse_int4(b'4321'), (875770417,))
# length % 4 != 0 # length % 4 != 0
@ -50,7 +50,7 @@ def test_parse_int4():
parse_int4(b'') parse_int4(b'')
def test_decode_real8(): def test_decode_real8() -> None:
# zeroes # zeroes
assert decode_real8(numpy.array([0x0])) == 0 assert decode_real8(numpy.array([0x0])) == 0
assert decode_real8(numpy.array([1 << 63])) == 0 # negative assert decode_real8(numpy.array([1 << 63])) == 0 # negative
@ -60,7 +60,7 @@ def test_decode_real8():
assert decode_real8(numpy.array([0xC120 << 48])) == -2.0 assert decode_real8(numpy.array([0xC120 << 48])) == -2.0
def test_parse_real8(): def test_parse_real8() -> None:
packed = struct.pack('>3Q', 0x0, 0x4110_0000_0000_0000, 0xC120_0000_0000_0000) packed = struct.pack('>3Q', 0x0, 0x4110_0000_0000_0000, 0xC120_0000_0000_0000)
assert_array_equal(parse_real8(packed), (0.0, 1.0, -2.0)) assert_array_equal(parse_real8(packed), (0.0, 1.0, -2.0))
@ -73,7 +73,7 @@ def test_parse_real8():
parse_real8(b'') parse_real8(b'')
def test_parse_ascii(): def test_parse_ascii() -> None:
# # empty data Now allowed! # # empty data Now allowed!
# with pytest.raises(KlamathError): # with pytest.raises(KlamathError):
# parse_ascii(b'') # parse_ascii(b'')
@ -82,40 +82,40 @@ def test_parse_ascii():
assert parse_ascii(b'12345\0') == b'12345' # strips trailing null byte assert parse_ascii(b'12345\0') == b'12345' # strips trailing null byte
def test_pack_bitarray(): def test_pack_bitarray() -> None:
packed = pack_bitarray(321) packed = pack_bitarray(321)
assert len(packed) == 2 assert len(packed) == 2
assert packed == struct.pack('>H', 321) assert packed == struct.pack('>H', 321)
def test_pack_int2(): def test_pack_int2() -> None:
packed = pack_int2((3, 2, 1)) packed = pack_int2((3, 2, 1))
assert len(packed) == 3 * 2 assert len(packed) == 3 * 2
assert packed == struct.pack('>3h', 3, 2, 1) assert packed == struct.pack('>3h', 3, 2, 1)
assert pack_int2([-3, 2, -1]) == struct.pack('>3h', -3, 2, -1) assert pack_int2([-3, 2, -1]) == struct.pack('>3h', -3, 2, -1)
def test_pack_int4(): def test_pack_int4() -> None:
packed = pack_int4((3, 2, 1)) packed = pack_int4((3, 2, 1))
assert len(packed) == 3 * 4 assert len(packed) == 3 * 4
assert packed == struct.pack('>3l', 3, 2, 1) assert packed == struct.pack('>3l', 3, 2, 1)
assert pack_int4([-3, 2, -1]) == struct.pack('>3l', -3, 2, -1) assert pack_int4([-3, 2, -1]) == struct.pack('>3l', -3, 2, -1)
def test_encode_real8(): def test_encode_real8() -> None:
assert encode_real8(numpy.array([0.0])) == 0 assert encode_real8(numpy.array([0.0])) == 0
arr = numpy.array((1.0, -2.0, 1e-9, 1e-3, 1e-12)) arr = numpy.array((1.0, -2.0, 1e-9, 1e-3, 1e-12))
assert_array_equal(decode_real8(encode_real8(arr)), arr) assert_array_equal(decode_real8(encode_real8(arr)), arr)
def test_pack_real8(): def test_pack_real8() -> None:
reals = (0, 1, -1, 0.5, 1e-9, 1e-3, 1e-12) reals = (0, 1, -1, 0.5, 1e-9, 1e-3, 1e-12)
packed = pack_real8(reals) packed = pack_real8(reals)
assert len(packed) == len(reals) * 8 assert len(packed) == len(reals) * 8
assert_array_equal(parse_real8(packed), reals) assert_array_equal(parse_real8(packed), reals)
def test_pack_ascii(): def test_pack_ascii() -> None:
assert pack_ascii(b'4321') == b'4321' assert pack_ascii(b'4321') == b'4321'
assert pack_ascii(b'321') == b'321\0' assert pack_ascii(b'321') == b'321\0'