From dc9ed8e79464760dd6f3f0e553b7f124aa71a634 Mon Sep 17 00:00:00 2001 From: Jan Petykiewicz Date: Mon, 29 Jul 2024 18:06:59 -0700 Subject: [PATCH] flatten and simplify conditionals --- fatamorgana/basic.py | 93 +++++++++----------- fatamorgana/main.py | 3 +- fatamorgana/records.py | 99 ++++++++++------------ fatamorgana/test/test_files_ctrapezoids.py | 11 ++- 4 files changed, 93 insertions(+), 113 deletions(-) diff --git a/fatamorgana/basic.py b/fatamorgana/basic.py index 6010d23..66ba118 100644 --- a/fatamorgana/basic.py +++ b/fatamorgana/basic.py @@ -950,7 +950,7 @@ class OctangularDelta: sign = self.octangle & 0x02 > 0 xy[axis] = self.proj_mag * (1 - 2 * sign) return xy - else: + else: # noqa: RET505 yn = (self.octangle & 0x02) > 0 xyn = (self.octangle & 0x01) > 0 ys = 1 - 2 * yn @@ -1097,10 +1097,9 @@ class Delta: """ if self.x == 0 or self.y == 0 or abs(self.x) == abs(self.y): return write_uint(stream, OctangularDelta(self.x, self.y).as_uint() << 1) - else: - size = write_uint(stream, (encode_sint(self.x) << 1) | 0x01) - size += write_uint(stream, encode_sint(self.y)) - return size + size = write_uint(stream, (encode_sint(self.x) << 1) | 0x01) + size += write_uint(stream, encode_sint(self.y)) + return size def __eq__(self, other: Any) -> bool: return hasattr(other, 'as_list') and self.as_list() == other.as_list() @@ -1125,12 +1124,11 @@ def read_repetition(stream: IO[bytes]) -> repetition_t: rtype = read_uint(stream) if rtype == 0: return ReuseRepetition.read(stream, rtype) - elif rtype in (1, 2, 3, 8, 9): + if rtype in (1, 2, 3, 8, 9): return GridRepetition.read(stream, rtype) - elif rtype in (4, 5, 6, 7, 10, 11): + if rtype in (4, 5, 6, 7, 10, 11): return ArbitraryRepetition.read(stream, rtype) - else: - raise InvalidDataError(f'Unexpected repetition type: {rtype}') + raise InvalidDataError(f'Unexpected repetition type: {rtype}') def write_repetition(stream: IO[bytes], repetition: repetition_t) -> int: @@ -1311,7 +1309,7 @@ class GridRepetition: size = write_uint(stream, 9) size += write_uint(stream, self.a_count - 2) size += Delta(*self.a_vector).write(stream) - else: + else: # noqa: PLR5501 if self.a_vector[1] == 0 and self.b_vector[0] == 0: size = write_uint(stream, 1) size += write_uint(stream, self.a_count - 2) @@ -1637,11 +1635,10 @@ def write_point_list( h_first = False v_first = False break - else: - if point[1] != previous[1] or point[0] == previous[0]: - h_first = False - v_first = False - break + elif point[1] != previous[1] or point[0] == previous[0]: + h_first = False + v_first = False + break previous = point # If one of h_first or v_first, write a bunch of 1-deltas @@ -1650,7 +1647,7 @@ def write_point_list( size += write_uint(stream, len(points)) size += sum(write_sint(stream, x + y) for x, y in points) return size - elif v_first: + if v_first: size = write_uint(stream, 1) size += write_uint(stream, len(points)) size += sum(write_sint(stream, x + y) for x, y in points) @@ -1773,30 +1770,29 @@ def read_property_value(stream: IO[bytes]) -> property_value_t: prop_type = read_uint(stream) if 0 <= prop_type <= 7: return read_real(stream, prop_type) - elif prop_type == 8: + if prop_type == 8: return read_uint(stream) - elif prop_type == 9: + if prop_type == 9: return read_sint(stream) - elif prop_type == 10: + if prop_type == 10: return AString.read(stream) - elif prop_type == 11: + if prop_type == 11: return read_bstring(stream) - elif prop_type == 12: + if prop_type == 12: return NString.read(stream) - elif prop_type == 13: + if prop_type == 13: ref_type = AString ref = read_uint(stream) return PropStringReference(ref, ref_type) - elif prop_type == 14: + if prop_type == 14: ref_type = bytes ref = read_uint(stream) return PropStringReference(ref, ref_type) - elif prop_type == 15: + if prop_type == 15: ref_type = NString ref = read_uint(stream) return PropStringReference(ref, ref_type) - else: - raise InvalidDataError(f'Invalid property type: {prop_type}') + raise InvalidDataError(f'Invalid property type: {prop_type}') def write_property_value( @@ -1883,17 +1879,16 @@ def read_interval(stream: IO[bytes]) -> tuple[int | None, int | None]: interval_type = read_uint(stream) if interval_type == 0: return None, None - elif interval_type == 1: + if interval_type == 1: return None, read_uint(stream) - elif interval_type == 2: + if interval_type == 2: return read_uint(stream), None - elif interval_type == 3: + if interval_type == 3: v = read_uint(stream) return v, v - elif interval_type == 4: + if interval_type == 4: return read_uint(stream), read_uint(stream) - else: - raise InvalidDataError(f'Unrecognized interval type: {interval_type}') + raise InvalidDataError(f'Unrecognized interval type: {interval_type}') def write_interval( @@ -1916,18 +1911,15 @@ def write_interval( if min_bound is None: if max_bound is None: return write_uint(stream, 0) - else: - return write_uint(stream, 1) + write_uint(stream, max_bound) - else: - if max_bound is None: - return write_uint(stream, 2) + write_uint(stream, min_bound) - elif min_bound == max_bound: - return write_uint(stream, 3) + write_uint(stream, min_bound) - else: - size = write_uint(stream, 4) - size += write_uint(stream, min_bound) - size += write_uint(stream, max_bound) - return size + return write_uint(stream, 1) + write_uint(stream, max_bound) + if max_bound is None: + return write_uint(stream, 2) + write_uint(stream, min_bound) + if min_bound == max_bound: + return write_uint(stream, 3) + write_uint(stream, min_bound) + size = write_uint(stream, 4) + size += write_uint(stream, min_bound) + size += write_uint(stream, max_bound) + return size class OffsetEntry: @@ -2190,9 +2182,7 @@ class Validation: checksum_type = read_uint(stream) if checksum_type == 0: checksum = None - elif checksum_type == 1: - checksum = read_u32(stream) - elif checksum_type == 2: + elif checksum_type in (1, 2): checksum = read_u32(stream) else: raise InvalidDataError('Invalid validation type!') @@ -2214,14 +2204,13 @@ class Validation: """ if self.checksum_type == 0: return write_uint(stream, 0) - elif self.checksum is None: + if self.checksum is None: raise InvalidDataError(f'Checksum is empty but type is {self.checksum_type}') - elif self.checksum_type == 1: + if self.checksum_type == 1: return write_uint(stream, 1) + write_u32(stream, self.checksum) - elif self.checksum_type == 2: + if self.checksum_type == 2: return write_uint(stream, 2) + write_u32(stream, self.checksum) - else: - raise InvalidDataError(f'Unrecognized checksum type: {self.checksum_type}') + raise InvalidDataError(f'Unrecognized checksum type: {self.checksum_type}') def __repr__(self) -> str: return f'Validation(type: {self.checksum_type} sum: {self.checksum})' diff --git a/fatamorgana/main.py b/fatamorgana/main.py index 978d842..317378c 100644 --- a/fatamorgana/main.py +++ b/fatamorgana/main.py @@ -193,8 +193,7 @@ class OasisLayout: if record_id == 1: if file_state.started: raise InvalidRecordError('Duplicate Start record') - else: - file_state.started = True + file_state.started = True if record_id == 2 and file_state.within_cblock: raise InvalidRecordError('End within CBlock') diff --git a/fatamorgana/records.py b/fatamorgana/records.py index 87e4da8..0de9739 100644 --- a/fatamorgana/records.py +++ b/fatamorgana/records.py @@ -274,10 +274,9 @@ def read_refname( """ if not is_present: return None - elif is_reference: + if is_reference: return read_uint(stream) - else: - return NString.read(stream) + return NString.read(stream) def read_refstring( @@ -299,10 +298,9 @@ def read_refstring( """ if not is_present: return None - elif is_reference: + if is_reference: return read_uint(stream) - else: - return AString.read(stream) + return AString.read(stream) class Pad(Record): @@ -994,32 +992,33 @@ class Property(Record): def write(self, stream: IO[bytes]) -> int: if self.is_standard is None and self.values is None and self.name is None: return write_uint(stream, 29) + + if self.is_standard is None: + raise InvalidDataError('Property has value or name, but no is_standard flag!') + + if self.values is not None: + value_count = len(self.values) + vv = 0 + uu = 0x0f if value_count >= 0x0f else value_count else: - if self.is_standard is None: - raise InvalidDataError('Property has value or name, but no is_standard flag!') - if self.values is not None: - value_count = len(self.values) - vv = 0 - uu = 0x0f if value_count >= 0x0f else value_count + vv = 1 + uu = 0 + + cc = self.name is not None + nn = cc and isinstance(self.name, int) + ss = self.is_standard + + size = write_uint(stream, 28) + size += write_byte(stream, (uu << 4) | (vv << 3) | (cc << 2) | (nn << 1) | ss) + if cc: + if nn: + size += write_uint(stream, self.name) # type: ignore else: - vv = 1 - uu = 0 - - cc = self.name is not None - nn = cc and isinstance(self.name, int) - ss = self.is_standard - - size = write_uint(stream, 28) - size += write_byte(stream, (uu << 4) | (vv << 3) | (cc << 2) | (nn << 1) | ss) - if cc: - if nn: - size += write_uint(stream, self.name) # type: ignore - else: - size += self.name.write(stream) # type: ignore - if not vv: - if uu == 0x0f: - size += write_uint(stream, len(self.values)) # type: ignore - size += sum(write_property_value(stream, pp) for pp in self.values) # type: ignore + size += self.name.write(stream) # type: ignore + if not vv: + if uu == 0x0f: + size += write_uint(stream, len(self.values)) # type: ignore + size += sum(write_property_value(stream, pp) for pp in self.values) # type: ignore return size @@ -1736,9 +1735,8 @@ class Polygon(Record, GeometryMixin): self.point_list = point_list self.properties = [] if properties is None else properties - if point_list is not None: - if len(point_list) < 3: - warn('Polygon with < 3 points') + if point_list is not None and len(point_list) < 3: + warn('Polygon with < 3 points', stacklevel=2) def get_point_list(self) -> point_list_t: return verify_modal(self.point_list) @@ -1921,14 +1919,13 @@ class Path(Record, GeometryMixin): def get_pathext(ext_scheme: int) -> pathextension_t | None: if ext_scheme == 0: return None - elif ext_scheme == 1: + if ext_scheme == 1: return PathExtensionScheme.Flush, None - elif ext_scheme == 2: + if ext_scheme == 2: return PathExtensionScheme.HalfWidth, None - elif ext_scheme == 3: + if ext_scheme == 3: return PathExtensionScheme.Arbitrary, read_sint(stream) - else: - raise InvalidDataError(f'Invalid ext_scheme: {ext_scheme}') + raise InvalidDataError(f'Invalid ext_scheme: {ext_scheme}') optional['extension_start'] = get_pathext(scheme_start) optional['extension_end'] = get_pathext(scheme_end) @@ -2066,9 +2063,8 @@ class Trapezoid(Record, GeometryMixin): if self.is_vertical: if height is not None and delta_b - delta_a > height: raise InvalidDataError(f'Trapezoid: h < delta_b - delta_a ({height} < {delta_b} - {delta_a})') - else: - if width is not None and delta_b - delta_a > width: - raise InvalidDataError(f'Trapezoid: w < delta_b - delta_a ({width} < {delta_b} - {delta_a})') + elif width is not None and delta_b - delta_a > width: + raise InvalidDataError(f'Trapezoid: w < delta_b - delta_a ({width} < {delta_b} - {delta_a})') def get_is_vertical(self) -> bool: return verify_modal(self.is_vertical) @@ -2392,7 +2388,7 @@ class CTrapezoid(Record, GeometryMixin): raise InvalidDataError(f'CTrapezoid has spurious height entry: {height}') if width is not None and height is not None: - if ctrapezoid_type in range(0, 4) and width < height: + if ctrapezoid_type in range(0, 4) and width < height: # noqa: PIE808 raise InvalidDataError(f'CTrapezoid has width < height ({width} < {height})') if ctrapezoid_type in range(4, 8) and width < 2 * height: raise InvalidDataError(f'CTrapezoid has width < 2*height ({width} < 2 * {height})') @@ -2401,7 +2397,7 @@ class CTrapezoid(Record, GeometryMixin): if ctrapezoid_type in range(12, 16) and 2 * width > height: raise InvalidDataError(f'CTrapezoid has 2*width > height ({width} > 2 * {height})') - if ctrapezoid_type is not None and ctrapezoid_type not in range(0, 26): + if ctrapezoid_type is not None and ctrapezoid_type not in range(0, 26): # noqa: PIE808 raise InvalidDataError(f'CTrapezoid has invalid type: {ctrapezoid_type}') @@ -2532,8 +2528,7 @@ def adjust_repetition(record, modals: Modals) -> None: if isinstance(record.repetition, ReuseRepetition): if modals.repetition is None: raise InvalidDataError('Unfillable repetition') - else: - record.repetition = copy.copy(modals.repetition) + record.repetition = copy.copy(modals.repetition) else: modals.repetition = copy.copy(record.repetition) @@ -2679,20 +2674,18 @@ def dedup_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> N if modals.xy_relative: record.x -= mx setattr(modals, mx_field, record.x) + elif record.x == mx: + record.x = None else: - if record.x == mx: - record.x = None - else: - setattr(modals, mx_field, record.x) + setattr(modals, mx_field, record.x) if record.y is not None: my = getattr(modals, my_field) if modals.xy_relative: record.y -= my setattr(modals, my_field, record.y) + elif record.y == my: + record.y = None else: - if record.y == my: - record.y = None - else: - setattr(modals, my_field, record.y) + setattr(modals, my_field, record.y) diff --git a/fatamorgana/test/test_files_ctrapezoids.py b/fatamorgana/test/test_files_ctrapezoids.py index c831146..d0429bf 100644 --- a/fatamorgana/test/test_files_ctrapezoids.py +++ b/fatamorgana/test/test_files_ctrapezoids.py @@ -135,13 +135,12 @@ def test_file_1() -> None: assert gg.width == [250, None][is_ctrapz], msg elif ct_type in range(22, 24) or ct_type == 25: assert gg.height == [100, None][is_ctrapz], msg + elif ct_type < 8 or 16 <= ct_type < 25 or ct_type >= 26: + assert gg.width == 250, msg + assert gg.height == 100, msg else: - if ct_type < 8 or 16 <= ct_type < 25 or 26 <= ct_type: - assert gg.width == 250, msg - assert gg.height == 100, msg - else: - assert gg.width == 100, msg - assert gg.height == 250, msg + assert gg.width == 100, msg + assert gg.height == 250, msg elif ii < 3 and ii % 2: assert gg.ctrapezoid_type == 24, msg elif ii == 55: