| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | use crate::core::Point; |
| | use std::collections::VecDeque; |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct LearnableRoutingConfig { |
| | |
| | pub learning_rate: f32, |
| |
|
| | |
| | pub momentum: f32, |
| |
|
| | |
| | pub weight_decay: f32, |
| |
|
| | |
| | pub max_feedback_samples: usize, |
| |
|
| | |
| | pub min_samples_to_learn: usize, |
| |
|
| | |
| | pub update_frequency: usize, |
| |
|
| | |
| | pub per_dimension_weights: bool, |
| | } |
| |
|
| | impl Default for LearnableRoutingConfig { |
| | fn default() -> Self { |
| | Self { |
| | learning_rate: 0.01, |
| | momentum: 0.9, |
| | weight_decay: 0.001, |
| | max_feedback_samples: 1000, |
| | min_samples_to_learn: 50, |
| | update_frequency: 10, |
| | per_dimension_weights: true, |
| | } |
| | } |
| | } |
| |
|
| | impl LearnableRoutingConfig { |
| | pub fn new() -> Self { |
| | Self::default() |
| | } |
| |
|
| | pub fn with_learning_rate(mut self, lr: f32) -> Self { |
| | self.learning_rate = lr; |
| | self |
| | } |
| |
|
| | pub fn with_momentum(mut self, momentum: f32) -> Self { |
| | self.momentum = momentum.clamp(0.0, 0.99); |
| | self |
| | } |
| |
|
| | pub fn disabled() -> Self { |
| | Self { |
| | learning_rate: 0.0, |
| | ..Default::default() |
| | } |
| | } |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct RoutingFeedback { |
| | |
| | pub query: Point, |
| |
|
| | |
| | pub selected_centroid: Point, |
| |
|
| | |
| | pub reward: f32, |
| |
|
| | |
| | pub level: usize, |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | #[derive(Debug, Clone)] |
| | pub struct LearnableRouter { |
| | |
| | config: LearnableRoutingConfig, |
| |
|
| | |
| | weights: Vec<f32>, |
| |
|
| | |
| | momentum_buffer: Vec<f32>, |
| |
|
| | |
| | feedback_buffer: VecDeque<RoutingFeedback>, |
| |
|
| | |
| | total_samples: usize, |
| |
|
| | |
| | dims: usize, |
| | } |
| |
|
| | impl LearnableRouter { |
| | |
| | pub fn new(dims: usize, config: LearnableRoutingConfig) -> Self { |
| | let weight_count = if config.per_dimension_weights { dims } else { 1 }; |
| |
|
| | Self { |
| | config, |
| | weights: vec![1.0; weight_count], |
| | momentum_buffer: vec![0.0; weight_count], |
| | feedback_buffer: VecDeque::new(), |
| | total_samples: 0, |
| | dims, |
| | } |
| | } |
| |
|
| | |
| | pub fn default_for_dims(dims: usize) -> Self { |
| | Self::new(dims, LearnableRoutingConfig::default()) |
| | } |
| |
|
| | |
| | pub fn is_learning_enabled(&self) -> bool { |
| | self.config.learning_rate > 0.0 |
| | } |
| |
|
| | |
| | pub fn weights(&self) -> &[f32] { |
| | &self.weights |
| | } |
| |
|
| | |
| | |
| | |
| | pub fn weighted_similarity(&self, query: &Point, centroid: &Point) -> f32 { |
| | if self.config.per_dimension_weights { |
| | |
| | query.dims().iter() |
| | .zip(centroid.dims().iter()) |
| | .zip(self.weights.iter()) |
| | .map(|((q, c), w)| w * q * c) |
| | .sum() |
| | } else { |
| | |
| | let dot: f32 = query.dims().iter() |
| | .zip(centroid.dims().iter()) |
| | .map(|(q, c)| q * c) |
| | .sum(); |
| | self.weights[0] * dot |
| | } |
| | } |
| |
|
| | |
| | pub fn record_feedback(&mut self, feedback: RoutingFeedback) { |
| | self.feedback_buffer.push_back(feedback); |
| | self.total_samples += 1; |
| |
|
| | |
| | while self.feedback_buffer.len() > self.config.max_feedback_samples { |
| | self.feedback_buffer.pop_front(); |
| | } |
| |
|
| | |
| | if self.should_update() { |
| | self.update_weights(); |
| | } |
| | } |
| |
|
| | |
| | fn should_update(&self) -> bool { |
| | self.config.learning_rate > 0.0 |
| | && self.feedback_buffer.len() >= self.config.min_samples_to_learn |
| | && self.total_samples % self.config.update_frequency == 0 |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | fn update_weights(&mut self) { |
| | if self.feedback_buffer.is_empty() { |
| | return; |
| | } |
| |
|
| | let lr = self.config.learning_rate; |
| | let momentum = self.config.momentum; |
| | let decay = self.config.weight_decay; |
| |
|
| | |
| | let mut gradient = vec![0.0f32; self.weights.len()]; |
| |
|
| | for feedback in &self.feedback_buffer { |
| | let reward = feedback.reward; |
| |
|
| | if self.config.per_dimension_weights { |
| | |
| | for ((&q, &c), g) in feedback.query.dims().iter() |
| | .zip(feedback.selected_centroid.dims().iter()) |
| | .zip(gradient.iter_mut()) |
| | { |
| | |
| | *g += reward * q * c; |
| | } |
| | } else { |
| | |
| | let dot: f32 = feedback.query.dims().iter() |
| | .zip(feedback.selected_centroid.dims().iter()) |
| | .map(|(q, c)| q * c) |
| | .sum(); |
| | gradient[0] += reward * dot; |
| | } |
| | } |
| |
|
| | |
| | let n = self.feedback_buffer.len() as f32; |
| | for g in gradient.iter_mut() { |
| | *g /= n; |
| | } |
| |
|
| | |
| | for (i, (w, g)) in self.weights.iter_mut().zip(gradient.iter()).enumerate() { |
| | |
| | self.momentum_buffer[i] = momentum * self.momentum_buffer[i] + (1.0 - momentum) * g; |
| |
|
| | |
| | *w += lr * self.momentum_buffer[i] - decay * (*w - 1.0); |
| |
|
| | |
| | *w = w.clamp(0.1, 10.0); |
| | } |
| | } |
| |
|
| | |
| | pub fn record_success(&mut self, query: &Point, selected_centroid: &Point, level: usize) { |
| | self.record_feedback(RoutingFeedback { |
| | query: query.clone(), |
| | selected_centroid: selected_centroid.clone(), |
| | reward: 1.0, |
| | level, |
| | }); |
| | } |
| |
|
| | |
| | pub fn record_failure(&mut self, query: &Point, selected_centroid: &Point, level: usize) { |
| | self.record_feedback(RoutingFeedback { |
| | query: query.clone(), |
| | selected_centroid: selected_centroid.clone(), |
| | reward: -1.0, |
| | level, |
| | }); |
| | } |
| |
|
| | |
| | pub fn record_implicit(&mut self, query: &Point, selected_centroid: &Point, level: usize, relevance_score: f32) { |
| | |
| | let reward = 2.0 * relevance_score - 1.0; |
| | self.record_feedback(RoutingFeedback { |
| | query: query.clone(), |
| | selected_centroid: selected_centroid.clone(), |
| | reward, |
| | level, |
| | }); |
| | } |
| |
|
| | |
| | pub fn stats(&self) -> RouterStats { |
| | RouterStats { |
| | total_samples: self.total_samples, |
| | buffer_size: self.feedback_buffer.len(), |
| | weight_mean: self.weights.iter().sum::<f32>() / self.weights.len() as f32, |
| | weight_std: { |
| | let mean = self.weights.iter().sum::<f32>() / self.weights.len() as f32; |
| | (self.weights.iter().map(|w| (w - mean).powi(2)).sum::<f32>() |
| | / self.weights.len() as f32).sqrt() |
| | }, |
| | weight_min: self.weights.iter().cloned().fold(f32::INFINITY, f32::min), |
| | weight_max: self.weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max), |
| | } |
| | } |
| |
|
| | |
| | pub fn reset_weights(&mut self) { |
| | for w in self.weights.iter_mut() { |
| | *w = 1.0; |
| | } |
| | for m in self.momentum_buffer.iter_mut() { |
| | *m = 0.0; |
| | } |
| | } |
| |
|
| | |
| | pub fn clear_feedback(&mut self) { |
| | self.feedback_buffer.clear(); |
| | } |
| |
|
| | |
| | pub fn dims(&self) -> usize { |
| | self.dims |
| | } |
| |
|
| | |
| | pub fn serialize_weights(&self) -> Vec<u8> { |
| | let mut bytes = Vec::with_capacity(self.weights.len() * 4); |
| | for w in &self.weights { |
| | bytes.extend_from_slice(&w.to_le_bytes()); |
| | } |
| | bytes |
| | } |
| |
|
| | |
| | pub fn deserialize_weights(&mut self, bytes: &[u8]) -> Result<(), &'static str> { |
| | if bytes.len() != self.weights.len() * 4 { |
| | return Err("Weight count mismatch"); |
| | } |
| |
|
| | for (i, chunk) in bytes.chunks(4).enumerate() { |
| | let arr: [u8; 4] = chunk.try_into().map_err(|_| "Invalid byte chunk")?; |
| | self.weights[i] = f32::from_le_bytes(arr); |
| | } |
| |
|
| | Ok(()) |
| | } |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct RouterStats { |
| | pub total_samples: usize, |
| | pub buffer_size: usize, |
| | pub weight_mean: f32, |
| | pub weight_std: f32, |
| | pub weight_min: f32, |
| | pub weight_max: f32, |
| | } |
| |
|
| | |
| | |
| | |
| | pub fn compute_routing_score( |
| | router: &LearnableRouter, |
| | query: &Point, |
| | centroid: &Point, |
| | temporal_distance: f32, |
| | temporal_weight: f32, |
| | ) -> f32 { |
| | let semantic_sim = router.weighted_similarity(query, centroid); |
| |
|
| | |
| | let semantic_dist = 1.0 - semantic_sim; |
| |
|
| | |
| | semantic_dist * (1.0 - temporal_weight) + temporal_distance * temporal_weight |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | fn make_point(v: Vec<f32>) -> Point { |
| | Point::new(v).normalize() |
| | } |
| |
|
| | #[test] |
| | fn test_router_creation() { |
| | let router = LearnableRouter::default_for_dims(64); |
| |
|
| | assert_eq!(router.dims(), 64); |
| | assert_eq!(router.weights().len(), 64); |
| | assert!(router.is_learning_enabled()); |
| |
|
| | |
| | for &w in router.weights() { |
| | assert!((w - 1.0).abs() < 1e-6); |
| | } |
| | } |
| |
|
| | #[test] |
| | fn test_weighted_similarity() { |
| | let router = LearnableRouter::default_for_dims(4); |
| |
|
| | let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
| | let centroid = make_point(vec![0.8, 0.2, 0.0, 0.0]); |
| |
|
| | let sim = router.weighted_similarity(&query, ¢roid); |
| |
|
| | |
| | let expected_cosine: f32 = query.dims().iter() |
| | .zip(centroid.dims().iter()) |
| | .map(|(q, c)| q * c) |
| | .sum(); |
| |
|
| | assert!((sim - expected_cosine).abs() < 1e-5); |
| | } |
| |
|
| | #[test] |
| | fn test_feedback_recording() { |
| | let mut router = LearnableRouter::new(4, LearnableRoutingConfig { |
| | min_samples_to_learn: 5, |
| | update_frequency: 5, |
| | ..Default::default() |
| | }); |
| |
|
| | let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
| | let centroid = make_point(vec![0.9, 0.1, 0.0, 0.0]); |
| |
|
| | |
| | for _ in 0..10 { |
| | router.record_success(&query, ¢roid, 0); |
| | } |
| |
|
| | let stats = router.stats(); |
| | assert_eq!(stats.total_samples, 10); |
| |
|
| | |
| | |
| | println!("Weights after positive feedback: {:?}", router.weights()); |
| | } |
| |
|
| | #[test] |
| | fn test_learning_dynamics() { |
| | let mut router = LearnableRouter::new(4, LearnableRoutingConfig { |
| | learning_rate: 0.1, |
| | min_samples_to_learn: 3, |
| | update_frequency: 3, |
| | momentum: 0.0, |
| | weight_decay: 0.0, |
| | ..Default::default() |
| | }); |
| |
|
| | |
| | let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
| | |
| | let centroid_good = make_point(vec![0.95, 0.05, 0.0, 0.0]); |
| | |
| | let centroid_bad = make_point(vec![0.0, 1.0, 0.0, 0.0]); |
| |
|
| | |
| | for _ in 0..6 { |
| | router.record_success(&query, ¢roid_good, 0); |
| | } |
| |
|
| | let weights_after_positive = router.weights().to_vec(); |
| |
|
| | |
| | for _ in 0..6 { |
| | router.record_failure(&query, ¢roid_bad, 0); |
| | } |
| |
|
| | let weights_after_negative = router.weights().to_vec(); |
| |
|
| | println!("Initial weights: [1.0, 1.0, 1.0, 1.0]"); |
| | println!("After positive: {:?}", weights_after_positive); |
| | println!("After negative: {:?}", weights_after_negative); |
| |
|
| | |
| | |
| | } |
| |
|
| | #[test] |
| | fn test_disabled_learning() { |
| | let mut router = LearnableRouter::new(4, LearnableRoutingConfig::disabled()); |
| |
|
| | assert!(!router.is_learning_enabled()); |
| |
|
| | let query = make_point(vec![1.0, 0.0, 0.0, 0.0]); |
| | let centroid = make_point(vec![0.9, 0.1, 0.0, 0.0]); |
| |
|
| | |
| | for _ in 0..100 { |
| | router.record_success(&query, ¢roid, 0); |
| | } |
| |
|
| | |
| | for &w in router.weights() { |
| | assert!((w - 1.0).abs() < 1e-6); |
| | } |
| | } |
| |
|
| | #[test] |
| | fn test_serialization() { |
| | let mut router = LearnableRouter::default_for_dims(4); |
| |
|
| | |
| | for (i, w) in router.weights.iter_mut().enumerate() { |
| | *w = (i as f32 + 1.0) * 0.5; |
| | } |
| |
|
| | let bytes = router.serialize_weights(); |
| |
|
| | let mut router2 = LearnableRouter::default_for_dims(4); |
| | router2.deserialize_weights(&bytes).unwrap(); |
| |
|
| | for (w1, w2) in router.weights().iter().zip(router2.weights().iter()) { |
| | assert!((w1 - w2).abs() < 1e-6); |
| | } |
| | } |
| | } |
| |
|