| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | use std::collections::HashMap; |
| | use std::sync::Arc; |
| |
|
| | use crate::core::{Id, Point}; |
| | use crate::core::proximity::Proximity; |
| | use crate::ports::{Near, NearError, NearResult, SearchResult}; |
| |
|
| | |
| | pub struct FlatIndex { |
| | |
| | points: HashMap<Id, Point>, |
| |
|
| | |
| | dimensionality: usize, |
| |
|
| | |
| | proximity: Arc<dyn Proximity>, |
| |
|
| | |
| | |
| | higher_is_better: bool, |
| | } |
| |
|
| | impl FlatIndex { |
| | |
| | |
| | |
| | |
| | |
| | pub fn new( |
| | dimensionality: usize, |
| | proximity: Arc<dyn Proximity>, |
| | higher_is_better: bool, |
| | ) -> Self { |
| | Self { |
| | points: HashMap::new(), |
| | dimensionality, |
| | proximity, |
| | higher_is_better, |
| | } |
| | } |
| |
|
| | |
| | pub fn cosine(dimensionality: usize) -> Self { |
| | use crate::core::proximity::Cosine; |
| | Self::new(dimensionality, Arc::new(Cosine), true) |
| | } |
| |
|
| | |
| | pub fn euclidean(dimensionality: usize) -> Self { |
| | use crate::core::proximity::Euclidean; |
| | Self::new(dimensionality, Arc::new(Euclidean), false) |
| | } |
| |
|
| | |
| | fn sort_results(&self, results: &mut Vec<SearchResult>) { |
| | if self.higher_is_better { |
| | |
| | results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); |
| | } else { |
| | |
| | results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap()); |
| | } |
| | } |
| | } |
| |
|
| | impl Near for FlatIndex { |
| | fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> { |
| | |
| | if query.dimensionality() != self.dimensionality { |
| | return Err(NearError::DimensionalityMismatch { |
| | expected: self.dimensionality, |
| | got: query.dimensionality(), |
| | }); |
| | } |
| |
|
| | |
| | let mut results: Vec<SearchResult> = self |
| | .points |
| | .iter() |
| | .map(|(id, point)| { |
| | let score = self.proximity.proximity(query, point); |
| | SearchResult::new(*id, score) |
| | }) |
| | .collect(); |
| |
|
| | |
| | self.sort_results(&mut results); |
| |
|
| | |
| | results.truncate(k); |
| |
|
| | Ok(results) |
| | } |
| |
|
| | fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> { |
| | |
| | if query.dimensionality() != self.dimensionality { |
| | return Err(NearError::DimensionalityMismatch { |
| | expected: self.dimensionality, |
| | got: query.dimensionality(), |
| | }); |
| | } |
| |
|
| | |
| | let mut results: Vec<SearchResult> = self |
| | .points |
| | .iter() |
| | .filter_map(|(id, point)| { |
| | let score = self.proximity.proximity(query, point); |
| | let within = if self.higher_is_better { |
| | score >= threshold |
| | } else { |
| | score <= threshold |
| | }; |
| | if within { |
| | Some(SearchResult::new(*id, score)) |
| | } else { |
| | None |
| | } |
| | }) |
| | .collect(); |
| |
|
| | |
| | self.sort_results(&mut results); |
| |
|
| | Ok(results) |
| | } |
| |
|
| | fn add(&mut self, id: Id, point: &Point) -> NearResult<()> { |
| | if point.dimensionality() != self.dimensionality { |
| | return Err(NearError::DimensionalityMismatch { |
| | expected: self.dimensionality, |
| | got: point.dimensionality(), |
| | }); |
| | } |
| |
|
| | self.points.insert(id, point.clone()); |
| | Ok(()) |
| | } |
| |
|
| | fn remove(&mut self, id: Id) -> NearResult<()> { |
| | self.points.remove(&id); |
| | Ok(()) |
| | } |
| |
|
| | fn rebuild(&mut self) -> NearResult<()> { |
| | |
| | Ok(()) |
| | } |
| |
|
| | fn is_ready(&self) -> bool { |
| | true |
| | } |
| |
|
| | fn len(&self) -> usize { |
| | self.points.len() |
| | } |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | fn setup_index() -> FlatIndex { |
| | let mut index = FlatIndex::cosine(3); |
| |
|
| | |
| | let points = vec![ |
| | (Id::from_bytes([1; 16]), Point::new(vec![1.0, 0.0, 0.0])), |
| | (Id::from_bytes([2; 16]), Point::new(vec![0.0, 1.0, 0.0])), |
| | (Id::from_bytes([3; 16]), Point::new(vec![0.0, 0.0, 1.0])), |
| | (Id::from_bytes([4; 16]), Point::new(vec![0.7, 0.7, 0.0]).normalize()), |
| | ]; |
| |
|
| | for (id, point) in points { |
| | index.add(id, &point).unwrap(); |
| | } |
| |
|
| | index |
| | } |
| |
|
| | #[test] |
| | fn test_flat_index_near() { |
| | let index = setup_index(); |
| |
|
| | |
| | let query = Point::new(vec![1.0, 0.0, 0.0]); |
| | let results = index.near(&query, 2).unwrap(); |
| |
|
| | assert_eq!(results.len(), 2); |
| |
|
| | |
| | assert_eq!(results[0].id, Id::from_bytes([1; 16])); |
| | assert!((results[0].score - 1.0).abs() < 0.0001); |
| | } |
| |
|
| | #[test] |
| | fn test_flat_index_within_cosine() { |
| | let index = setup_index(); |
| |
|
| | |
| | let query = Point::new(vec![1.0, 0.0, 0.0]); |
| | let results = index.within(&query, 0.5).unwrap(); |
| |
|
| | |
| | assert_eq!(results.len(), 2); |
| | } |
| |
|
| | #[test] |
| | fn test_flat_index_euclidean() { |
| | let mut index = FlatIndex::euclidean(2); |
| |
|
| | index.add(Id::from_bytes([1; 16]), &Point::new(vec![0.0, 0.0])).unwrap(); |
| | index.add(Id::from_bytes([2; 16]), &Point::new(vec![1.0, 0.0])).unwrap(); |
| | index.add(Id::from_bytes([3; 16]), &Point::new(vec![5.0, 0.0])).unwrap(); |
| |
|
| | let query = Point::new(vec![0.0, 0.0]); |
| | let results = index.near(&query, 2).unwrap(); |
| |
|
| | |
| | assert_eq!(results[0].id, Id::from_bytes([1; 16])); |
| | assert!((results[0].score - 0.0).abs() < 0.0001); |
| |
|
| | |
| | assert_eq!(results[1].id, Id::from_bytes([2; 16])); |
| | assert!((results[1].score - 1.0).abs() < 0.0001); |
| | } |
| |
|
| | #[test] |
| | fn test_flat_index_add_remove() { |
| | let mut index = FlatIndex::cosine(3); |
| |
|
| | let id = Id::from_bytes([1; 16]); |
| | let point = Point::new(vec![1.0, 0.0, 0.0]); |
| |
|
| | index.add(id, &point).unwrap(); |
| | assert_eq!(index.len(), 1); |
| |
|
| | index.remove(id).unwrap(); |
| | assert_eq!(index.len(), 0); |
| | } |
| |
|
| | #[test] |
| | fn test_flat_index_dimensionality_check() { |
| | let mut index = FlatIndex::cosine(3); |
| |
|
| | let wrong_dims = Point::new(vec![1.0, 0.0]); |
| | let result = index.add(Id::now(), &wrong_dims); |
| |
|
| | match result { |
| | Err(NearError::DimensionalityMismatch { expected, got }) => { |
| | assert_eq!(expected, 3); |
| | assert_eq!(got, 2); |
| | } |
| | _ => panic!("Expected DimensionalityMismatch error"), |
| | } |
| | } |
| |
|
| | #[test] |
| | fn test_flat_index_ready() { |
| | let index = FlatIndex::cosine(3); |
| | assert!(index.is_ready()); |
| | } |
| | } |
| |
|