improve type annotations
This commit is contained in:
		
							parent
							
								
									e7e42a2ef8
								
							
						
					
					
						commit
						59c94f7c17
					
				| @ -1,7 +1,8 @@ | ||||
| """ | ||||
| Generic record-level read/write functionality. | ||||
| """ | ||||
| from typing import Sequence, IO, TypeVar, ClassVar, Type | ||||
| from typing import IO, ClassVar, Self, Generic, TypeVar | ||||
| from collections.abc import Sequence | ||||
| import struct | ||||
| import io | ||||
| from datetime import datetime | ||||
| @ -17,6 +18,8 @@ from .basic import parse_ascii, pack_ascii, read | ||||
| 
 | ||||
| 
 | ||||
| _RECORD_HEADER_FMT = struct.Struct('>HH') | ||||
| II = TypeVar('II')  # Input type | ||||
| OO = TypeVar('OO')  # Output type | ||||
| 
 | ||||
| 
 | ||||
| def write_record_header(stream: IO[bytes], data_size: int, tag: int) -> int: | ||||
| @ -53,30 +56,27 @@ def expect_record(stream: IO[bytes], tag: int) -> int: | ||||
|     return data_size | ||||
| 
 | ||||
| 
 | ||||
| R = TypeVar('R', bound='Record') | ||||
| 
 | ||||
| 
 | ||||
| class Record(metaclass=ABCMeta): | ||||
| class Record(Generic[II, OO], metaclass=ABCMeta): | ||||
|     tag: ClassVar[int] = -1 | ||||
|     expected_size: ClassVar[int | None] = None | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_size(cls, size: int): | ||||
|     def check_size(cls: type[Self], size: int) -> None: | ||||
|         if cls.expected_size is not None and size != cls.expected_size: | ||||
|             raise KlamathError(f'Expected size {cls.expected_size}, got {size}') | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_data(cls, data): | ||||
|     def check_data(cls: type[Self], data: II) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     @abstractmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int): | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> OO: | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     @abstractmethod | ||||
|     def pack_data(cls, data) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: II) -> bytes: | ||||
|         pass | ||||
| 
 | ||||
|     @staticmethod | ||||
| @ -84,11 +84,11 @@ class Record(metaclass=ABCMeta): | ||||
|         return read_record_header(stream) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def write_header(cls, stream: IO[bytes], data_size: int) -> int: | ||||
|     def write_header(cls: type[Self], stream: IO[bytes], data_size: int) -> int: | ||||
|         return write_record_header(stream, data_size, cls.tag) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def skip_past(cls, stream: IO[bytes]) -> bool: | ||||
|     def skip_past(cls: type[Self], stream: IO[bytes]) -> bool: | ||||
|         """ | ||||
|         Skip to the end of the next occurence of this record. | ||||
| 
 | ||||
| @ -110,7 +110,7 @@ class Record(metaclass=ABCMeta): | ||||
|         return True | ||||
| 
 | ||||
|     @classmethod | ||||
|     def skip_and_read(cls, stream: IO[bytes]): | ||||
|     def skip_and_read(cls: type[Self], stream: IO[bytes]) -> OO: | ||||
|         size, tag = Record.read_header(stream) | ||||
|         while tag != cls.tag: | ||||
|             stream.seek(size, io.SEEK_CUR) | ||||
| @ -119,90 +119,90 @@ class Record(metaclass=ABCMeta): | ||||
|         return data | ||||
| 
 | ||||
|     @classmethod | ||||
|     def read(cls: Type[R], stream: IO[bytes]): | ||||
|     def read(cls: type[Self], stream: IO[bytes]) -> OO: | ||||
|         size = expect_record(stream, cls.tag) | ||||
|         data = cls.read_data(stream, size) | ||||
|         return data | ||||
| 
 | ||||
|     @classmethod | ||||
|     def write(cls, stream: IO[bytes], data) -> int: | ||||
|     def write(cls: type[Self], stream: IO[bytes], data: II) -> int: | ||||
|         data_bytes = cls.pack_data(data) | ||||
|         b = cls.write_header(stream, len(data_bytes)) | ||||
|         b += stream.write(data_bytes) | ||||
|         return b | ||||
| 
 | ||||
| 
 | ||||
| class NoDataRecord(Record): | ||||
| class NoDataRecord(Record[None, None]): | ||||
|     expected_size: ClassVar[int | None] = 0 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> None: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> None: | ||||
|         stream.read(size) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: None) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: None) -> bytes: | ||||
|         if data is not None: | ||||
|             raise KlamathError('?? Packing {data} into NoDataRecord??') | ||||
|         return b'' | ||||
| 
 | ||||
| 
 | ||||
| class BitArrayRecord(Record): | ||||
| class BitArrayRecord(Record[int, int]): | ||||
|     expected_size: ClassVar[int | None] = 2 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> int: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> int:        # noqa: ARG003  size unused | ||||
|         return parse_bitarray(read(stream, 2)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: int) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: int) -> bytes: | ||||
|         return pack_bitarray(data) | ||||
| 
 | ||||
| 
 | ||||
| class Int2Record(Record): | ||||
| class Int2Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int16]]): | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int16]: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int16]: | ||||
|         return parse_int2(read(stream, size)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: Sequence[int]) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: | ||||
|         return pack_int2(data) | ||||
| 
 | ||||
| 
 | ||||
| class Int4Record(Record): | ||||
| class Int4Record(Record[NDArray[numpy.integer] | Sequence[int] | int, NDArray[numpy.int32]]): | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.int32]: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.int32]: | ||||
|         return parse_int4(read(stream, size)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: Sequence[int]) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> bytes: | ||||
|         return pack_int4(data) | ||||
| 
 | ||||
| 
 | ||||
| class Real8Record(Record): | ||||
| class Real8Record(Record[Sequence[float] | float, NDArray[numpy.float64]]): | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> NDArray[numpy.float64]: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> NDArray[numpy.float64]: | ||||
|         return parse_real8(read(stream, size)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: Sequence[int]) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: Sequence[float] | float) -> bytes: | ||||
|         return pack_real8(data) | ||||
| 
 | ||||
| 
 | ||||
| class ASCIIRecord(Record): | ||||
| class ASCIIRecord(Record[bytes, bytes]): | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> bytes: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> bytes: | ||||
|         return parse_ascii(read(stream, size)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: bytes) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: bytes) -> bytes: | ||||
|         return pack_ascii(data) | ||||
| 
 | ||||
| 
 | ||||
| class DateTimeRecord(Record): | ||||
| class DateTimeRecord(Record[Sequence[datetime], list[datetime]]): | ||||
|     @classmethod | ||||
|     def read_data(cls, stream: IO[bytes], size: int) -> list[datetime]: | ||||
|     def read_data(cls: type[Self], stream: IO[bytes], size: int) -> list[datetime]: | ||||
|         return parse_datetime(read(stream, size)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pack_data(cls, data: Sequence[datetime]) -> bytes: | ||||
|     def pack_data(cls: type[Self], data: Sequence[datetime]) -> bytes: | ||||
|         return pack_datetime(data) | ||||
|  | ||||
| @ -144,7 +144,7 @@ class REFLIBS(ASCIIRecord): | ||||
|     tag = 0x1f06 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_size(cls, size: int): | ||||
|     def check_size(cls: type[Self], size: int) -> None: | ||||
|         if size != 0 and size % 44 != 0: | ||||
|             raise Exception(f'Expected size to be multiple of 44, got {size}') | ||||
| 
 | ||||
| @ -153,7 +153,7 @@ class FONTS(ASCIIRecord): | ||||
|     tag = 0x2006 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_size(cls, size: int): | ||||
|     def check_size(cls: type[Self], size: int) -> None: | ||||
|         if size != 0 and size % 44 != 0: | ||||
|             raise Exception(f'Expected size to be multiple of 44, got {size}') | ||||
| 
 | ||||
| @ -168,7 +168,7 @@ class GENERATIONS(Int2Record): | ||||
|     expected_size = 2 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_data(cls, data: Sequence[int]): | ||||
|     def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None: | ||||
|         if len(data) != 1: | ||||
|             raise Exception(f'Expected exactly one integer, got {data}') | ||||
| 
 | ||||
| @ -177,7 +177,7 @@ class ATTRTABLE(ASCIIRecord): | ||||
|     tag = 0x2306 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_size(cls, size: int): | ||||
|     def check_size(cls: type[Self], size: int) -> None: | ||||
|         if size > 44: | ||||
|             raise Exception(f'Expected size <= 44, got {size}') | ||||
| 
 | ||||
| @ -266,7 +266,7 @@ class FORMAT(Int2Record): | ||||
|     expected_size = 2 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def check_data(cls, data: Sequence[int]): | ||||
|     def check_data(cls: type[Self], data: NDArray[numpy.integer] | Sequence[int] | int) -> None: | ||||
|         if len(data) != 1: | ||||
|             raise Exception(f'Expected exactly one integer, got {data}') | ||||
| 
 | ||||
|  | ||||
| @ -12,7 +12,7 @@ from .basic import decode_real8, encode_real8, parse_datetime | ||||
| from .basic import KlamathError | ||||
| 
 | ||||
| 
 | ||||
| def test_parse_bitarray(): | ||||
| def test_parse_bitarray() -> None: | ||||
|     assert parse_bitarray(b'59') == 13625 | ||||
|     assert parse_bitarray(b'\0\0') == 0 | ||||
|     assert parse_bitarray(b'\xff\xff') == 65535 | ||||
| @ -26,7 +26,7 @@ def test_parse_bitarray(): | ||||
|         parse_bitarray(b'') | ||||
| 
 | ||||
| 
 | ||||
| def test_parse_int2(): | ||||
| def test_parse_int2() -> None: | ||||
|     assert_array_equal(parse_int2(b'59\xff\xff\0\0'), (13625, -1, 0)) | ||||
| 
 | ||||
|     # odd length | ||||
| @ -38,7 +38,7 @@ def test_parse_int2(): | ||||
|         parse_int2(b'') | ||||
| 
 | ||||
| 
 | ||||
| def test_parse_int4(): | ||||
| def test_parse_int4() -> None: | ||||
|     assert_array_equal(parse_int4(b'4321'), (875770417,)) | ||||
| 
 | ||||
|     # length % 4 != 0 | ||||
| @ -50,7 +50,7 @@ def test_parse_int4(): | ||||
|         parse_int4(b'') | ||||
| 
 | ||||
| 
 | ||||
| def test_decode_real8(): | ||||
| def test_decode_real8() -> None: | ||||
|     # zeroes | ||||
|     assert decode_real8(numpy.array([0x0])) == 0 | ||||
|     assert decode_real8(numpy.array([1 << 63])) == 0  # negative | ||||
| @ -60,7 +60,7 @@ def test_decode_real8(): | ||||
|     assert decode_real8(numpy.array([0xC120 << 48])) == -2.0 | ||||
| 
 | ||||
| 
 | ||||
| def test_parse_real8(): | ||||
| def test_parse_real8() -> None: | ||||
|     packed = struct.pack('>3Q', 0x0, 0x4110_0000_0000_0000, 0xC120_0000_0000_0000) | ||||
|     assert_array_equal(parse_real8(packed), (0.0, 1.0, -2.0)) | ||||
| 
 | ||||
| @ -73,7 +73,7 @@ def test_parse_real8(): | ||||
|         parse_real8(b'') | ||||
| 
 | ||||
| 
 | ||||
| def test_parse_ascii(): | ||||
| def test_parse_ascii() -> None: | ||||
|     # # empty data       Now allowed! | ||||
|     # with pytest.raises(KlamathError): | ||||
|     #     parse_ascii(b'') | ||||
| @ -82,40 +82,40 @@ def test_parse_ascii(): | ||||
|     assert parse_ascii(b'12345\0') == b'12345'  # strips trailing null byte | ||||
| 
 | ||||
| 
 | ||||
| def test_pack_bitarray(): | ||||
| def test_pack_bitarray() -> None: | ||||
|     packed = pack_bitarray(321) | ||||
|     assert len(packed) == 2 | ||||
|     assert packed == struct.pack('>H', 321) | ||||
| 
 | ||||
| 
 | ||||
| def test_pack_int2(): | ||||
| def test_pack_int2() -> None: | ||||
|     packed = pack_int2((3, 2, 1)) | ||||
|     assert len(packed) == 3 * 2 | ||||
|     assert packed == struct.pack('>3h', 3, 2, 1) | ||||
|     assert pack_int2([-3, 2, -1]) == struct.pack('>3h', -3, 2, -1) | ||||
| 
 | ||||
| 
 | ||||
| def test_pack_int4(): | ||||
| def test_pack_int4() -> None: | ||||
|     packed = pack_int4((3, 2, 1)) | ||||
|     assert len(packed) == 3 * 4 | ||||
|     assert packed == struct.pack('>3l', 3, 2, 1) | ||||
|     assert pack_int4([-3, 2, -1]) == struct.pack('>3l', -3, 2, -1) | ||||
| 
 | ||||
| 
 | ||||
| def test_encode_real8(): | ||||
| def test_encode_real8() -> None: | ||||
|     assert encode_real8(numpy.array([0.0])) == 0 | ||||
|     arr = numpy.array((1.0, -2.0, 1e-9, 1e-3, 1e-12)) | ||||
|     assert_array_equal(decode_real8(encode_real8(arr)), arr) | ||||
| 
 | ||||
| 
 | ||||
| def test_pack_real8(): | ||||
| def test_pack_real8() -> None: | ||||
|     reals = (0, 1, -1, 0.5, 1e-9, 1e-3, 1e-12) | ||||
|     packed = pack_real8(reals) | ||||
|     assert len(packed) == len(reals) * 8 | ||||
|     assert_array_equal(parse_real8(packed), reals) | ||||
| 
 | ||||
| 
 | ||||
| def test_pack_ascii(): | ||||
| def test_pack_ascii() -> None: | ||||
|     assert pack_ascii(b'4321') == b'4321' | ||||
|     assert pack_ascii(b'321') == b'321\0' | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user