flatten and simplify conditionals

This commit is contained in:
Jan Petykiewicz 2024-07-29 18:06:59 -07:00
parent 891007054f
commit dc9ed8e794
4 changed files with 93 additions and 113 deletions

View File

@ -950,7 +950,7 @@ class OctangularDelta:
sign = self.octangle & 0x02 > 0 sign = self.octangle & 0x02 > 0
xy[axis] = self.proj_mag * (1 - 2 * sign) xy[axis] = self.proj_mag * (1 - 2 * sign)
return xy return xy
else: else: # noqa: RET505
yn = (self.octangle & 0x02) > 0 yn = (self.octangle & 0x02) > 0
xyn = (self.octangle & 0x01) > 0 xyn = (self.octangle & 0x01) > 0
ys = 1 - 2 * yn ys = 1 - 2 * yn
@ -1097,10 +1097,9 @@ class Delta:
""" """
if self.x == 0 or self.y == 0 or abs(self.x) == abs(self.y): if self.x == 0 or self.y == 0 or abs(self.x) == abs(self.y):
return write_uint(stream, OctangularDelta(self.x, self.y).as_uint() << 1) return write_uint(stream, OctangularDelta(self.x, self.y).as_uint() << 1)
else: size = write_uint(stream, (encode_sint(self.x) << 1) | 0x01)
size = write_uint(stream, (encode_sint(self.x) << 1) | 0x01) size += write_uint(stream, encode_sint(self.y))
size += write_uint(stream, encode_sint(self.y)) return size
return size
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return hasattr(other, 'as_list') and self.as_list() == other.as_list() return hasattr(other, 'as_list') and self.as_list() == other.as_list()
@ -1125,12 +1124,11 @@ def read_repetition(stream: IO[bytes]) -> repetition_t:
rtype = read_uint(stream) rtype = read_uint(stream)
if rtype == 0: if rtype == 0:
return ReuseRepetition.read(stream, rtype) return ReuseRepetition.read(stream, rtype)
elif rtype in (1, 2, 3, 8, 9): if rtype in (1, 2, 3, 8, 9):
return GridRepetition.read(stream, rtype) return GridRepetition.read(stream, rtype)
elif rtype in (4, 5, 6, 7, 10, 11): if rtype in (4, 5, 6, 7, 10, 11):
return ArbitraryRepetition.read(stream, rtype) return ArbitraryRepetition.read(stream, rtype)
else: raise InvalidDataError(f'Unexpected repetition type: {rtype}')
raise InvalidDataError(f'Unexpected repetition type: {rtype}')
def write_repetition(stream: IO[bytes], repetition: repetition_t) -> int: def write_repetition(stream: IO[bytes], repetition: repetition_t) -> int:
@ -1311,7 +1309,7 @@ class GridRepetition:
size = write_uint(stream, 9) size = write_uint(stream, 9)
size += write_uint(stream, self.a_count - 2) size += write_uint(stream, self.a_count - 2)
size += Delta(*self.a_vector).write(stream) size += Delta(*self.a_vector).write(stream)
else: else: # noqa: PLR5501
if self.a_vector[1] == 0 and self.b_vector[0] == 0: if self.a_vector[1] == 0 and self.b_vector[0] == 0:
size = write_uint(stream, 1) size = write_uint(stream, 1)
size += write_uint(stream, self.a_count - 2) size += write_uint(stream, self.a_count - 2)
@ -1637,11 +1635,10 @@ def write_point_list(
h_first = False h_first = False
v_first = False v_first = False
break break
else: elif point[1] != previous[1] or point[0] == previous[0]:
if point[1] != previous[1] or point[0] == previous[0]: h_first = False
h_first = False v_first = False
v_first = False break
break
previous = point previous = point
# If one of h_first or v_first, write a bunch of 1-deltas # If one of h_first or v_first, write a bunch of 1-deltas
@ -1650,7 +1647,7 @@ def write_point_list(
size += write_uint(stream, len(points)) size += write_uint(stream, len(points))
size += sum(write_sint(stream, x + y) for x, y in points) size += sum(write_sint(stream, x + y) for x, y in points)
return size return size
elif v_first: if v_first:
size = write_uint(stream, 1) size = write_uint(stream, 1)
size += write_uint(stream, len(points)) size += write_uint(stream, len(points))
size += sum(write_sint(stream, x + y) for x, y in points) size += sum(write_sint(stream, x + y) for x, y in points)
@ -1773,30 +1770,29 @@ def read_property_value(stream: IO[bytes]) -> property_value_t:
prop_type = read_uint(stream) prop_type = read_uint(stream)
if 0 <= prop_type <= 7: if 0 <= prop_type <= 7:
return read_real(stream, prop_type) return read_real(stream, prop_type)
elif prop_type == 8: if prop_type == 8:
return read_uint(stream) return read_uint(stream)
elif prop_type == 9: if prop_type == 9:
return read_sint(stream) return read_sint(stream)
elif prop_type == 10: if prop_type == 10:
return AString.read(stream) return AString.read(stream)
elif prop_type == 11: if prop_type == 11:
return read_bstring(stream) return read_bstring(stream)
elif prop_type == 12: if prop_type == 12:
return NString.read(stream) return NString.read(stream)
elif prop_type == 13: if prop_type == 13:
ref_type = AString ref_type = AString
ref = read_uint(stream) ref = read_uint(stream)
return PropStringReference(ref, ref_type) return PropStringReference(ref, ref_type)
elif prop_type == 14: if prop_type == 14:
ref_type = bytes ref_type = bytes
ref = read_uint(stream) ref = read_uint(stream)
return PropStringReference(ref, ref_type) return PropStringReference(ref, ref_type)
elif prop_type == 15: if prop_type == 15:
ref_type = NString ref_type = NString
ref = read_uint(stream) ref = read_uint(stream)
return PropStringReference(ref, ref_type) return PropStringReference(ref, ref_type)
else: raise InvalidDataError(f'Invalid property type: {prop_type}')
raise InvalidDataError(f'Invalid property type: {prop_type}')
def write_property_value( def write_property_value(
@ -1883,17 +1879,16 @@ def read_interval(stream: IO[bytes]) -> tuple[int | None, int | None]:
interval_type = read_uint(stream) interval_type = read_uint(stream)
if interval_type == 0: if interval_type == 0:
return None, None return None, None
elif interval_type == 1: if interval_type == 1:
return None, read_uint(stream) return None, read_uint(stream)
elif interval_type == 2: if interval_type == 2:
return read_uint(stream), None return read_uint(stream), None
elif interval_type == 3: if interval_type == 3:
v = read_uint(stream) v = read_uint(stream)
return v, v return v, v
elif interval_type == 4: if interval_type == 4:
return read_uint(stream), read_uint(stream) return read_uint(stream), read_uint(stream)
else: raise InvalidDataError(f'Unrecognized interval type: {interval_type}')
raise InvalidDataError(f'Unrecognized interval type: {interval_type}')
def write_interval( def write_interval(
@ -1916,18 +1911,15 @@ def write_interval(
if min_bound is None: if min_bound is None:
if max_bound is None: if max_bound is None:
return write_uint(stream, 0) return write_uint(stream, 0)
else: return write_uint(stream, 1) + write_uint(stream, max_bound)
return write_uint(stream, 1) + write_uint(stream, max_bound) if max_bound is None:
else: return write_uint(stream, 2) + write_uint(stream, min_bound)
if max_bound is None: if min_bound == max_bound:
return write_uint(stream, 2) + write_uint(stream, min_bound) return write_uint(stream, 3) + write_uint(stream, min_bound)
elif min_bound == max_bound: size = write_uint(stream, 4)
return write_uint(stream, 3) + write_uint(stream, min_bound) size += write_uint(stream, min_bound)
else: size += write_uint(stream, max_bound)
size = write_uint(stream, 4) return size
size += write_uint(stream, min_bound)
size += write_uint(stream, max_bound)
return size
class OffsetEntry: class OffsetEntry:
@ -2190,9 +2182,7 @@ class Validation:
checksum_type = read_uint(stream) checksum_type = read_uint(stream)
if checksum_type == 0: if checksum_type == 0:
checksum = None checksum = None
elif checksum_type == 1: elif checksum_type in (1, 2):
checksum = read_u32(stream)
elif checksum_type == 2:
checksum = read_u32(stream) checksum = read_u32(stream)
else: else:
raise InvalidDataError('Invalid validation type!') raise InvalidDataError('Invalid validation type!')
@ -2214,14 +2204,13 @@ class Validation:
""" """
if self.checksum_type == 0: if self.checksum_type == 0:
return write_uint(stream, 0) return write_uint(stream, 0)
elif self.checksum is None: if self.checksum is None:
raise InvalidDataError(f'Checksum is empty but type is {self.checksum_type}') raise InvalidDataError(f'Checksum is empty but type is {self.checksum_type}')
elif self.checksum_type == 1: if self.checksum_type == 1:
return write_uint(stream, 1) + write_u32(stream, self.checksum) return write_uint(stream, 1) + write_u32(stream, self.checksum)
elif self.checksum_type == 2: if self.checksum_type == 2:
return write_uint(stream, 2) + write_u32(stream, self.checksum) return write_uint(stream, 2) + write_u32(stream, self.checksum)
else: raise InvalidDataError(f'Unrecognized checksum type: {self.checksum_type}')
raise InvalidDataError(f'Unrecognized checksum type: {self.checksum_type}')
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Validation(type: {self.checksum_type} sum: {self.checksum})' return f'Validation(type: {self.checksum_type} sum: {self.checksum})'

View File

@ -193,8 +193,7 @@ class OasisLayout:
if record_id == 1: if record_id == 1:
if file_state.started: if file_state.started:
raise InvalidRecordError('Duplicate Start record') raise InvalidRecordError('Duplicate Start record')
else: file_state.started = True
file_state.started = True
if record_id == 2 and file_state.within_cblock: if record_id == 2 and file_state.within_cblock:
raise InvalidRecordError('End within CBlock') raise InvalidRecordError('End within CBlock')

View File

@ -274,10 +274,9 @@ def read_refname(
""" """
if not is_present: if not is_present:
return None return None
elif is_reference: if is_reference:
return read_uint(stream) return read_uint(stream)
else: return NString.read(stream)
return NString.read(stream)
def read_refstring( def read_refstring(
@ -299,10 +298,9 @@ def read_refstring(
""" """
if not is_present: if not is_present:
return None return None
elif is_reference: if is_reference:
return read_uint(stream) return read_uint(stream)
else: return AString.read(stream)
return AString.read(stream)
class Pad(Record): class Pad(Record):
@ -994,32 +992,33 @@ class Property(Record):
def write(self, stream: IO[bytes]) -> int: def write(self, stream: IO[bytes]) -> int:
if self.is_standard is None and self.values is None and self.name is None: if self.is_standard is None and self.values is None and self.name is None:
return write_uint(stream, 29) return write_uint(stream, 29)
if self.is_standard is None:
raise InvalidDataError('Property has value or name, but no is_standard flag!')
if self.values is not None:
value_count = len(self.values)
vv = 0
uu = 0x0f if value_count >= 0x0f else value_count
else: else:
if self.is_standard is None: vv = 1
raise InvalidDataError('Property has value or name, but no is_standard flag!') uu = 0
if self.values is not None:
value_count = len(self.values) cc = self.name is not None
vv = 0 nn = cc and isinstance(self.name, int)
uu = 0x0f if value_count >= 0x0f else value_count ss = self.is_standard
size = write_uint(stream, 28)
size += write_byte(stream, (uu << 4) | (vv << 3) | (cc << 2) | (nn << 1) | ss)
if cc:
if nn:
size += write_uint(stream, self.name) # type: ignore
else: else:
vv = 1 size += self.name.write(stream) # type: ignore
uu = 0 if not vv:
if uu == 0x0f:
cc = self.name is not None size += write_uint(stream, len(self.values)) # type: ignore
nn = cc and isinstance(self.name, int) size += sum(write_property_value(stream, pp) for pp in self.values) # type: ignore
ss = self.is_standard
size = write_uint(stream, 28)
size += write_byte(stream, (uu << 4) | (vv << 3) | (cc << 2) | (nn << 1) | ss)
if cc:
if nn:
size += write_uint(stream, self.name) # type: ignore
else:
size += self.name.write(stream) # type: ignore
if not vv:
if uu == 0x0f:
size += write_uint(stream, len(self.values)) # type: ignore
size += sum(write_property_value(stream, pp) for pp in self.values) # type: ignore
return size return size
@ -1736,9 +1735,8 @@ class Polygon(Record, GeometryMixin):
self.point_list = point_list self.point_list = point_list
self.properties = [] if properties is None else properties self.properties = [] if properties is None else properties
if point_list is not None: if point_list is not None and len(point_list) < 3:
if len(point_list) < 3: warn('Polygon with < 3 points', stacklevel=2)
warn('Polygon with < 3 points')
def get_point_list(self) -> point_list_t: def get_point_list(self) -> point_list_t:
return verify_modal(self.point_list) return verify_modal(self.point_list)
@ -1921,14 +1919,13 @@ class Path(Record, GeometryMixin):
def get_pathext(ext_scheme: int) -> pathextension_t | None: def get_pathext(ext_scheme: int) -> pathextension_t | None:
if ext_scheme == 0: if ext_scheme == 0:
return None return None
elif ext_scheme == 1: if ext_scheme == 1:
return PathExtensionScheme.Flush, None return PathExtensionScheme.Flush, None
elif ext_scheme == 2: if ext_scheme == 2:
return PathExtensionScheme.HalfWidth, None return PathExtensionScheme.HalfWidth, None
elif ext_scheme == 3: if ext_scheme == 3:
return PathExtensionScheme.Arbitrary, read_sint(stream) return PathExtensionScheme.Arbitrary, read_sint(stream)
else: raise InvalidDataError(f'Invalid ext_scheme: {ext_scheme}')
raise InvalidDataError(f'Invalid ext_scheme: {ext_scheme}')
optional['extension_start'] = get_pathext(scheme_start) optional['extension_start'] = get_pathext(scheme_start)
optional['extension_end'] = get_pathext(scheme_end) optional['extension_end'] = get_pathext(scheme_end)
@ -2066,9 +2063,8 @@ class Trapezoid(Record, GeometryMixin):
if self.is_vertical: if self.is_vertical:
if height is not None and delta_b - delta_a > height: if height is not None and delta_b - delta_a > height:
raise InvalidDataError(f'Trapezoid: h < delta_b - delta_a ({height} < {delta_b} - {delta_a})') raise InvalidDataError(f'Trapezoid: h < delta_b - delta_a ({height} < {delta_b} - {delta_a})')
else: elif width is not None and delta_b - delta_a > width:
if width is not None and delta_b - delta_a > width: raise InvalidDataError(f'Trapezoid: w < delta_b - delta_a ({width} < {delta_b} - {delta_a})')
raise InvalidDataError(f'Trapezoid: w < delta_b - delta_a ({width} < {delta_b} - {delta_a})')
def get_is_vertical(self) -> bool: def get_is_vertical(self) -> bool:
return verify_modal(self.is_vertical) return verify_modal(self.is_vertical)
@ -2392,7 +2388,7 @@ class CTrapezoid(Record, GeometryMixin):
raise InvalidDataError(f'CTrapezoid has spurious height entry: {height}') raise InvalidDataError(f'CTrapezoid has spurious height entry: {height}')
if width is not None and height is not None: if width is not None and height is not None:
if ctrapezoid_type in range(0, 4) and width < height: if ctrapezoid_type in range(0, 4) and width < height: # noqa: PIE808
raise InvalidDataError(f'CTrapezoid has width < height ({width} < {height})') raise InvalidDataError(f'CTrapezoid has width < height ({width} < {height})')
if ctrapezoid_type in range(4, 8) and width < 2 * height: if ctrapezoid_type in range(4, 8) and width < 2 * height:
raise InvalidDataError(f'CTrapezoid has width < 2*height ({width} < 2 * {height})') raise InvalidDataError(f'CTrapezoid has width < 2*height ({width} < 2 * {height})')
@ -2401,7 +2397,7 @@ class CTrapezoid(Record, GeometryMixin):
if ctrapezoid_type in range(12, 16) and 2 * width > height: if ctrapezoid_type in range(12, 16) and 2 * width > height:
raise InvalidDataError(f'CTrapezoid has 2*width > height ({width} > 2 * {height})') raise InvalidDataError(f'CTrapezoid has 2*width > height ({width} > 2 * {height})')
if ctrapezoid_type is not None and ctrapezoid_type not in range(0, 26): if ctrapezoid_type is not None and ctrapezoid_type not in range(0, 26): # noqa: PIE808
raise InvalidDataError(f'CTrapezoid has invalid type: {ctrapezoid_type}') raise InvalidDataError(f'CTrapezoid has invalid type: {ctrapezoid_type}')
@ -2532,8 +2528,7 @@ def adjust_repetition(record, modals: Modals) -> None:
if isinstance(record.repetition, ReuseRepetition): if isinstance(record.repetition, ReuseRepetition):
if modals.repetition is None: if modals.repetition is None:
raise InvalidDataError('Unfillable repetition') raise InvalidDataError('Unfillable repetition')
else: record.repetition = copy.copy(modals.repetition)
record.repetition = copy.copy(modals.repetition)
else: else:
modals.repetition = copy.copy(record.repetition) modals.repetition = copy.copy(record.repetition)
@ -2679,20 +2674,18 @@ def dedup_coordinates(record, modals: Modals, mx_field: str, my_field: str) -> N
if modals.xy_relative: if modals.xy_relative:
record.x -= mx record.x -= mx
setattr(modals, mx_field, record.x) setattr(modals, mx_field, record.x)
elif record.x == mx:
record.x = None
else: else:
if record.x == mx: setattr(modals, mx_field, record.x)
record.x = None
else:
setattr(modals, mx_field, record.x)
if record.y is not None: if record.y is not None:
my = getattr(modals, my_field) my = getattr(modals, my_field)
if modals.xy_relative: if modals.xy_relative:
record.y -= my record.y -= my
setattr(modals, my_field, record.y) setattr(modals, my_field, record.y)
elif record.y == my:
record.y = None
else: else:
if record.y == my: setattr(modals, my_field, record.y)
record.y = None
else:
setattr(modals, my_field, record.y)

View File

@ -135,13 +135,12 @@ def test_file_1() -> None:
assert gg.width == [250, None][is_ctrapz], msg assert gg.width == [250, None][is_ctrapz], msg
elif ct_type in range(22, 24) or ct_type == 25: elif ct_type in range(22, 24) or ct_type == 25:
assert gg.height == [100, None][is_ctrapz], msg assert gg.height == [100, None][is_ctrapz], msg
elif ct_type < 8 or 16 <= ct_type < 25 or ct_type >= 26:
assert gg.width == 250, msg
assert gg.height == 100, msg
else: else:
if ct_type < 8 or 16 <= ct_type < 25 or 26 <= ct_type: assert gg.width == 100, msg
assert gg.width == 250, msg assert gg.height == 250, msg
assert gg.height == 100, msg
else:
assert gg.width == 100, msg
assert gg.height == 250, msg
elif ii < 3 and ii % 2: elif ii < 3 and ii % 2:
assert gg.ctrapezoid_type == 24, msg assert gg.ctrapezoid_type == 24, msg
elif ii == 55: elif ii == 55: