rstar-tree compiles, but has some missing impls

This commit is contained in:
joe 2026-01-20 22:39:46 -08:00
parent 228c6ff976
commit d8f1ffd801
2 changed files with 169 additions and 550 deletions

View file

@ -1,8 +1,9 @@
use std::cmp::Ordering;
use bevy::prelude::*;
use bevy::{math::bounding::Aabb2d, prelude::*};
use ordered_float::OrderedFloat;
use spart::{geometry::BoundingVolume, rstar_tree::RStarTreeObject};
const POINT_RADIUS: f32 = f32::EPSILON * 16.0;
#[derive(Debug, Clone)]
pub struct Point {
@ -29,35 +30,8 @@ impl PartialOrd for Point {
}
}
impl RStarTreeObject for Point {
type B = PointBox;
fn mbr(&self) -> Self::B {
todo!()
}
}
#[derive(Debug, Clone)]
pub struct PointBox;
impl BoundingVolume for PointBox {
fn area(&self) -> f64 {
todo!()
}
fn union(&self, other: &Self) -> Self {
todo!()
}
fn intersects(&self, other: &Self) -> bool {
todo!()
}
fn overlap(&self, other: &Self) -> f64 {
todo!()
}
fn margin(&self) -> f64 {
todo!()
impl Point {
pub fn mbr(&self) -> Aabb2d {
Aabb2d::new(self.point, Vec2::splat(POINT_RADIUS))
}
}

View file

@ -1,11 +1,13 @@
use crate::geom::Point;
use bevy::math::bounding::Aabb2d;
use bevy::{
math::bounding::{Aabb2d, BoundingVolume, IntersectsVolume},
prelude::*,
};
use ordered_float::OrderedFloat;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::{cmp::Ordering, collections::BinaryHeap};
// Epsilon value for zero-sizes bounding boxes/cubes.
const EPSILON: f64 = 1e-10;
const EPSILON: f32 = 1e-10;
/// An entry in the R*tree, which can be either a leaf or a node.
#[derive(Debug, Clone)]
@ -52,27 +54,24 @@ impl Entry {
_ => None,
}
}
fn child(&self) -> Option<&Box> {
fn child(&self) -> Option<&TreeNode> {
match self {
Entry::Node { child, .. } => Some(child),
_ => None,
}
}
fn child_mut(&mut self) -> Option<&mut <Self as crate::rtree_common::EntryAccess>::Node> {
fn child_mut(&mut self) -> Option<&mut TreeNode> {
match self {
Entry::Node { child, .. } => Some(child),
_ => None,
}
}
fn set_mbr(&mut self, new_mbr: Self::BV) {
fn set_mbr(&mut self, new_mbr: Aabb2d) {
if let Entry::Node { mbr, .. } = self {
*mbr = new_mbr;
}
}
fn into_child(self) -> Option<Box<<Self as crate::rtree_common::EntryAccess>::Node>>
where
Self: Sized,
{
fn into_child(self) -> Option<TreeNode> {
match self {
Entry::Node { child, .. } => Some(child),
_ => None,
@ -80,44 +79,55 @@ impl Entry {
}
}
impl<T: RStarTreeObject> crate::rtree_common::NodeAccess for TreeNode<T> {
type Entry = Entry<T>;
impl TreeNode {
fn is_leaf(&self) -> bool {
self.is_leaf
}
fn entries(&self) -> &Vec<Self::Entry> {
fn entries(&self) -> &[Entry] {
&self.entries
}
fn entries_mut(&mut self) -> &mut Vec<Self::Entry> {
fn entries_mut(&mut self) -> &mut [Entry] {
&mut self.entries
}
fn range_search_bbox(&self, bbox: &Aabb2d) -> Vec<&Point> {
let mut result = Vec::new();
if self.is_leaf() {
for entry in self.entries() {
if let Some(obj) = entry.as_leaf_obj() {
if entry.mbr().intersects(bbox) {
result.push(obj);
}
}
}
} else {
for entry in self.entries() {
if let Some(child) = entry.child() {
if entry.mbr().intersects(bbox) {
result.extend_from_slice(&child.range_search_bbox(bbox));
}
}
}
}
result
}
fn mbr(&self) -> Option<Aabb2d> {
entries_mbr(self.entries())
}
}
impl<T: RStarTreeObject> RStarTree<T> {
/// Creates a new R*tree with the specified maximum number of entries per node.
///
/// # Arguments
///
/// * `max_entries` - The maximum number of entries allowed in a node.
///
/// # Errors
///
/// Returns `SpartError::InvalidCapacity` if `max_entries` is less than 2.
pub fn new(max_entries: usize) -> Result<Self, SpartError> {
if max_entries < 2 {
return Err(SpartError::InvalidCapacity {
capacity: max_entries,
});
}
info!("Creating new RStarTree with max_entries: {}", max_entries);
Ok(RStarTree {
impl RStarTree {
pub fn new(max_entries: usize) -> Self {
let max_entries = max_entries.max(2);
RStarTree {
root: TreeNode {
entries: Vec::new(),
is_leaf: true,
},
max_entries,
min_entries: (max_entries as f64 * 0.4).ceil() as usize,
})
min_entries: (max_entries as f32 * 0.4).ceil() as usize,
}
}
/// Inserts an object into the R*tree.
@ -125,12 +135,7 @@ impl<T: RStarTreeObject> RStarTree<T> {
/// # Arguments
///
/// * `object` - The object to insert.
pub fn insert(&mut self, object: T)
where
T: Clone,
T::B: BSPBounds,
{
info!("Inserting object into RStarTree: {:?}", object);
pub fn insert(&mut self, object: Point) {
let entry = Entry::Leaf {
mbr: object.mbr(),
object,
@ -138,14 +143,8 @@ impl<T: RStarTreeObject> RStarTree<T> {
self.insert_entry(entry, None);
}
fn insert_entry(&mut self, entry: Entry<T>, reinsert_from_level: Option<usize>)
where
T: Clone,
T::B: BSPBounds,
{
fn insert_entry(&mut self, entry: Entry, mut reinsert_level: Option<usize>) {
let mut to_insert = vec![(entry, 0)];
let mut reinsert_level = reinsert_from_level;
while let Some((item, level)) = to_insert.pop() {
let overflow = insert_recursive(
&mut self.root,
@ -160,6 +159,14 @@ impl<T: RStarTreeObject> RStarTree<T> {
if reinsert_level == Some(overflow_level) {
let old_entries = overflowed_node;
let (group1, group2) = split_entries(old_entries, self.max_entries);
let mut mbr1 = group1[0].mbr().clone();
for entry in group1.iter() {
mbr1 = mbr1.merge(entry.mbr());
}
let mut mbr2 = group2[0].mbr().clone();
for entry in group2.iter() {
mbr2 = mbr2.merge(entry.mbr());
}
let child1 = TreeNode {
entries: group1,
is_leaf: self.root.is_leaf,
@ -168,19 +175,16 @@ impl<T: RStarTreeObject> RStarTree<T> {
entries: group2,
is_leaf: self.root.is_leaf,
};
let mbr1 = common_compute_group_mbr(&child1.entries)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr2 = common_compute_group_mbr(&child2.entries)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
self.root.is_leaf = false;
self.root.entries.clear();
self.root.entries.push(Entry::Node {
mbr: mbr1,
child: Box::new(child1),
child: child1,
});
self.root.entries.push(Entry::Node {
mbr: mbr2,
child: Box::new(child2),
child: child2,
});
} else {
if reinsert_level.is_none() {
@ -209,11 +213,8 @@ impl<T: RStarTreeObject> RStarTree<T> {
/// # Returns
///
/// A vector of references to the objects whose minimum bounding volumes intersect the query.
pub fn range_search_bbox(&self, query: &T::B) -> Vec<&T> {
info!("Performing range search with query: {:?}", query);
let mut result = Vec::new();
common_search_node(&self.root, query, &mut result);
result
pub fn range_search_bbox(&self, query_bbox: &Aabb2d) -> Vec<&Point> {
self.root.range_search_bbox(query_bbox)
}
/// Inserts a bulk of objects into the R*-tree.
@ -221,20 +222,16 @@ impl<T: RStarTreeObject> RStarTree<T> {
/// # Arguments
///
/// * `objects` - The objects to insert.
pub fn insert_bulk(&mut self, objects: Vec<T>)
where
T: Clone,
T::B: BSPBounds,
{
pub fn insert_bulk(&mut self, objects: Vec<Point>) {
if objects.is_empty() {
return;
}
let mut entries: Vec<Entry<T>> = objects
let mut entries: Vec<Entry> = objects
.into_iter()
.map(|obj| Entry::Leaf {
mbr: obj.mbr(),
object: obj,
.map(|point| Entry::Leaf {
mbr: point.mbr(),
object: point,
})
.collect();
@ -243,15 +240,12 @@ impl<T: RStarTreeObject> RStarTree<T> {
let chunks = entries.chunks(self.max_entries);
for chunk in chunks {
let child_node = TreeNode {
let child = TreeNode {
entries: chunk.to_vec(),
is_leaf: self.root.is_leaf,
};
if let Some(mbr) = common_compute_group_mbr(&child_node.entries) {
new_level_entries.push(Entry::Node {
mbr,
child: Box::new(child_node),
});
if let Some(mbr) = child.mbr() {
new_level_entries.push(Entry::Node { mbr, child });
}
}
entries = new_level_entries;
@ -277,7 +271,7 @@ impl<T: RStarTreeObject> RStarTree<T> {
}
}
fn choose_subtree<T: RStarTreeObject>(node: &TreeNode<T>, entry: &Entry<T>) -> usize {
fn choose_subtree(node: &TreeNode, entry: &Entry) -> usize {
let children_are_leaves = if let Some(Entry::Node { child, .. }) = node.entries.first() {
child.is_leaf
} else {
@ -296,23 +290,23 @@ fn choose_subtree<T: RStarTreeObject>(node: &TreeNode<T>, entry: &Entry<T>) -> u
.entries
.iter()
.filter(|e| !std::ptr::eq(*e, a))
.map(|e| e.mbr().union(entry.mbr()).overlap(e.mbr()))
.sum::<f64>();
.map(|e| e.mbr().merge(entry.mbr()).overlap(e.mbr()))
.sum::<f32>();
let overlap_b = node
.entries
.iter()
.filter(|e| !std::ptr::eq(*e, b))
.map(|e| e.mbr().union(entry.mbr()).overlap(e.mbr()))
.sum::<f64>();
.map(|e| e.mbr().merge(entry.mbr()).overlap(e.mbr()))
.sum::<f32>();
let overlap_cmp = overlap_a.partial_cmp(&overlap_b).unwrap_or(Ordering::Equal);
if overlap_cmp != Ordering::Equal {
return overlap_cmp;
}
let enlargement_a = mbr_a.enlargement(entry.mbr());
let enlargement_b = mbr_b.enlargement(entry.mbr());
let enlargement_a = mbr_a.merge(entry.mbr()).visible_area() - mbr_a.visible_area();
let enlargement_b = mbr_b.merge(entry.mbr()).visible_area() - mbr_b.visible_area();
let enlargement_cmp = enlargement_a
.partial_cmp(&enlargement_b)
.unwrap_or(Ordering::Equal);
@ -321,8 +315,8 @@ fn choose_subtree<T: RStarTreeObject>(node: &TreeNode<T>, entry: &Entry<T>) -> u
}
mbr_a
.area()
.partial_cmp(&mbr_b.area())
.visible_area()
.partial_cmp(&mbr_b.visible_area())
.unwrap_or(Ordering::Equal)
})
.map(|(i, _)| i)
@ -335,8 +329,8 @@ fn choose_subtree<T: RStarTreeObject>(node: &TreeNode<T>, entry: &Entry<T>) -> u
let mbr_a = a.mbr();
let mbr_b = b.mbr();
let enlargement_a = mbr_a.enlargement(entry.mbr());
let enlargement_b = mbr_b.enlargement(entry.mbr());
let enlargement_a = mbr_a.merge(entry.mbr()).visible_area() - mbr_a.visible_area();
let enlargement_b = mbr_b.merge(entry.mbr()).visible_area() - mbr_b.visible_area();
let enlargement_cmp = enlargement_a
.partial_cmp(&enlargement_b)
@ -345,8 +339,8 @@ fn choose_subtree<T: RStarTreeObject>(node: &TreeNode<T>, entry: &Entry<T>) -> u
return enlargement_cmp;
}
mbr_a
.area()
.partial_cmp(&mbr_b.area())
.visible_area()
.partial_cmp(&mbr_b.visible_area())
.unwrap_or(Ordering::Equal)
})
.map(|(i, _)| i)
@ -354,17 +348,14 @@ fn choose_subtree<T: RStarTreeObject>(node: &TreeNode<T>, entry: &Entry<T>) -> u
}
}
fn insert_recursive<T: RStarTreeObject + Clone>(
node: &mut TreeNode<T>,
entry: Entry<T>,
fn insert_recursive(
node: &mut TreeNode,
entry: Entry,
max_entries: usize,
level: usize,
reinsert_level: &mut Option<usize>,
to_insert_queue: &mut Vec<(Entry<T>, usize)>,
) -> Option<(Vec<Entry<T>>, usize)>
where
T::B: BSPBounds,
{
to_insert_queue: &mut Vec<(Entry, usize)>,
) -> Option<(Vec<Entry>, usize)> {
if node.is_leaf {
node.entries.push(entry);
} else {
@ -385,6 +376,11 @@ where
) {
if reinsert_level.is_some() && *reinsert_level == Some(overflow_level) {
let (g1, g2) = split_entries(overflow, max_entries);
let mbr1 = entries_mbr(&g1)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr2 = entries_mbr(&g2)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let child1 = TreeNode {
entries: g1,
is_leaf: child.is_leaf,
@ -393,17 +389,13 @@ where
entries: g2,
is_leaf: child.is_leaf,
};
let mbr1 = common_compute_group_mbr(&child1.entries)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr2 = common_compute_group_mbr(&child2.entries)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
node.entries[best_index] = Entry::Node {
mbr: mbr1,
child: Box::new(child1),
child: child1,
};
node.entries.push(Entry::Node {
mbr: mbr2,
child: Box::new(child2),
child: child2,
});
} else {
if reinsert_level.is_none() {
@ -422,18 +414,19 @@ where
}
}
}
if let Some(new_mbr) = common_compute_group_mbr(
if let Entry::Node { child, .. } = &node.entries[best_index] {
&child.entries
} else {
unreachable!()
},
) {
if let Entry::Node { mbr, .. } = &mut node.entries[best_index] {
let children = &mut node.entries_mut()[best_index];
let Entry::Node {
child: children,
mbr,
} = children
else {
return None;
};
if let Some(new_mbr) = entries_mbr(children.entries()) {
*mbr = new_mbr;
}
}
}
if node.entries.len() > max_entries {
return Some((std::mem::take(&mut node.entries), level));
@ -441,53 +434,29 @@ where
None
}
fn forced_reinsert<T: RStarTreeObject + Clone>(
node: &mut TreeNode<T>,
max_entries: usize,
) -> Vec<Entry<T>>
where
T::B: BSPBounds,
{
let node_mbr = if let Some(mbr) = common_compute_group_mbr(&node.entries) {
fn forced_reinsert(node: &mut TreeNode, max_entries: usize) -> Vec<Entry> {
let node_mbr = if let Some(mbr) = entries_mbr(&node.entries) {
mbr
} else {
return Vec::new();
};
let reinsert_count = (max_entries as f64 * 0.3).ceil() as usize;
let reinsert_count = (max_entries as f32 * 0.3).ceil() as usize;
node.entries.sort_by(|a, b| {
let center_a: Vec<f64> = (0..T::B::DIM)
.map(|d| {
a.mbr()
.center(d)
.unwrap_or_else(|_| unreachable!("dim valid"))
})
.collect();
let center_b: Vec<f64> = (0..T::B::DIM)
.map(|d| {
b.mbr()
.center(d)
.unwrap_or_else(|_| unreachable!("dim valid"))
})
.collect();
let node_center: Vec<f64> = (0..T::B::DIM)
.map(|d| {
node_mbr
.center(d)
.unwrap_or_else(|_| unreachable!("dim valid"))
})
.collect();
let center_a: Vec<f32> = (0..2).map(|d| a.mbr().center()[d]).collect();
let center_b: Vec<f32> = (0..2).map(|d| b.mbr().center()[d]).collect();
let node_center: Vec<f32> = (0..2).map(|d| node_mbr.center()[d]).collect();
let dist_a = center_a
.iter()
.zip(node_center.iter())
.map(|(ca, cb)| (ca - cb).powi(2))
.sum::<f64>();
.sum::<f32>();
let dist_b = center_b
.iter()
.zip(node_center.iter())
.map(|(ca, cb)| (ca - cb).powi(2))
.sum::<f64>();
.sum::<f32>();
dist_b.partial_cmp(&dist_a).unwrap_or(Ordering::Equal)
});
@ -495,38 +464,29 @@ where
node.entries.drain(0..reinsert_count).collect()
}
fn split_entries<T: RStarTreeObject + Clone>(
mut entries: Vec<Entry<T>>,
max_entries: usize,
) -> (Vec<Entry<T>>, Vec<Entry<T>>)
where
T::B: BSPBounds,
{
let min_entries = (max_entries as f64 * 0.4).ceil() as usize;
fn split_entries(mut entries: Vec<Entry>, max_entries: usize) -> (Vec<Entry>, Vec<Entry>) {
let min_entries = (max_entries as f32 * 0.4).ceil() as usize;
let mut best_axis = 0;
let mut best_split_index = 0;
let mut min_margin = f64::INFINITY;
let mut min_margin = f32::INFINITY;
for dim in 0..T::B::DIM {
for dim in 0..2 {
entries.sort_by(|a, b| {
let ca = a
.mbr()
.center(dim)
.unwrap_or_else(|_| unreachable!("dim valid"));
let cb = b
.mbr()
.center(dim)
.unwrap_or_else(|_| unreachable!("dim valid"));
let ca = a.mbr().center()[dim];
let cb = b.mbr().center()[dim];
ca.partial_cmp(&cb).unwrap_or(Ordering::Equal)
});
for k in min_entries..=entries.len() - min_entries {
let group1 = &entries[..k];
let group2 = &entries[k..];
let mbr1 = common_compute_group_mbr(group1)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr2 = common_compute_group_mbr(group2)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mut mbr1 = group1[0].mbr().clone();
for entry in group1 {
mbr1 = mbr1.merge(entry.mbr());
}
let mbr2 = group2[0].mbr().clone();
let margin = mbr1.margin() + mbr2.margin();
if margin < min_margin {
min_margin = margin;
@ -537,29 +497,23 @@ where
}
entries.sort_by(|a, b| {
let ca = a
.mbr()
.center(best_axis)
.unwrap_or_else(|_| unreachable!("dim valid"));
let cb = b
.mbr()
.center(best_axis)
.unwrap_or_else(|_| unreachable!("dim valid"));
let ca = a.mbr().center()[best_axis];
let cb = b.mbr().center()[best_axis];
ca.partial_cmp(&cb).unwrap_or(Ordering::Equal)
});
let mut best_overlap = f64::INFINITY;
let mut best_area = f64::INFINITY;
let mut best_overlap = f32::INFINITY;
let mut best_area = f32::INFINITY;
for k in min_entries..=entries.len() - min_entries {
let group1 = &entries[..k];
let group2 = &entries[k..];
let mbr1 = common_compute_group_mbr(group1)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr2 = common_compute_group_mbr(group2)
.unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr1 =
entries_mbr(group1).unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let mbr2 =
entries_mbr(group2).unwrap_or_else(|| unreachable!("non-empty group must have MBR"));
let overlap = mbr1.overlap(&mbr2);
let area = mbr1.area() + mbr2.area();
let area = mbr1.visible_area() + mbr2.visible_area();
if overlap < best_overlap {
best_overlap = overlap;
@ -575,334 +529,7 @@ where
(group1.to_vec(), group2.to_vec())
}
impl<T: RStarTreeObject> RStarTree<T>
where
T: PartialEq + Clone,
T::B: BSPBounds,
{
/// Deletes an object from the R*tree.
///
/// # Arguments
///
/// * `object` - The object to delete.
///
/// # Returns
///
/// `true` if at least one matching object was found and removed.
pub fn delete(&mut self, object: &T) -> bool {
info!("Attempting to delete object: {:?}", object);
let object_mbr = object.mbr();
let mut reinsert_list = Vec::new();
let deleted = common_delete_entry(
&mut self.root,
object,
&object_mbr,
self.min_entries,
&mut reinsert_list,
);
if deleted {
for entry in reinsert_list {
self.insert_entry(entry, None);
}
if !self.root.is_leaf && self.root.entries.len() == 1 {
if let Some(Entry::Node { child, .. }) = self.root.entries.pop() {
self.root = *child;
}
}
}
deleted
}
}
impl<T: std::fmt::Debug + Clone> RStarTreeObject for Point2D<T> {
type B = Rectangle;
fn mbr(&self) -> Self::B {
Rectangle {
x: self.x,
y: self.y,
width: EPSILON,
height: EPSILON,
}
}
}
impl<T: std::fmt::Debug + Clone> RStarTreeObject for Point3D<T> {
type B = Cube;
fn mbr(&self) -> Self::B {
Cube {
x: self.x,
y: self.y,
z: self.z,
width: EPSILON,
height: EPSILON,
depth: EPSILON,
}
}
}
impl<T: std::fmt::Debug + Clone> RStarTree<Point2D<T>> {
/// Performs a knearest neighbor search on an R*tree of 2D points.
///
/// # Arguments
///
/// * `query` - The 2D point to search near.
/// * `k` - The number of nearest neighbors to return.
///
/// # Returns
///
/// A vector of references to the k nearest 2D points.
///
/// # Note
///
/// The pruning logic for the search is based on Euclidean distance. Custom distance metrics
/// that are not compatible with Euclidean distance may lead to incorrect results or reduced
/// performance.
pub fn knn_search<M: DistanceMetric<Point2D<T>>>(
&self,
query: &Point2D<T>,
k: usize,
) -> Vec<&Point2D<T>> {
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<KnnCandidate<Entry<Point2D<T>>>> = BinaryHeap::new();
for entry in &self.root.entries {
let dist_sq = entry.mbr().min_distance(query).powi(2);
heap.push(KnnCandidate {
dist: dist_sq,
entry,
});
}
type OrdDist = OrderedFloat<f64>;
#[inline]
#[allow(non_snake_case)]
fn OrdDist(x: f64) -> OrderedFloat<f64> {
OrderedFloat(x)
}
struct HeapItem<'a, P> {
key: OrdDist,
idx: usize,
obj: &'a P,
}
impl<P> PartialEq for HeapItem<'_, P> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key && self.idx == other.idx
}
}
impl<P> Eq for HeapItem<'_, P> {}
impl<P> Ord for HeapItem<'_, P> {
fn cmp(&self, other: &Self) -> Ordering {
match self.key.cmp(&other.key) {
Ordering::Equal => self.idx.cmp(&other.idx),
ord => ord,
}
}
}
impl<P> PartialOrd for HeapItem<'_, P> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut results: BinaryHeap<HeapItem<Point2D<T>>> = BinaryHeap::new();
let mut counter: usize = 0;
while let Some(KnnCandidate { dist, entry }) = heap.pop() {
if results.len() >= k {
if let Some(worst_result) = results.peek() {
if dist > worst_result.key.0 {
break;
}
}
}
match entry {
Entry::Leaf { object, .. } => {
let d_sq = M::distance_sq(query, object);
if results.len() < k {
counter += 1;
results.push(HeapItem {
key: OrdDist(d_sq),
idx: counter,
obj: object,
});
} else if let Some(peek) = results.peek() {
if d_sq < peek.key.0 {
results.pop();
counter += 1;
results.push(HeapItem {
key: OrdDist(d_sq),
idx: counter,
obj: object,
});
}
}
}
Entry::Node { child, .. } => {
for child_entry in &child.entries {
let d_sq = child_entry.mbr().min_distance(query).powi(2);
if results.len() < k {
heap.push(KnnCandidate {
dist: d_sq,
entry: child_entry,
});
} else if let Some(peek) = results.peek() {
if d_sq < peek.key.0 {
heap.push(KnnCandidate {
dist: d_sq,
entry: child_entry,
});
}
}
}
}
}
}
let mut sorted_results = results.into_vec();
sorted_results.sort_by(|a, b| a.key.partial_cmp(&b.key).unwrap_or(Ordering::Equal));
sorted_results.into_iter().map(|r| r.obj).collect()
}
}
impl<T: std::fmt::Debug + Clone> RStarTree<Point3D<T>> {
/// Performs a knearest neighbor search on an R*tree of 3D points.
///
/// # Arguments
///
/// * `query` - The 3D point to search near.
/// * `k` - The number of nearest neighbors to return.
///
/// # Returns
///
/// A vector of references to the k nearest 3D points.
///
/// # Note
///
/// The pruning logic for the search is based on Euclidean distance. Custom distance metrics
/// that are not compatible with Euclidean distance may lead to incorrect results or reduced
/// performance.
pub fn knn_search<M: DistanceMetric<Point3D<T>>>(
&self,
query: &Point3D<T>,
k: usize,
) -> Vec<&Point3D<T>> {
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<KnnCandidate<Entry<Point3D<T>>>> = BinaryHeap::new();
for entry in &self.root.entries {
let dist_sq = entry.mbr().min_distance(query).powi(2);
heap.push(KnnCandidate {
dist: dist_sq,
entry,
});
}
type OrdDist = OrderedFloat<f64>;
#[inline]
#[allow(non_snake_case)]
fn OrdDist(x: f64) -> OrderedFloat<f64> {
OrderedFloat(x)
}
struct HeapItem<'a, P> {
key: OrdDist,
idx: usize,
obj: &'a P,
}
impl<P> PartialEq for HeapItem<'_, P> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key && self.idx == other.idx
}
}
impl<P> Eq for HeapItem<'_, P> {}
impl<P> Ord for HeapItem<'_, P> {
fn cmp(&self, other: &Self) -> Ordering {
match self.key.cmp(&other.key) {
Ordering::Equal => self.idx.cmp(&other.idx),
ord => ord,
}
}
}
impl<P> PartialOrd for HeapItem<'_, P> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut results: BinaryHeap<HeapItem<Point3D<T>>> = BinaryHeap::new();
let mut counter: usize = 0;
while let Some(KnnCandidate { dist, entry }) = heap.pop() {
if results.len() >= k {
if let Some(worst_result) = results.peek() {
if dist > worst_result.key.0 {
break;
}
}
}
match entry {
Entry::Leaf { object, .. } => {
let d_sq = M::distance_sq(query, object);
if results.len() < k {
counter += 1;
results.push(HeapItem {
key: OrdDist(d_sq),
idx: counter,
obj: object,
});
} else if let Some(peek) = results.peek() {
if d_sq < peek.key.0 {
results.pop();
counter += 1;
results.push(HeapItem {
key: OrdDist(d_sq),
idx: counter,
obj: object,
});
}
}
}
Entry::Node { child, .. } => {
for child_entry in &child.entries {
let d_sq = child_entry.mbr().min_distance(query).powi(2);
if results.len() < k {
heap.push(KnnCandidate {
dist: d_sq,
entry: child_entry,
});
} else if let Some(peek) = results.peek() {
if d_sq < peek.key.0 {
heap.push(KnnCandidate {
dist: d_sq,
entry: child_entry,
});
}
}
}
}
}
}
let mut sorted_results = results.into_vec();
sorted_results.sort_by(|a, b| a.key.partial_cmp(&b.key).unwrap_or(Ordering::Equal));
sorted_results.into_iter().map(|r| r.obj).collect()
}
}
impl<T> RStarTree<T>
where
T: RStarTreeObject + PartialEq + std::fmt::Debug,
T::B: BoundingVolumeFromPoint<T> + HasMinDistance<T> + Clone,
{
impl RStarTree {
/// Performs a range search on the R*tree using a query object and radius.
///
/// The query object is wrapped into a bounding volume using `from_point_radius`.
@ -915,18 +542,36 @@ where
/// # Returns
///
/// A vector of references to the objects within the given radius.
///
/// # Note
///
/// The pruning logic for the search is based on Euclidean distance. Custom distance metrics
/// that are not compatible with Euclidean distance may lead to incorrect results or reduced
/// performance.
pub fn range_search<M: DistanceMetric<T>>(&self, query: &T, radius: f64) -> Vec<&T> {
let query_volume = T::B::from_point_radius(query, radius);
let candidates = self.range_search_bbox(&query_volume);
pub fn range_search(&self, query_point: &Point, radius: f32) -> Vec<&Point> {
let query_bbox = Aabb2d::new(query_point.point, Vec2::splat(radius));
let r2 = radius.powi(2);
let candidates = self.range_search_bbox(&query_bbox);
candidates
.into_iter()
.filter(|object| M::distance_sq(query, object) <= radius * radius)
.filter(|other| query_point.point.distance_squared(other.point) <= r2)
.collect()
}
}
fn entries_mbr(entries: &[Entry]) -> Option<Aabb2d> {
let mut iter = entries.iter();
let first = iter.next()?.mbr().clone();
Some(iter.fold(first, |acc, entry| acc.merge(entry.mbr())))
}
trait Bvr: BoundingVolume {
fn margin(&self) -> f32;
fn overlap(&self, other: &Self) -> f32;
}
impl Bvr for Aabb2d {
fn margin(&self) -> f32 {
let Vec2 { x, y } = self.half_size();
2.0 * x * y
}
fn overlap(&self, other: &Aabb2d) -> f32 {
todo!()
}
}