| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | use crate::core::{Id, Point}; |
| | use std::io::{self, Read, Write, Cursor}; |
| |
|
| | |
| | const MAGIC: &[u8; 4] = b"HAT\0"; |
| |
|
| | |
| | const VERSION: u32 = 1; |
| |
|
| | |
| | #[derive(Debug)] |
| | pub enum PersistError { |
| | |
| | InvalidMagic, |
| | |
| | UnsupportedVersion(u32), |
| | |
| | Io(io::Error), |
| | |
| | Corrupted(String), |
| | |
| | DimensionMismatch { expected: usize, found: usize }, |
| | } |
| |
|
| | impl std::fmt::Display for PersistError { |
| | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| | match self { |
| | PersistError::InvalidMagic => write!(f, "Invalid HAT file magic bytes"), |
| | PersistError::UnsupportedVersion(v) => write!(f, "Unsupported HAT version: {}", v), |
| | PersistError::Io(e) => write!(f, "IO error: {}", e), |
| | PersistError::Corrupted(msg) => write!(f, "Data corruption: {}", msg), |
| | PersistError::DimensionMismatch { expected, found } => { |
| | write!(f, "Dimension mismatch: expected {}, found {}", expected, found) |
| | } |
| | } |
| | } |
| | } |
| |
|
| | impl std::error::Error for PersistError {} |
| |
|
| | impl From<io::Error> for PersistError { |
| | fn from(e: io::Error) -> Self { |
| | PersistError::Io(e) |
| | } |
| | } |
| |
|
| | |
| | #[repr(u8)] |
| | #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| | pub enum LevelByte { |
| | Root = 0, |
| | Session = 1, |
| | Document = 2, |
| | Chunk = 3, |
| | } |
| |
|
| | impl LevelByte { |
| | pub fn from_u8(v: u8) -> Option<Self> { |
| | match v { |
| | 0 => Some(LevelByte::Root), |
| | 1 => Some(LevelByte::Session), |
| | 2 => Some(LevelByte::Document), |
| | 3 => Some(LevelByte::Chunk), |
| | _ => None, |
| | } |
| | } |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct SerializedContainer { |
| | pub id: Id, |
| | pub level: LevelByte, |
| | pub timestamp: u64, |
| | pub children: Vec<Id>, |
| | pub descendant_count: u64, |
| | pub centroid: Vec<f32>, |
| | pub accumulated_sum: Option<Vec<f32>>, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct SerializedHat { |
| | pub version: u32, |
| | pub dimensionality: u32, |
| | pub root_id: Option<Id>, |
| | pub containers: Vec<SerializedContainer>, |
| | pub active_session: Option<Id>, |
| | pub active_document: Option<Id>, |
| | pub router_weights: Option<Vec<f32>>, |
| | } |
| |
|
| | impl SerializedHat { |
| | |
| | pub fn to_bytes(&self) -> Result<Vec<u8>, PersistError> { |
| | let mut buf = Vec::new(); |
| |
|
| | |
| | buf.write_all(MAGIC)?; |
| | buf.write_all(&self.version.to_le_bytes())?; |
| | buf.write_all(&self.dimensionality.to_le_bytes())?; |
| | buf.write_all(&(self.containers.len() as u64).to_le_bytes())?; |
| |
|
| | |
| | if let Some(id) = &self.root_id { |
| | buf.write_all(id.as_bytes())?; |
| | } else { |
| | buf.write_all(&[0u8; 16])?; |
| | } |
| |
|
| | |
| | for container in &self.containers { |
| | |
| | buf.write_all(container.id.as_bytes())?; |
| |
|
| | |
| | buf.write_all(&[container.level as u8])?; |
| |
|
| | |
| | buf.write_all(&container.timestamp.to_le_bytes())?; |
| |
|
| | |
| | buf.write_all(&(container.children.len() as u32).to_le_bytes())?; |
| | for child_id in &container.children { |
| | buf.write_all(child_id.as_bytes())?; |
| | } |
| |
|
| | |
| | buf.write_all(&container.descendant_count.to_le_bytes())?; |
| |
|
| | |
| | for &v in &container.centroid { |
| | buf.write_all(&v.to_le_bytes())?; |
| | } |
| |
|
| | |
| | if let Some(sum) = &container.accumulated_sum { |
| | buf.write_all(&[1u8])?; |
| | for &v in sum { |
| | buf.write_all(&v.to_le_bytes())?; |
| | } |
| | } else { |
| | buf.write_all(&[0u8])?; |
| | } |
| | } |
| |
|
| | |
| | if let Some(id) = &self.active_session { |
| | buf.write_all(id.as_bytes())?; |
| | } else { |
| | buf.write_all(&[0u8; 16])?; |
| | } |
| |
|
| | if let Some(id) = &self.active_document { |
| | buf.write_all(id.as_bytes())?; |
| | } else { |
| | buf.write_all(&[0u8; 16])?; |
| | } |
| |
|
| | |
| | if let Some(weights) = &self.router_weights { |
| | buf.write_all(&[1u8])?; |
| | for &w in weights { |
| | buf.write_all(&w.to_le_bytes())?; |
| | } |
| | } else { |
| | buf.write_all(&[0u8])?; |
| | } |
| |
|
| | Ok(buf) |
| | } |
| |
|
| | |
| | pub fn from_bytes(data: &[u8]) -> Result<Self, PersistError> { |
| | let mut cursor = Cursor::new(data); |
| |
|
| | |
| | let mut magic = [0u8; 4]; |
| | cursor.read_exact(&mut magic)?; |
| | if &magic != MAGIC { |
| | return Err(PersistError::InvalidMagic); |
| | } |
| |
|
| | let mut version_bytes = [0u8; 4]; |
| | cursor.read_exact(&mut version_bytes)?; |
| | let version = u32::from_le_bytes(version_bytes); |
| | if version != VERSION { |
| | return Err(PersistError::UnsupportedVersion(version)); |
| | } |
| |
|
| | let mut dims_bytes = [0u8; 4]; |
| | cursor.read_exact(&mut dims_bytes)?; |
| | let dimensionality = u32::from_le_bytes(dims_bytes); |
| |
|
| | let mut count_bytes = [0u8; 8]; |
| | cursor.read_exact(&mut count_bytes)?; |
| | let container_count = u64::from_le_bytes(count_bytes); |
| |
|
| | let mut root_bytes = [0u8; 16]; |
| | cursor.read_exact(&mut root_bytes)?; |
| | let root_id = if root_bytes == [0u8; 16] { |
| | None |
| | } else { |
| | Some(Id::from_bytes(root_bytes)) |
| | }; |
| |
|
| | |
| | let mut containers = Vec::with_capacity(container_count as usize); |
| | for _ in 0..container_count { |
| | |
| | let mut id_bytes = [0u8; 16]; |
| | cursor.read_exact(&mut id_bytes)?; |
| | let id = Id::from_bytes(id_bytes); |
| |
|
| | |
| | let mut level_byte = [0u8; 1]; |
| | cursor.read_exact(&mut level_byte)?; |
| | let level = LevelByte::from_u8(level_byte[0]) |
| | .ok_or_else(|| PersistError::Corrupted(format!("Invalid level: {}", level_byte[0])))?; |
| |
|
| | |
| | let mut ts_bytes = [0u8; 8]; |
| | cursor.read_exact(&mut ts_bytes)?; |
| | let timestamp = u64::from_le_bytes(ts_bytes); |
| |
|
| | |
| | let mut child_count_bytes = [0u8; 4]; |
| | cursor.read_exact(&mut child_count_bytes)?; |
| | let child_count = u32::from_le_bytes(child_count_bytes) as usize; |
| |
|
| | let mut children = Vec::with_capacity(child_count); |
| | for _ in 0..child_count { |
| | let mut child_bytes = [0u8; 16]; |
| | cursor.read_exact(&mut child_bytes)?; |
| | children.push(Id::from_bytes(child_bytes)); |
| | } |
| |
|
| | |
| | let mut desc_bytes = [0u8; 8]; |
| | cursor.read_exact(&mut desc_bytes)?; |
| | let descendant_count = u64::from_le_bytes(desc_bytes); |
| |
|
| | |
| | let mut centroid = Vec::with_capacity(dimensionality as usize); |
| | for _ in 0..dimensionality { |
| | let mut v_bytes = [0u8; 4]; |
| | cursor.read_exact(&mut v_bytes)?; |
| | centroid.push(f32::from_le_bytes(v_bytes)); |
| | } |
| |
|
| | |
| | let mut has_sum = [0u8; 1]; |
| | cursor.read_exact(&mut has_sum)?; |
| | let accumulated_sum = if has_sum[0] == 1 { |
| | let mut sum = Vec::with_capacity(dimensionality as usize); |
| | for _ in 0..dimensionality { |
| | let mut v_bytes = [0u8; 4]; |
| | cursor.read_exact(&mut v_bytes)?; |
| | sum.push(f32::from_le_bytes(v_bytes)); |
| | } |
| | Some(sum) |
| | } else { |
| | None |
| | }; |
| |
|
| | containers.push(SerializedContainer { |
| | id, |
| | level, |
| | timestamp, |
| | children, |
| | descendant_count, |
| | centroid, |
| | accumulated_sum, |
| | }); |
| | } |
| |
|
| | |
| | let mut active_session_bytes = [0u8; 16]; |
| | cursor.read_exact(&mut active_session_bytes)?; |
| | let active_session = if active_session_bytes == [0u8; 16] { |
| | None |
| | } else { |
| | Some(Id::from_bytes(active_session_bytes)) |
| | }; |
| |
|
| | let mut active_document_bytes = [0u8; 16]; |
| | cursor.read_exact(&mut active_document_bytes)?; |
| | let active_document = if active_document_bytes == [0u8; 16] { |
| | None |
| | } else { |
| | Some(Id::from_bytes(active_document_bytes)) |
| | }; |
| |
|
| | |
| | let router_weights = if cursor.position() < data.len() as u64 { |
| | let mut has_weights = [0u8; 1]; |
| | cursor.read_exact(&mut has_weights)?; |
| | if has_weights[0] == 1 { |
| | let mut weights = Vec::with_capacity(dimensionality as usize); |
| | for _ in 0..dimensionality { |
| | let mut w_bytes = [0u8; 4]; |
| | cursor.read_exact(&mut w_bytes)?; |
| | weights.push(f32::from_le_bytes(w_bytes)); |
| | } |
| | Some(weights) |
| | } else { |
| | None |
| | } |
| | } else { |
| | None |
| | }; |
| |
|
| | Ok(SerializedHat { |
| | version, |
| | dimensionality, |
| | root_id, |
| | containers, |
| | active_session, |
| | active_document, |
| | router_weights, |
| | }) |
| | } |
| | } |
| |
|
| | |
| | fn id_to_bytes(id: &Option<Id>) -> [u8; 16] { |
| | match id { |
| | Some(id) => *id.as_bytes(), |
| | None => [0u8; 16], |
| | } |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_serialized_hat_roundtrip() { |
| | let original = SerializedHat { |
| | version: VERSION, |
| | dimensionality: 128, |
| | root_id: Some(Id::now()), |
| | containers: vec![ |
| | SerializedContainer { |
| | id: Id::now(), |
| | level: LevelByte::Root, |
| | timestamp: 1234567890, |
| | children: vec![Id::now(), Id::now()], |
| | descendant_count: 10, |
| | centroid: vec![0.1; 128], |
| | accumulated_sum: None, |
| | }, |
| | SerializedContainer { |
| | id: Id::now(), |
| | level: LevelByte::Chunk, |
| | timestamp: 1234567891, |
| | children: vec![], |
| | descendant_count: 1, |
| | centroid: vec![0.5; 128], |
| | accumulated_sum: Some(vec![0.5; 128]), |
| | }, |
| | ], |
| | active_session: Some(Id::now()), |
| | active_document: None, |
| | router_weights: Some(vec![1.0; 128]), |
| | }; |
| |
|
| | let bytes = original.to_bytes().unwrap(); |
| | let restored = SerializedHat::from_bytes(&bytes).unwrap(); |
| |
|
| | assert_eq!(restored.version, original.version); |
| | assert_eq!(restored.dimensionality, original.dimensionality); |
| | assert_eq!(restored.containers.len(), original.containers.len()); |
| | assert!(restored.router_weights.is_some()); |
| | } |
| |
|
| | #[test] |
| | fn test_invalid_magic() { |
| | let bad_data = b"BAD\0rest of data..."; |
| | let result = SerializedHat::from_bytes(bad_data); |
| | assert!(matches!(result, Err(PersistError::InvalidMagic))); |
| | } |
| |
|
| | #[test] |
| | fn test_level_byte_conversion() { |
| | assert_eq!(LevelByte::from_u8(0), Some(LevelByte::Root)); |
| | assert_eq!(LevelByte::from_u8(1), Some(LevelByte::Session)); |
| | assert_eq!(LevelByte::from_u8(2), Some(LevelByte::Document)); |
| | assert_eq!(LevelByte::from_u8(3), Some(LevelByte::Chunk)); |
| | assert_eq!(LevelByte::from_u8(4), None); |
| | } |
| | } |
| |
|