improve type annotations

This commit is contained in:
Jan Petykiewicz 2024-07-28 23:15:11 -07:00
parent e7e42a2ef8
commit 59c94f7c17
3 changed files with 52 additions and 52 deletions

View File

@ -1,7 +1,8 @@
"""
Generic record-level read/write functionality.
"""
from typing import Sequence, IO, TypeVar, ClassVar, Type
from typing import IO, ClassVar, Self, Generic, TypeVar
from collections.abc import Sequence
import struct
import io
from datetime import datetime
@ -17,6 +18,8 @@ from .basic import parse_ascii, pack_ascii, read
_RECORD_HEADER_FMT = struct.Struct('>HH')
II = TypeVar('II') # Input type
OO = TypeVar('OO') # Output type
def write_record_header(stream: IO[bytes], data_size: int, tag: int) -> int:
@ -53,30 +56,27 @@ def expect_record(stream: IO[bytes], tag: int) -> int:
return data_size
R = TypeVar('R', bound='Record')
class Record(metaclass=ABCMeta):
class Record(Generic[II, OO], metaclass=ABCMeta):
tag: ClassVar[int] = -1
expected_size: ClassVar[int | None] = None
@classmethod
def check_size(cls, size: int):
def check_size(cls: type[Self], size: int) -> None:
if cls.expected_size is not None and size != cls.expected_size:
raise KlamathError(f'Expected size {cls.expected_size}, got {size}')
@classmethod
def check_data(cls, data):
def check_data(cls: type[Self], data: II) -> None:
pass
@classmethod
@abstractmethod
def read_data(cls, stream: IO[bytes], size: int):
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> OO:
pass
@classmethod
@abstractmethod
def pack_data(cls, data) -> bytes:
def pack_data(cls: type[Self], data: II) -> bytes:
pass
@staticmethod
@ -84,11 +84,11 @@ class Record(metaclass=ABCMeta):
return read_record_header(stream)
@classmethod
def write_header(cls, stream: IO[bytes], data_size: int) -> int:
def write_header(cls: type[Self], stream: IO[bytes], data_size: int) -> int:
return write_record_header(stream, data_size, cls.tag)
@classmethod
def skip_past(cls, stream: IO[bytes]) -> bool:
def skip_past(cls: type[Self], stream: IO[bytes]) -> bool:
"""
Skip to the end of the next occurence of this record.
@ -110,7 +110,7 @@ class Record(metaclass=ABCMeta):
return True
@classmethod
def skip_and_read(cls, stream: IO[bytes]):
def skip_and_read(cls: type[Self], stream: IO[bytes]) -> OO:
size, tag = Record.read_header(stream)
while tag != cls.tag:
stream.seek(size, io.SEEK_CUR)
@ -119,90 +119,90 @@ class Record(metaclass=ABCMeta):
return data
@classmethod
def read(cls: Type[R], stream: IO[bytes]):
def read(cls: type[Self], stream: IO[bytes]) -> OO:
size = expect_record(stream, cls.tag)
data = cls.read_data(stream, size)
return data
@classmethod
def write(cls, stream: IO[bytes], data) -> int:
def write(cls: type[Self], stream: IO[bytes], data: II) -> int:
data_bytes = cls.pack_data(data)
b = cls.write_header(stream, len(data_bytes))
b += stream.write(data_bytes)
return b
class NoDataRecord(Record):
class NoDataRecord(Record[None, None]):
expected_size: ClassVar[int | None] = 0
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> None:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> None:
stream.read(size)
@classmethod
def pack_data(cls, data: None) -> bytes:
def pack_data(cls: type[Self], data: None) -> bytes:
if data is not None:
raise KlamathError('?? Packing {data} into NoDataRecord??')
return b''
class BitArrayRecord(Record):
class BitArrayRecord(Record[int, int]):
expected_size: ClassVar[int | None] = 2
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> int:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> int: # noqa: ARG003 size unused
return parse_bitarray(read(stream, 2))
@classmethod
def pack_data(cls, data: int) -> bytes:
def pack_data(cls: type[Self], data: int) -> bytes:
return pack_bitarray(data)
class Int2Record(Record):
class Int2Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int16]]):
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int16]:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int16]:
return parse_int2(read(stream, size))
@classmethod
def pack_data(cls, data: Sequence[int]) -> bytes:
def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
return pack_int2(data)
class Int4Record(Record):
class Int4Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int32]]):
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int32]:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int32]:
return parse_int4(read(stream, size))
@classmethod
def pack_data(cls, data: Sequence[int]) -> bytes:
def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
return pack_int4(data)
class Real8Record(Record):
class Real8Record(Record[Sequence[float] | float, NDArray[numpy.float64]]):
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.float64]:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.float64]:
return parse_real8(read(stream, size))
@classmethod
def pack_data(cls, data: Sequence[int]) -> bytes:
def pack_data(cls: type[Self], data: Sequence[float] | float) -> bytes:
return pack_real8(data)
class ASCIIRecord(Record):
class ASCIIRecord(Record[bytes, bytes]):
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> bytes:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> bytes:
return parse_ascii(read(stream, size))
@classmethod
def pack_data(cls, data: bytes) -> bytes:
def pack_data(cls: type[Self], data: bytes) -> bytes:
return pack_ascii(data)
class DateTimeRecord(Record):
class DateTimeRecord(Record[Sequence[datetime], list[datetime]]):
@classmethod
def read_data(cls, stream: IO[bytes], size: int) -> list[datetime]:
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> list[datetime]:
return parse_datetime(read(stream, size))
@classmethod
def pack_data(cls, data: Sequence[datetime]) -> bytes:
def pack_data(cls: type[Self], data: Sequence[datetime]) -> bytes:
return pack_datetime(data)

View File

@ -144,7 +144,7 @@ class REFLIBS(ASCIIRecord):
tag = 0x1f06
@classmethod
def check_size(cls, size: int):
def check_size(cls: type[Self], size: int) -> None:
if size != 0 and size % 44 != 0:
raise Exception(f'Expected size to be multiple of 44, got {size}')
@ -153,7 +153,7 @@ class FONTS(ASCIIRecord):
tag = 0x2006
@classmethod
def check_size(cls, size: int):
def check_size(cls: type[Self], size: int) -> None:
if size != 0 and size % 44 != 0:
raise Exception(f'Expected size to be multiple of 44, got {size}')
@ -168,7 +168,7 @@ class GENERATIONS(Int2Record):
expected_size = 2
@classmethod
def check_data(cls, data: Sequence[int]):
def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None:
if len(data) != 1:
raise Exception(f'Expected exactly one integer, got {data}')
@ -177,7 +177,7 @@ class ATTRTABLE(ASCIIRecord):
tag = 0x2306
@classmethod
def check_size(cls, size: int):
def check_size(cls: type[Self], size: int) -> None:
if size > 44:
raise Exception(f'Expected size <= 44, got {size}')
@ -266,7 +266,7 @@ class FORMAT(Int2Record):
expected_size = 2
@classmethod
def check_data(cls, data: Sequence[int]):
def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None:
if len(data) != 1:
raise Exception(f'Expected exactly one integer, got {data}')

View File

@ -12,7 +12,7 @@ from .basic import decode_real8, encode_real8, parse_datetime
from .basic import KlamathError
def test_parse_bitarray():
def test_parse_bitarray() -> None:
assert parse_bitarray(b'59') == 13625
assert parse_bitarray(b'\0\0') == 0
assert parse_bitarray(b'\xff\xff') == 65535
@ -26,7 +26,7 @@ def test_parse_bitarray():
parse_bitarray(b'')
def test_parse_int2():
def test_parse_int2() -> None:
assert_array_equal(parse_int2(b'59\xff\xff\0\0'), (13625, -1, 0))
# odd length
@ -38,7 +38,7 @@ def test_parse_int2():
parse_int2(b'')
def test_parse_int4():
def test_parse_int4() -> None:
assert_array_equal(parse_int4(b'4321'), (875770417,))
# length % 4 != 0
@ -50,7 +50,7 @@ def test_parse_int4():
parse_int4(b'')
def test_decode_real8():
def test_decode_real8() -> None:
# zeroes
assert decode_real8(numpy.array([0x0])) == 0
assert decode_real8(numpy.array([1 << 63])) == 0 # negative
@ -60,7 +60,7 @@ def test_decode_real8():
assert decode_real8(numpy.array([0xC120 << 48])) == -2.0
def test_parse_real8():
def test_parse_real8() -> None:
packed = struct.pack('>3Q', 0x0, 0x4110_0000_0000_0000, 0xC120_0000_0000_0000)
assert_array_equal(parse_real8(packed), (0.0, 1.0, -2.0))
@ -73,7 +73,7 @@ def test_parse_real8():
parse_real8(b'')
def test_parse_ascii():
def test_parse_ascii() -> None:
# # empty data Now allowed!
# with pytest.raises(KlamathError):
# parse_ascii(b'')
@ -82,40 +82,40 @@ def test_parse_ascii():
assert parse_ascii(b'12345\0') == b'12345' # strips trailing null byte
def test_pack_bitarray():
def test_pack_bitarray() -> None:
packed = pack_bitarray(321)
assert len(packed) == 2
assert packed == struct.pack('>H', 321)
def test_pack_int2():
def test_pack_int2() -> None:
packed = pack_int2((3, 2, 1))
assert len(packed) == 3 * 2
assert packed == struct.pack('>3h', 3, 2, 1)
assert pack_int2([-3, 2, -1]) == struct.pack('>3h', -3, 2, -1)
def test_pack_int4():
def test_pack_int4() -> None:
packed = pack_int4((3, 2, 1))
assert len(packed) == 3 * 4
assert packed == struct.pack('>3l', 3, 2, 1)
assert pack_int4([-3, 2, -1]) == struct.pack('>3l', -3, 2, -1)
def test_encode_real8():
def test_encode_real8() -> None:
assert encode_real8(numpy.array([0.0])) == 0
arr = numpy.array((1.0, -2.0, 1e-9, 1e-3, 1e-12))
assert_array_equal(decode_real8(encode_real8(arr)), arr)
def test_pack_real8():
def test_pack_real8() -> None:
reals = (0, 1, -1, 0.5, 1e-9, 1e-3, 1e-12)
packed = pack_real8(reals)
assert len(packed) == len(reals) * 8
assert_array_equal(parse_real8(packed), reals)
def test_pack_ascii():
def test_pack_ascii() -> None:
assert pack_ascii(b'4321') == b'4321'
assert pack_ascii(b'321') == b'321\0'