inire/inire/geometry/dynamic_path_index.py

97 lines
3.4 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
import numpy
import rtree
from shapely.strtree import STRtree
from inire.geometry.index_helpers import build_index_payload, iter_grid_cells
if TYPE_CHECKING:
from collections.abc import Sequence
from shapely.geometry import Polygon
from inire.geometry.collision import RoutingWorld
class DynamicPathIndex:
__slots__ = (
"engine",
"index",
"geometries",
"dilated",
"tree",
"obj_ids",
"grid",
"id_counter",
"net_ids_array",
"bounds_array",
)
def __init__(self, engine: RoutingWorld) -> None:
self.engine = engine
self.index = rtree.index.Index()
self.geometries: dict[int, tuple[str, Polygon]] = {}
self.dilated: dict[int, Polygon] = {}
self.tree: STRtree | None = None
self.obj_ids: numpy.ndarray = numpy.array([], dtype=numpy.int32)
self.grid: dict[tuple[int, int], list[int]] = {}
self.id_counter = 0
self.net_ids_array = numpy.array([], dtype=object)
self.bounds_array = numpy.array([], dtype=numpy.float64).reshape(0, 4)
def invalidate_queries(self) -> None:
self.tree = None
self.grid = {}
def ensure_tree(self) -> None:
if self.tree is None and self.dilated:
if self.engine.metrics is not None:
self.engine.metrics.total_dynamic_tree_rebuilds += 1
ids, geometries, bounds_array = build_index_payload(self.dilated)
self.tree = STRtree(geometries)
self.obj_ids = numpy.array(ids, dtype=numpy.int32)
self.bounds_array = bounds_array
net_ids = [self.geometries[obj_id][0] for obj_id in self.obj_ids]
self.net_ids_array = numpy.array(net_ids, dtype=object)
def ensure_grid(self) -> None:
if self.grid or not self.dilated:
return
if self.engine.metrics is not None:
self.engine.metrics.total_dynamic_grid_rebuilds += 1
cell_size = self.engine.grid_cell_size
for obj_id, polygon in self.dilated.items():
for cell in iter_grid_cells(polygon.bounds, cell_size):
self.grid.setdefault(cell, []).append(obj_id)
def add_path(self, net_id: str, geometry: Sequence[Polygon], dilated_geometry: Sequence[Polygon]) -> None:
self.invalidate_queries()
if self.engine.metrics is not None:
self.engine.metrics.total_dynamic_path_objects_added += len(geometry)
for index, polygon in enumerate(geometry):
obj_id = self.id_counter
self.id_counter += 1
dilated = dilated_geometry[index]
self.geometries[obj_id] = (net_id, polygon)
self.dilated[obj_id] = dilated
self.index.insert(obj_id, dilated.bounds)
def remove_path(self, net_id: str) -> None:
to_remove = [obj_id for obj_id, (existing_net_id, _) in self.geometries.items() if existing_net_id == net_id]
self.remove_obj_ids(to_remove)
def remove_obj_ids(self, obj_ids: list[int]) -> None:
if not obj_ids:
return
self.invalidate_queries()
if self.engine.metrics is not None:
self.engine.metrics.total_dynamic_path_objects_removed += len(obj_ids)
for obj_id in obj_ids:
self.index.delete(obj_id, self.dilated[obj_id].bounds)
del self.geometries[obj_id]
del self.dilated[obj_id]