inire/inire/geometry/static_obstacle_index.py

132 lines
4.6 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
import numpy
import rtree
from shapely.strtree import STRtree
from inire.geometry.index_helpers import build_index_payload, is_axis_aligned_rect
if TYPE_CHECKING:
from shapely.geometry import Polygon
from inire.geometry.collision import RoutingWorld
class StaticObstacleIndex:
__slots__ = (
"engine",
"index",
"geometries",
"dilated",
"is_rect",
"tree",
"obj_ids",
"bounds_array",
"is_rect_array",
"raw_tree",
"raw_obj_ids",
"net_specific_trees",
"net_specific_is_rect",
"net_specific_bounds",
"id_counter",
"version",
)
def __init__(self, engine: RoutingWorld) -> None:
self.engine = engine
self.index = rtree.index.Index()
self.geometries: dict[int, Polygon] = {}
self.dilated: dict[int, Polygon] = {}
self.is_rect: dict[int, bool] = {}
self.tree: STRtree | None = None
self.obj_ids: list[int] = []
self.bounds_array: numpy.ndarray | None = None
self.is_rect_array: numpy.ndarray | None = None
self.raw_tree: STRtree | None = None
self.raw_obj_ids: list[int] = []
self.net_specific_trees: dict[tuple[float, float], STRtree] = {}
self.net_specific_is_rect: dict[tuple[float, float], numpy.ndarray] = {}
self.net_specific_bounds: dict[tuple[float, float], numpy.ndarray] = {}
self.id_counter = 0
self.version = 0
def add_obstacle(self, polygon: Polygon, dilated_geometry: Polygon | None = None) -> int:
obj_id = self.id_counter
self.id_counter += 1
if dilated_geometry is not None:
dilated = dilated_geometry
else:
dilated = polygon.buffer(self.engine.clearance / 2.0, join_style="mitre")
self.geometries[obj_id] = polygon
self.dilated[obj_id] = dilated
self.is_rect[obj_id] = is_axis_aligned_rect(dilated)
self.index.insert(obj_id, dilated.bounds)
self.invalidate_caches()
return obj_id
def remove_obstacle(self, obj_id: int) -> None:
if obj_id not in self.geometries:
return
bounds = self.dilated[obj_id].bounds
self.index.delete(obj_id, bounds)
del self.geometries[obj_id]
del self.dilated[obj_id]
del self.is_rect[obj_id]
self.invalidate_caches()
def invalidate_caches(self) -> None:
self.tree = None
self.bounds_array = None
self.is_rect_array = None
self.obj_ids = []
self.raw_tree = None
self.raw_obj_ids = []
self.net_specific_trees.clear()
self.net_specific_is_rect.clear()
self.net_specific_bounds.clear()
self.version += 1
def ensure_tree(self) -> None:
if self.tree is None and self.dilated:
if self.engine.metrics is not None:
self.engine.metrics.total_static_tree_rebuilds += 1
self.obj_ids, geometries, self.bounds_array = build_index_payload(self.dilated)
self.tree = STRtree(geometries)
self.is_rect_array = numpy.array([self.is_rect[i] for i in self.obj_ids])
def ensure_net_tree(self, net_width: float) -> STRtree:
key = (round(net_width, 4), round(self.engine.clearance, 4))
if key in self.net_specific_trees:
return self.net_specific_trees[key]
if self.engine.metrics is not None:
self.engine.metrics.total_static_net_tree_rebuilds += 1
total_dilation = net_width / 2.0 + self.engine.clearance
geometries = []
is_rect_list = []
bounds_list = []
for obj_id in sorted(self.geometries.keys()):
polygon = self.geometries[obj_id]
dilated = polygon.buffer(total_dilation, join_style="mitre")
geometries.append(dilated)
bounds_list.append(dilated.bounds)
is_rect_list.append(is_axis_aligned_rect(dilated))
tree = STRtree(geometries)
self.net_specific_trees[key] = tree
self.net_specific_is_rect[key] = numpy.array(is_rect_list, dtype=bool)
self.net_specific_bounds[key] = numpy.array(bounds_list, dtype=numpy.float64)
return tree
def ensure_raw_tree(self) -> None:
if self.raw_tree is None and self.geometries:
if self.engine.metrics is not None:
self.engine.metrics.total_static_raw_tree_rebuilds += 1
self.raw_obj_ids, geometries, _bounds_array = build_index_payload(self.geometries)
self.raw_tree = STRtree(geometries)