| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | use pyo3::prelude::*; |
| | use pyo3::exceptions::{PyValueError, PyIOError}; |
| |
|
| | use crate::core::{Id, Point}; |
| | use crate::adapters::index::{HatIndex as RustHatIndex, HatConfig, ConsolidationConfig, Consolidate}; |
| | use crate::ports::Near; |
| |
|
| | |
| | #[pyclass(name = "SearchResult")] |
| | #[derive(Clone)] |
| | pub struct PySearchResult { |
| | |
| | #[pyo3(get)] |
| | pub id: String, |
| |
|
| | |
| | #[pyo3(get)] |
| | pub score: f32, |
| | } |
| |
|
| | #[pymethods] |
| | impl PySearchResult { |
| | fn __repr__(&self) -> String { |
| | format!("SearchResult(id='{}', score={:.4})", self.id, self.score) |
| | } |
| |
|
| | fn __str__(&self) -> String { |
| | format!("{}: {:.4}", self.id, self.score) |
| | } |
| | } |
| |
|
| | |
| | #[pyclass(name = "HatConfig")] |
| | #[derive(Clone)] |
| | pub struct PyHatConfig { |
| | inner: HatConfig, |
| | } |
| |
|
| | #[pymethods] |
| | impl PyHatConfig { |
| | #[new] |
| | fn new() -> Self { |
| | Self { inner: HatConfig::default() } |
| | } |
| |
|
| | |
| | fn with_beam_width(mut slf: PyRefMut<'_, Self>, width: usize) -> PyRefMut<'_, Self> { |
| | slf.inner.beam_width = width; |
| | slf |
| | } |
| |
|
| | |
| | fn with_temporal_weight(mut slf: PyRefMut<'_, Self>, weight: f32) -> PyRefMut<'_, Self> { |
| | slf.inner.temporal_weight = weight; |
| | slf |
| | } |
| |
|
| | |
| | fn with_propagation_threshold(mut slf: PyRefMut<'_, Self>, threshold: f32) -> PyRefMut<'_, Self> { |
| | slf.inner.propagation_threshold = threshold; |
| | slf |
| | } |
| |
|
| | fn __repr__(&self) -> String { |
| | format!( |
| | "HatConfig(beam_width={}, temporal_weight={:.2}, propagation_threshold={:.3})", |
| | self.inner.beam_width, self.inner.temporal_weight, self.inner.propagation_threshold |
| | ) |
| | } |
| | } |
| |
|
| | |
| | #[pyclass(name = "SessionSummary")] |
| | #[derive(Clone)] |
| | pub struct PySessionSummary { |
| | #[pyo3(get)] |
| | pub id: String, |
| |
|
| | #[pyo3(get)] |
| | pub score: f32, |
| |
|
| | #[pyo3(get)] |
| | pub chunk_count: usize, |
| |
|
| | #[pyo3(get)] |
| | pub timestamp_ms: u64, |
| | } |
| |
|
| | #[pymethods] |
| | impl PySessionSummary { |
| | fn __repr__(&self) -> String { |
| | format!( |
| | "SessionSummary(id='{}', score={:.4}, chunks={})", |
| | self.id, self.score, self.chunk_count |
| | ) |
| | } |
| | } |
| |
|
| | |
| | #[pyclass(name = "DocumentSummary")] |
| | #[derive(Clone)] |
| | pub struct PyDocumentSummary { |
| | #[pyo3(get)] |
| | pub id: String, |
| |
|
| | #[pyo3(get)] |
| | pub score: f32, |
| |
|
| | #[pyo3(get)] |
| | pub chunk_count: usize, |
| | } |
| |
|
| | #[pymethods] |
| | impl PyDocumentSummary { |
| | fn __repr__(&self) -> String { |
| | format!( |
| | "DocumentSummary(id='{}', score={:.4}, chunks={})", |
| | self.id, self.score, self.chunk_count |
| | ) |
| | } |
| | } |
| |
|
| | |
| | #[pyclass(name = "HatStats")] |
| | #[derive(Clone)] |
| | pub struct PyHatStats { |
| | #[pyo3(get)] |
| | pub global_count: usize, |
| |
|
| | #[pyo3(get)] |
| | pub session_count: usize, |
| |
|
| | #[pyo3(get)] |
| | pub document_count: usize, |
| |
|
| | #[pyo3(get)] |
| | pub chunk_count: usize, |
| | } |
| |
|
| | #[pymethods] |
| | impl PyHatStats { |
| | |
| | #[getter] |
| | fn total_points(&self) -> usize { |
| | self.chunk_count |
| | } |
| |
|
| | fn __repr__(&self) -> String { |
| | format!( |
| | "HatStats(points={}, sessions={}, documents={}, chunks={})", |
| | self.chunk_count, self.session_count, self.document_count, self.chunk_count |
| | ) |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | #[pyclass(name = "HatIndex")] |
| | pub struct PyHatIndex { |
| | inner: RustHatIndex, |
| | } |
| |
|
| | #[pymethods] |
| | impl PyHatIndex { |
| | |
| | |
| | |
| | |
| | #[staticmethod] |
| | fn cosine(dimensionality: usize) -> Self { |
| | Self { |
| | inner: RustHatIndex::cosine(dimensionality), |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | #[staticmethod] |
| | fn with_config(dimensionality: usize, config: &PyHatConfig) -> Self { |
| | Self { |
| | inner: RustHatIndex::cosine(dimensionality).with_config(config.inner.clone()), |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | fn add(&mut self, embedding: Vec<f32>) -> PyResult<String> { |
| | let point = Point::new(embedding); |
| | let id = Id::now(); |
| |
|
| | self.inner.add(id, &point) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(format!("{}", id)) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | fn add_with_id(&mut self, id_hex: &str, embedding: Vec<f32>) -> PyResult<()> { |
| | let id = parse_id_hex(id_hex)?; |
| | let point = Point::new(embedding); |
| |
|
| | self.inner.add(id, &point) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | fn near(&self, query: Vec<f32>, k: usize) -> PyResult<Vec<PySearchResult>> { |
| | let point = Point::new(query); |
| |
|
| | let results = self.inner.near(&point, k) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(results.into_iter().map(|r| PySearchResult { |
| | id: format!("{}", r.id), |
| | score: r.score, |
| | }).collect()) |
| | } |
| |
|
| | |
| | |
| | |
| | fn new_session(&mut self) { |
| | self.inner.new_session(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | fn new_document(&mut self) { |
| | self.inner.new_document(); |
| | } |
| |
|
| | |
| | fn stats(&self) -> PyHatStats { |
| | let s = self.inner.stats(); |
| | PyHatStats { |
| | global_count: s.global_count, |
| | session_count: s.session_count, |
| | document_count: s.document_count, |
| | chunk_count: s.chunk_count, |
| | } |
| | } |
| |
|
| | |
| | fn __len__(&self) -> usize { |
| | self.inner.len() |
| | } |
| |
|
| | |
| | fn is_empty(&self) -> bool { |
| | self.inner.is_empty() |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | fn remove(&mut self, id_hex: &str) -> PyResult<()> { |
| | let id = parse_id_hex(id_hex)?; |
| |
|
| | self.inner.remove(id) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | fn near_sessions(&self, query: Vec<f32>, k: usize) -> PyResult<Vec<PySessionSummary>> { |
| | let point = Point::new(query); |
| |
|
| | let results = self.inner.near_sessions(&point, k) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(results.into_iter().map(|s| PySessionSummary { |
| | id: format!("{}", s.id), |
| | score: s.score, |
| | chunk_count: s.chunk_count, |
| | timestamp_ms: s.timestamp, |
| | }).collect()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | fn near_documents(&self, session_id: &str, query: Vec<f32>, k: usize) -> PyResult<Vec<PyDocumentSummary>> { |
| | let sid = parse_id_hex(session_id)?; |
| | let point = Point::new(query); |
| |
|
| | let results = self.inner.near_documents(sid, &point, k) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(results.into_iter().map(|d| PyDocumentSummary { |
| | id: format!("{}", d.id), |
| | score: d.score, |
| | chunk_count: d.chunk_count, |
| | }).collect()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | fn near_in_document(&self, doc_id: &str, query: Vec<f32>, k: usize) -> PyResult<Vec<PySearchResult>> { |
| | let did = parse_id_hex(doc_id)?; |
| | let point = Point::new(query); |
| |
|
| | let results = self.inner.near_in_document(did, &point, k) |
| | .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(results.into_iter().map(|r| PySearchResult { |
| | id: format!("{}", r.id), |
| | score: r.score, |
| | }).collect()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | fn consolidate(&mut self) { |
| | self.inner.consolidate(ConsolidationConfig::light()); |
| | } |
| |
|
| | |
| | fn consolidate_full(&mut self) { |
| | self.inner.consolidate(ConsolidationConfig::full()); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | fn save(&self, path: &str) -> PyResult<()> { |
| | self.inner.save_to_file(std::path::Path::new(path)) |
| | .map_err(|e| PyIOError::new_err(format!("{}", e))) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | #[staticmethod] |
| | fn load(path: &str) -> PyResult<Self> { |
| | let inner = RustHatIndex::load_from_file(std::path::Path::new(path)) |
| | .map_err(|e| PyIOError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(Self { inner }) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, pyo3::types::PyBytes>> { |
| | let data = self.inner.to_bytes() |
| | .map_err(|e| PyIOError::new_err(format!("{}", e)))?; |
| | Ok(pyo3::types::PyBytes::new_bound(py, &data)) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | #[staticmethod] |
| | fn from_bytes(data: &[u8]) -> PyResult<Self> { |
| | let inner = RustHatIndex::from_bytes(data) |
| | .map_err(|e| PyIOError::new_err(format!("{}", e)))?; |
| |
|
| | Ok(Self { inner }) |
| | } |
| |
|
| | fn __repr__(&self) -> String { |
| | let stats = self.inner.stats(); |
| | format!( |
| | "HatIndex(points={}, sessions={})", |
| | stats.chunk_count, stats.session_count |
| | ) |
| | } |
| | } |
| |
|
| | |
| | fn parse_id_hex(hex: &str) -> PyResult<Id> { |
| | if hex.len() != 32 { |
| | return Err(PyValueError::new_err( |
| | format!("ID must be 32 hex characters, got {}", hex.len()) |
| | )); |
| | } |
| |
|
| | let mut bytes = [0u8; 16]; |
| | for (i, chunk) in hex.as_bytes().chunks(2).enumerate() { |
| | let high = hex_char_to_nibble(chunk[0])?; |
| | let low = hex_char_to_nibble(chunk[1])?; |
| | bytes[i] = (high << 4) | low; |
| | } |
| |
|
| | Ok(Id::from_bytes(bytes)) |
| | } |
| |
|
| | fn hex_char_to_nibble(c: u8) -> PyResult<u8> { |
| | match c { |
| | b'0'..=b'9' => Ok(c - b'0'), |
| | b'a'..=b'f' => Ok(c - b'a' + 10), |
| | b'A'..=b'F' => Ok(c - b'A' + 10), |
| | _ => Err(PyValueError::new_err(format!("Invalid hex character: {}", c as char))), |
| | } |
| | } |
| |
|
| | |
| | #[pymodule] |
| | fn arms_hat(m: &Bound<'_, PyModule>) -> PyResult<()> { |
| | m.add_class::<PyHatIndex>()?; |
| | m.add_class::<PyHatConfig>()?; |
| | m.add_class::<PySearchResult>()?; |
| | m.add_class::<PySessionSummary>()?; |
| | m.add_class::<PyDocumentSummary>()?; |
| | m.add_class::<PyHatStats>()?; |
| |
|
| | |
| | m.add("__doc__", "ARMS-HAT: Hierarchical Attention Tree for AI memory retrieval")?; |
| | m.add("__version__", env!("CARGO_PKG_VERSION"))?; |
| |
|
| | Ok(()) |
| | } |
| |
|