improve type annotations
This commit is contained in:
parent
e7e42a2ef8
commit
59c94f7c17
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Generic record-level read/write functionality.
|
||||
"""
|
||||
from typing import Sequence, IO, TypeVar, ClassVar, Type
|
||||
from typing import IO, ClassVar, Self, Generic, TypeVar
|
||||
from collections.abc import Sequence
|
||||
import struct
|
||||
import io
|
||||
from datetime import datetime
|
||||
@ -17,6 +18,8 @@ from .basic import parse_ascii, pack_ascii, read
|
||||
|
||||
|
||||
_RECORD_HEADER_FMT = struct.Struct('>HH')
|
||||
II = TypeVar('II') # Input type
|
||||
OO = TypeVar('OO') # Output type
|
||||
|
||||
|
||||
def write_record_header(stream: IO[bytes], data_size: int, tag: int) -> int:
|
||||
@ -53,30 +56,27 @@ def expect_record(stream: IO[bytes], tag: int) -> int:
|
||||
return data_size
|
||||
|
||||
|
||||
R = TypeVar('R', bound='Record')
|
||||
|
||||
|
||||
class Record(metaclass=ABCMeta):
|
||||
class Record(Generic[II, OO], metaclass=ABCMeta):
|
||||
tag: ClassVar[int] = -1
|
||||
expected_size: ClassVar[int | None] = None
|
||||
|
||||
@classmethod
|
||||
def check_size(cls, size: int):
|
||||
def check_size(cls: type[Self], size: int) -> None:
|
||||
if cls.expected_size is not None and size != cls.expected_size:
|
||||
raise KlamathError(f'Expected size {cls.expected_size}, got {size}')
|
||||
|
||||
@classmethod
|
||||
def check_data(cls, data):
|
||||
def check_data(cls: type[Self], data: II) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int):
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> OO:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def pack_data(cls, data) -> bytes:
|
||||
def pack_data(cls: type[Self], data: II) -> bytes:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -84,11 +84,11 @@ class Record(metaclass=ABCMeta):
|
||||
return read_record_header(stream)
|
||||
|
||||
@classmethod
|
||||
def write_header(cls, stream: IO[bytes], data_size: int) -> int:
|
||||
def write_header(cls: type[Self], stream: IO[bytes], data_size: int) -> int:
|
||||
return write_record_header(stream, data_size, cls.tag)
|
||||
|
||||
@classmethod
|
||||
def skip_past(cls, stream: IO[bytes]) -> bool:
|
||||
def skip_past(cls: type[Self], stream: IO[bytes]) -> bool:
|
||||
"""
|
||||
Skip to the end of the next occurence of this record.
|
||||
|
||||
@ -110,7 +110,7 @@ class Record(metaclass=ABCMeta):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def skip_and_read(cls, stream: IO[bytes]):
|
||||
def skip_and_read(cls: type[Self], stream: IO[bytes]) -> OO:
|
||||
size, tag = Record.read_header(stream)
|
||||
while tag != cls.tag:
|
||||
stream.seek(size, io.SEEK_CUR)
|
||||
@ -119,90 +119,90 @@ class Record(metaclass=ABCMeta):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def read(cls: Type[R], stream: IO[bytes]):
|
||||
def read(cls: type[Self], stream: IO[bytes]) -> OO:
|
||||
size = expect_record(stream, cls.tag)
|
||||
data = cls.read_data(stream, size)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def write(cls, stream: IO[bytes], data) -> int:
|
||||
def write(cls: type[Self], stream: IO[bytes], data: II) -> int:
|
||||
data_bytes = cls.pack_data(data)
|
||||
b = cls.write_header(stream, len(data_bytes))
|
||||
b += stream.write(data_bytes)
|
||||
return b
|
||||
|
||||
|
||||
class NoDataRecord(Record):
|
||||
class NoDataRecord(Record[None, None]):
|
||||
expected_size: ClassVar[int | None] = 0
|
||||
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> None:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> None:
|
||||
stream.read(size)
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: None) -> bytes:
|
||||
def pack_data(cls: type[Self], data: None) -> bytes:
|
||||
if data is not None:
|
||||
raise KlamathError('?? Packing {data} into NoDataRecord??')
|
||||
return b''
|
||||
|
||||
|
||||
class BitArrayRecord(Record):
|
||||
class BitArrayRecord(Record[int, int]):
|
||||
expected_size: ClassVar[int | None] = 2
|
||||
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> int:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> int: # noqa: ARG003 size unused
|
||||
return parse_bitarray(read(stream, 2))
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: int) -> bytes:
|
||||
def pack_data(cls: type[Self], data: int) -> bytes:
|
||||
return pack_bitarray(data)
|
||||
|
||||
|
||||
class Int2Record(Record):
|
||||
class Int2Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int16]]):
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int16]:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int16]:
|
||||
return parse_int2(read(stream, size))
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: Sequence[int]) -> bytes:
|
||||
def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
|
||||
return pack_int2(data)
|
||||
|
||||
|
||||
class Int4Record(Record):
|
||||
class Int4Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int32]]):
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int32]:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int32]:
|
||||
return parse_int4(read(stream, size))
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: Sequence[int]) -> bytes:
|
||||
def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes:
|
||||
return pack_int4(data)
|
||||
|
||||
|
||||
class Real8Record(Record):
|
||||
class Real8Record(Record[Sequence[float] | float, NDArray[numpy.float64]]):
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.float64]:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.float64]:
|
||||
return parse_real8(read(stream, size))
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: Sequence[int]) -> bytes:
|
||||
def pack_data(cls: type[Self], data: Sequence[float] | float) -> bytes:
|
||||
return pack_real8(data)
|
||||
|
||||
|
||||
class ASCIIRecord(Record):
|
||||
class ASCIIRecord(Record[bytes, bytes]):
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> bytes:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> bytes:
|
||||
return parse_ascii(read(stream, size))
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: bytes) -> bytes:
|
||||
def pack_data(cls: type[Self], data: bytes) -> bytes:
|
||||
return pack_ascii(data)
|
||||
|
||||
|
||||
class DateTimeRecord(Record):
|
||||
class DateTimeRecord(Record[Sequence[datetime], list[datetime]]):
|
||||
@classmethod
|
||||
def read_data(cls, stream: IO[bytes], size: int) -> list[datetime]:
|
||||
def read_data(cls: type[Self], stream: IO[bytes], size: int) -> list[datetime]:
|
||||
return parse_datetime(read(stream, size))
|
||||
|
||||
@classmethod
|
||||
def pack_data(cls, data: Sequence[datetime]) -> bytes:
|
||||
def pack_data(cls: type[Self], data: Sequence[datetime]) -> bytes:
|
||||
return pack_datetime(data)
|
||||
|
@ -144,7 +144,7 @@ class REFLIBS(ASCIIRecord):
|
||||
tag = 0x1f06
|
||||
|
||||
@classmethod
|
||||
def check_size(cls, size: int):
|
||||
def check_size(cls: type[Self], size: int) -> None:
|
||||
if size != 0 and size % 44 != 0:
|
||||
raise Exception(f'Expected size to be multiple of 44, got {size}')
|
||||
|
||||
@ -153,7 +153,7 @@ class FONTS(ASCIIRecord):
|
||||
tag = 0x2006
|
||||
|
||||
@classmethod
|
||||
def check_size(cls, size: int):
|
||||
def check_size(cls: type[Self], size: int) -> None:
|
||||
if size != 0 and size % 44 != 0:
|
||||
raise Exception(f'Expected size to be multiple of 44, got {size}')
|
||||
|
||||
@ -168,7 +168,7 @@ class GENERATIONS(Int2Record):
|
||||
expected_size = 2
|
||||
|
||||
@classmethod
|
||||
def check_data(cls, data: Sequence[int]):
|
||||
def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None:
|
||||
if len(data) != 1:
|
||||
raise Exception(f'Expected exactly one integer, got {data}')
|
||||
|
||||
@ -177,7 +177,7 @@ class ATTRTABLE(ASCIIRecord):
|
||||
tag = 0x2306
|
||||
|
||||
@classmethod
|
||||
def check_size(cls, size: int):
|
||||
def check_size(cls: type[Self], size: int) -> None:
|
||||
if size > 44:
|
||||
raise Exception(f'Expected size <= 44, got {size}')
|
||||
|
||||
@ -266,7 +266,7 @@ class FORMAT(Int2Record):
|
||||
expected_size = 2
|
||||
|
||||
@classmethod
|
||||
def check_data(cls, data: Sequence[int]):
|
||||
def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None:
|
||||
if len(data) != 1:
|
||||
raise Exception(f'Expected exactly one integer, got {data}')
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .basic import decode_real8, encode_real8, parse_datetime
|
||||
from .basic import KlamathError
|
||||
|
||||
|
||||
def test_parse_bitarray():
|
||||
def test_parse_bitarray() -> None:
|
||||
assert parse_bitarray(b'59') == 13625
|
||||
assert parse_bitarray(b'\0\0') == 0
|
||||
assert parse_bitarray(b'\xff\xff') == 65535
|
||||
@ -26,7 +26,7 @@ def test_parse_bitarray():
|
||||
parse_bitarray(b'')
|
||||
|
||||
|
||||
def test_parse_int2():
|
||||
def test_parse_int2() -> None:
|
||||
assert_array_equal(parse_int2(b'59\xff\xff\0\0'), (13625, -1, 0))
|
||||
|
||||
# odd length
|
||||
@ -38,7 +38,7 @@ def test_parse_int2():
|
||||
parse_int2(b'')
|
||||
|
||||
|
||||
def test_parse_int4():
|
||||
def test_parse_int4() -> None:
|
||||
assert_array_equal(parse_int4(b'4321'), (875770417,))
|
||||
|
||||
# length % 4 != 0
|
||||
@ -50,7 +50,7 @@ def test_parse_int4():
|
||||
parse_int4(b'')
|
||||
|
||||
|
||||
def test_decode_real8():
|
||||
def test_decode_real8() -> None:
|
||||
# zeroes
|
||||
assert decode_real8(numpy.array([0x0])) == 0
|
||||
assert decode_real8(numpy.array([1 << 63])) == 0 # negative
|
||||
@ -60,7 +60,7 @@ def test_decode_real8():
|
||||
assert decode_real8(numpy.array([0xC120 << 48])) == -2.0
|
||||
|
||||
|
||||
def test_parse_real8():
|
||||
def test_parse_real8() -> None:
|
||||
packed = struct.pack('>3Q', 0x0, 0x4110_0000_0000_0000, 0xC120_0000_0000_0000)
|
||||
assert_array_equal(parse_real8(packed), (0.0, 1.0, -2.0))
|
||||
|
||||
@ -73,7 +73,7 @@ def test_parse_real8():
|
||||
parse_real8(b'')
|
||||
|
||||
|
||||
def test_parse_ascii():
|
||||
def test_parse_ascii() -> None:
|
||||
# # empty data Now allowed!
|
||||
# with pytest.raises(KlamathError):
|
||||
# parse_ascii(b'')
|
||||
@ -82,40 +82,40 @@ def test_parse_ascii():
|
||||
assert parse_ascii(b'12345\0') == b'12345' # strips trailing null byte
|
||||
|
||||
|
||||
def test_pack_bitarray():
|
||||
def test_pack_bitarray() -> None:
|
||||
packed = pack_bitarray(321)
|
||||
assert len(packed) == 2
|
||||
assert packed == struct.pack('>H', 321)
|
||||
|
||||
|
||||
def test_pack_int2():
|
||||
def test_pack_int2() -> None:
|
||||
packed = pack_int2((3, 2, 1))
|
||||
assert len(packed) == 3 * 2
|
||||
assert packed == struct.pack('>3h', 3, 2, 1)
|
||||
assert pack_int2([-3, 2, -1]) == struct.pack('>3h', -3, 2, -1)
|
||||
|
||||
|
||||
def test_pack_int4():
|
||||
def test_pack_int4() -> None:
|
||||
packed = pack_int4((3, 2, 1))
|
||||
assert len(packed) == 3 * 4
|
||||
assert packed == struct.pack('>3l', 3, 2, 1)
|
||||
assert pack_int4([-3, 2, -1]) == struct.pack('>3l', -3, 2, -1)
|
||||
|
||||
|
||||
def test_encode_real8():
|
||||
def test_encode_real8() -> None:
|
||||
assert encode_real8(numpy.array([0.0])) == 0
|
||||
arr = numpy.array((1.0, -2.0, 1e-9, 1e-3, 1e-12))
|
||||
assert_array_equal(decode_real8(encode_real8(arr)), arr)
|
||||
|
||||
|
||||
def test_pack_real8():
|
||||
def test_pack_real8() -> None:
|
||||
reals = (0, 1, -1, 0.5, 1e-9, 1e-3, 1e-12)
|
||||
packed = pack_real8(reals)
|
||||
assert len(packed) == len(reals) * 8
|
||||
assert_array_equal(parse_real8(packed), reals)
|
||||
|
||||
|
||||
def test_pack_ascii():
|
||||
def test_pack_ascii() -> None:
|
||||
assert pack_ascii(b'4321') == b'4321'
|
||||
assert pack_ascii(b'321') == b'321\0'
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user