From 59c94f7c1798f9c5677a57522587b5f13d107954 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Sun, 28 Jul 2024 23:15:11 -0700 Subject: [PATCH] improve type annotations --- klamath/record.py | 70 +++++++++++++++++++++---------------------- klamath/records.py | 10 +++---- klamath/test_basic.py | 24 +++++++-------- 3 files changed, 52 insertions(+), 52 deletions(-) diff --git a/klamath/record.py b/klamath/record.py index c0b4a6c..674e55c 100644 --- a/klamath/record.py +++ b/klamath/record.py @@ -1,7 +1,8 @@ """ Generic record-level read/write functionality. """ -from typing import Sequence, IO, TypeVar, ClassVar, Type +from typing import IO, ClassVar, Self, Generic, TypeVar +from collections.abc import Sequence import struct import io from datetime import datetime @@ -17,6 +18,8 @@ from .basic import parse_ascii, pack_ascii, read _RECORD_HEADER_FMT = struct.Struct('>HH') +II = TypeVar('II') # Input type +OO = TypeVar('OO') # Output type def write_record_header(stream: IO[bytes], data_size: int, tag: int) -> int: @@ -53,30 +56,27 @@ def expect_record(stream: IO[bytes], tag: int) -> int: return data_size -R = TypeVar('R', bound='Record') - - -class Record(metaclass=ABCMeta): +class Record(Generic[II, OO], metaclass=ABCMeta): tag: ClassVar[int] = -1 expected_size: ClassVar[int | None] = None @classmethod - def check_size(cls, size: int): + def check_size(cls: type[Self], size: int) -> None: if cls.expected_size is not None and size != cls.expected_size: raise KlamathError(f'Expected size {cls.expected_size}, got {size}') @classmethod - def check_data(cls, data): + def check_data(cls: type[Self], data: II) -> None: pass @classmethod @abstractmethod - def read_data(cls, stream: IO[bytes], size: int): + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> OO: pass @classmethod @abstractmethod - def pack_data(cls, data) -> bytes: + def pack_data(cls: type[Self], data: II) -> bytes: pass @staticmethod @@ -84,11 +84,11 @@ class Record(metaclass=ABCMeta): return read_record_header(stream) @classmethod - def write_header(cls, stream: IO[bytes], data_size: int) -> int: + def write_header(cls: type[Self], stream: IO[bytes], data_size: int) -> int: return write_record_header(stream, data_size, cls.tag) @classmethod - def skip_past(cls, stream: IO[bytes]) -> bool: + def skip_past(cls: type[Self], stream: IO[bytes]) -> bool: """ Skip to the end of the next occurence of this record. @@ -110,7 +110,7 @@ class Record(metaclass=ABCMeta): return True @classmethod - def skip_and_read(cls, stream: IO[bytes]): + def skip_and_read(cls: type[Self], stream: IO[bytes]) -> OO: size, tag = Record.read_header(stream) while tag != cls.tag: stream.seek(size, io.SEEK_CUR) @@ -119,90 +119,90 @@ class Record(metaclass=ABCMeta): return data @classmethod - def read(cls: Type[R], stream: IO[bytes]): + def read(cls: type[Self], stream: IO[bytes]) -> OO: size = expect_record(stream, cls.tag) data = cls.read_data(stream, size) return data @classmethod - def write(cls, stream: IO[bytes], data) -> int: + def write(cls: type[Self], stream: IO[bytes], data: II) -> int: data_bytes = cls.pack_data(data) b = cls.write_header(stream, len(data_bytes)) b += stream.write(data_bytes) return b -class NoDataRecord(Record): +class NoDataRecord(Record[None, None]): expected_size: ClassVar[int | None] = 0 @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> None: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> None: stream.read(size) @classmethod - def pack_data(cls, data: None) -> bytes: + def pack_data(cls: type[Self], data: None) -> bytes: if data is not None: raise KlamathError('?? Packing {data} into NoDataRecord??') return b'' -class BitArrayRecord(Record): +class BitArrayRecord(Record[int, int]): expected_size: ClassVar[int | None] = 2 @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> int: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> int: # noqa: ARG003 size unused return parse_bitarray(read(stream, 2)) @classmethod - def pack_data(cls, data: int) -> bytes: + def pack_data(cls: type[Self], data: int) -> bytes: return pack_bitarray(data) -class Int2Record(Record): +class Int2Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int16]]): @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int16]: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int16]: return parse_int2(read(stream, size)) @classmethod - def pack_data(cls, data: Sequence[int]) -> bytes: + def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: return pack_int2(data) -class Int4Record(Record): +class Int4Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int32]]): @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int32]: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int32]: return parse_int4(read(stream, size)) @classmethod - def pack_data(cls, data: Sequence[int]) -> bytes: + def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: return pack_int4(data) -class Real8Record(Record): +class Real8Record(Record[Sequence[float] | float, NDArray[numpy.float64]]): @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.float64]: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.float64]: return parse_real8(read(stream, size)) @classmethod - def pack_data(cls, data: Sequence[int]) -> bytes: + def pack_data(cls: type[Self], data: Sequence[float] | float) -> bytes: return pack_real8(data) -class ASCIIRecord(Record): +class ASCIIRecord(Record[bytes, bytes]): @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> bytes: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> bytes: return parse_ascii(read(stream, size)) @classmethod - def pack_data(cls, data: bytes) -> bytes: + def pack_data(cls: type[Self], data: bytes) -> bytes: return pack_ascii(data) -class DateTimeRecord(Record): +class DateTimeRecord(Record[Sequence[datetime], list[datetime]]): @classmethod - def read_data(cls, stream: IO[bytes], size: int) -> list[datetime]: + def read_data(cls: type[Self], stream: IO[bytes], size: int) -> list[datetime]: return parse_datetime(read(stream, size)) @classmethod - def pack_data(cls, data: Sequence[datetime]) -> bytes: + def pack_data(cls: type[Self], data: Sequence[datetime]) -> bytes: return pack_datetime(data) diff --git a/klamath/records.py b/klamath/records.py index 96fbc7f..9fe296d 100644 --- a/klamath/records.py +++ b/klamath/records.py @@ -144,7 +144,7 @@ class REFLIBS(ASCIIRecord): tag = 0x1f06 @classmethod - def check_size(cls, size: int): + def check_size(cls: type[Self], size: int) -> None: if size != 0 and size % 44 != 0: raise Exception(f'Expected size to be multiple of 44, got {size}') @@ -153,7 +153,7 @@ class FONTS(ASCIIRecord): tag = 0x2006 @classmethod - def check_size(cls, size: int): + def check_size(cls: type[Self], size: int) -> None: if size != 0 and size % 44 != 0: raise Exception(f'Expected size to be multiple of 44, got {size}') @@ -168,7 +168,7 @@ class GENERATIONS(Int2Record): expected_size = 2 @classmethod - def check_data(cls, data: Sequence[int]): + def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None: if len(data) != 1: raise Exception(f'Expected exactly one integer, got {data}') @@ -177,7 +177,7 @@ class ATTRTABLE(ASCIIRecord): tag = 0x2306 @classmethod - def check_size(cls, size: int): + def check_size(cls: type[Self], size: int) -> None: if size > 44: raise Exception(f'Expected size <= 44, got {size}') @@ -266,7 +266,7 @@ class FORMAT(Int2Record): expected_size = 2 @classmethod - def check_data(cls, data: Sequence[int]): + def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None: if len(data) != 1: raise Exception(f'Expected exactly one integer, got {data}') diff --git a/klamath/test_basic.py b/klamath/test_basic.py index b511cc6..4d686d9 100644 --- a/klamath/test_basic.py +++ b/klamath/test_basic.py @@ -12,7 +12,7 @@ from .basic import decode_real8, encode_real8, parse_datetime from .basic import KlamathError -def test_parse_bitarray(): +def test_parse_bitarray() -> None: assert parse_bitarray(b'59') == 13625 assert parse_bitarray(b'\0\0') == 0 assert parse_bitarray(b'\xff\xff') == 65535 @@ -26,7 +26,7 @@ def test_parse_bitarray(): parse_bitarray(b'') -def test_parse_int2(): +def test_parse_int2() -> None: assert_array_equal(parse_int2(b'59\xff\xff\0\0'), (13625, -1, 0)) # odd length @@ -38,7 +38,7 @@ def test_parse_int2(): parse_int2(b'') -def test_parse_int4(): +def test_parse_int4() -> None: assert_array_equal(parse_int4(b'4321'), (875770417,)) # length % 4 != 0 @@ -50,7 +50,7 @@ def test_parse_int4(): parse_int4(b'') -def test_decode_real8(): +def test_decode_real8() -> None: # zeroes assert decode_real8(numpy.array([0x0])) == 0 assert decode_real8(numpy.array([1 << 63])) == 0 # negative @@ -60,7 +60,7 @@ def test_decode_real8(): assert decode_real8(numpy.array([0xC120 << 48])) == -2.0 -def test_parse_real8(): +def test_parse_real8() -> None: packed = struct.pack('>3Q', 0x0, 0x4110_0000_0000_0000, 0xC120_0000_0000_0000) assert_array_equal(parse_real8(packed), (0.0, 1.0, -2.0)) @@ -73,7 +73,7 @@ def test_parse_real8(): parse_real8(b'') -def test_parse_ascii(): +def test_parse_ascii() -> None: # # empty data Now allowed! # with pytest.raises(KlamathError): # parse_ascii(b'') @@ -82,40 +82,40 @@ def test_parse_ascii(): assert parse_ascii(b'12345\0') == b'12345' # strips trailing null byte -def test_pack_bitarray(): +def test_pack_bitarray() -> None: packed = pack_bitarray(321) assert len(packed) == 2 assert packed == struct.pack('>H', 321) -def test_pack_int2(): +def test_pack_int2() -> None: packed = pack_int2((3, 2, 1)) assert len(packed) == 3 * 2 assert packed == struct.pack('>3h', 3, 2, 1) assert pack_int2([-3, 2, -1]) == struct.pack('>3h', -3, 2, -1) -def test_pack_int4(): +def test_pack_int4() -> None: packed = pack_int4((3, 2, 1)) assert len(packed) == 3 * 4 assert packed == struct.pack('>3l', 3, 2, 1) assert pack_int4([-3, 2, -1]) == struct.pack('>3l', -3, 2, -1) -def test_encode_real8(): +def test_encode_real8() -> None: assert encode_real8(numpy.array([0.0])) == 0 arr = numpy.array((1.0, -2.0, 1e-9, 1e-3, 1e-12)) assert_array_equal(decode_real8(encode_real8(arr)), arr) -def test_pack_real8(): +def test_pack_real8() -> None: reals = (0, 1, -1, 0.5, 1e-9, 1e-3, 1e-12) packed = pack_real8(reals) assert len(packed) == len(reals) * 8 assert_array_equal(parse_real8(packed), reals) -def test_pack_ascii(): +def test_pack_ascii() -> None: assert pack_ascii(b'4321') == b'4321' assert pack_ascii(b'321') == b'321\0'