| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | use crate::core::Id; |
| |
|
| | |
| | #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| | pub enum Role { |
| | |
| | System, |
| | |
| | User, |
| | |
| | Assistant, |
| | |
| | Tool, |
| | |
| | Context, |
| | } |
| |
|
| | impl Role { |
| | pub fn as_str(&self) -> &'static str { |
| | match self { |
| | Role::System => "system", |
| | Role::User => "user", |
| | Role::Assistant => "assistant", |
| | Role::Tool => "tool", |
| | Role::Context => "context", |
| | } |
| | } |
| |
|
| | pub fn from_str(s: &str) -> Option<Self> { |
| | match s.to_lowercase().as_str() { |
| | "system" => Some(Role::System), |
| | "user" => Some(Role::User), |
| | "assistant" => Some(Role::Assistant), |
| | "tool" | "function" => Some(Role::Tool), |
| | "context" | "retrieved" => Some(Role::Context), |
| | _ => None, |
| | } |
| | } |
| |
|
| | fn to_byte(&self) -> u8 { |
| | match self { |
| | Role::System => 0, |
| | Role::User => 1, |
| | Role::Assistant => 2, |
| | Role::Tool => 3, |
| | Role::Context => 4, |
| | } |
| | } |
| |
|
| | fn from_byte(b: u8) -> Option<Self> { |
| | match b { |
| | 0 => Some(Role::System), |
| | 1 => Some(Role::User), |
| | 2 => Some(Role::Assistant), |
| | 3 => Some(Role::Tool), |
| | 4 => Some(Role::Context), |
| | _ => None, |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | #[derive(Debug, Clone)] |
| | pub struct CompressedKV { |
| | |
| | pub model_id: String, |
| |
|
| | |
| | pub num_layers: u32, |
| |
|
| | |
| | pub num_heads: u32, |
| |
|
| | |
| | pub head_dim: u32, |
| |
|
| | |
| | pub seq_len: u32, |
| |
|
| | |
| | pub quantization: String, |
| |
|
| | |
| | |
| | |
| | pub data: Vec<u8>, |
| | } |
| |
|
| | impl CompressedKV { |
| | |
| | pub fn size_bytes(&self) -> usize { |
| | self.data.len() |
| | } |
| |
|
| | |
| | pub fn placeholder(model_id: &str) -> Self { |
| | Self { |
| | model_id: model_id.to_string(), |
| | num_layers: 0, |
| | num_heads: 0, |
| | head_dim: 0, |
| | seq_len: 0, |
| | quantization: "none".to_string(), |
| | data: vec![], |
| | } |
| | } |
| |
|
| | |
| | pub fn to_bytes(&self) -> Vec<u8> { |
| | let mut bytes = Vec::new(); |
| |
|
| | |
| | let model_bytes = self.model_id.as_bytes(); |
| | bytes.extend_from_slice(&(model_bytes.len() as u32).to_le_bytes()); |
| | bytes.extend_from_slice(model_bytes); |
| |
|
| | |
| | bytes.extend_from_slice(&self.num_layers.to_le_bytes()); |
| | bytes.extend_from_slice(&self.num_heads.to_le_bytes()); |
| | bytes.extend_from_slice(&self.head_dim.to_le_bytes()); |
| | bytes.extend_from_slice(&self.seq_len.to_le_bytes()); |
| |
|
| | |
| | let quant_bytes = self.quantization.as_bytes(); |
| | bytes.extend_from_slice(&(quant_bytes.len() as u32).to_le_bytes()); |
| | bytes.extend_from_slice(quant_bytes); |
| |
|
| | |
| | bytes.extend_from_slice(&(self.data.len() as u64).to_le_bytes()); |
| | bytes.extend_from_slice(&self.data); |
| |
|
| | bytes |
| | } |
| |
|
| | |
| | pub fn from_bytes(data: &[u8]) -> Option<(Self, usize)> { |
| | let mut offset = 0; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return None; |
| | } |
| | let model_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize; |
| | offset += 4; |
| |
|
| | if data.len() < offset + model_len { |
| | return None; |
| | } |
| | let model_id = String::from_utf8(data[offset..offset + model_len].to_vec()).ok()?; |
| | offset += model_len; |
| |
|
| | |
| | if data.len() < offset + 16 { |
| | return None; |
| | } |
| | let num_layers = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| | offset += 4; |
| | let num_heads = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| | offset += 4; |
| | let head_dim = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| | offset += 4; |
| | let seq_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| | offset += 4; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return None; |
| | } |
| | let quant_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize; |
| | offset += 4; |
| |
|
| | if data.len() < offset + quant_len { |
| | return None; |
| | } |
| | let quantization = String::from_utf8(data[offset..offset + quant_len].to_vec()).ok()?; |
| | offset += quant_len; |
| |
|
| | |
| | if data.len() < offset + 8 { |
| | return None; |
| | } |
| | let data_len = u64::from_le_bytes(data[offset..offset + 8].try_into().ok()?) as usize; |
| | offset += 8; |
| |
|
| | if data.len() < offset + data_len { |
| | return None; |
| | } |
| | let kv_data = data[offset..offset + data_len].to_vec(); |
| | offset += data_len; |
| |
|
| | Some(( |
| | Self { |
| | model_id, |
| | num_layers, |
| | num_heads, |
| | head_dim, |
| | seq_len, |
| | quantization, |
| | data: kv_data, |
| | }, |
| | offset, |
| | )) |
| | } |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct AttentionState { |
| | |
| | pub id: Id, |
| |
|
| | |
| | pub timestamp_ms: u64, |
| |
|
| | |
| | pub role: Role, |
| |
|
| | |
| | pub text: String, |
| |
|
| | |
| | pub embedding: Vec<f32>, |
| |
|
| | |
| | pub kv_cache: Option<CompressedKV>, |
| |
|
| | |
| | pub metadata: std::collections::HashMap<String, String>, |
| | } |
| |
|
| | impl AttentionState { |
| | |
| | pub fn new(role: Role, text: String, embedding: Vec<f32>) -> Self { |
| | Self { |
| | id: Id::now(), |
| | timestamp_ms: std::time::SystemTime::now() |
| | .duration_since(std::time::UNIX_EPOCH) |
| | .unwrap() |
| | .as_millis() as u64, |
| | role, |
| | text, |
| | embedding, |
| | kv_cache: None, |
| | metadata: std::collections::HashMap::new(), |
| | } |
| | } |
| |
|
| | |
| | pub fn with_kv_cache(mut self, kv: CompressedKV) -> Self { |
| | self.kv_cache = Some(kv); |
| | self |
| | } |
| |
|
| | |
| | pub fn with_metadata(mut self, key: &str, value: &str) -> Self { |
| | self.metadata.insert(key.to_string(), value.to_string()); |
| | self |
| | } |
| |
|
| | |
| | pub fn size_bytes(&self) -> usize { |
| | 16 + |
| | 8 + |
| | 1 + |
| | self.text.len() + |
| | self.embedding.len() * 4 + |
| | self.kv_cache.as_ref().map(|kv| kv.size_bytes()).unwrap_or(0) + |
| | self.metadata.iter().map(|(k, v)| k.len() + v.len() + 8).sum::<usize>() |
| | } |
| |
|
| | |
| | pub fn to_bytes(&self) -> Vec<u8> { |
| | let mut bytes = Vec::new(); |
| |
|
| | |
| | bytes.extend_from_slice(b"ATTN"); |
| | bytes.extend_from_slice(&1u32.to_le_bytes()); |
| |
|
| | |
| | bytes.extend_from_slice(self.id.as_bytes()); |
| |
|
| | |
| | bytes.extend_from_slice(&self.timestamp_ms.to_le_bytes()); |
| |
|
| | |
| | bytes.push(self.role.to_byte()); |
| |
|
| | |
| | let text_bytes = self.text.as_bytes(); |
| | bytes.extend_from_slice(&(text_bytes.len() as u32).to_le_bytes()); |
| | bytes.extend_from_slice(text_bytes); |
| |
|
| | |
| | bytes.extend_from_slice(&(self.embedding.len() as u32).to_le_bytes()); |
| | for &v in &self.embedding { |
| | bytes.extend_from_slice(&v.to_le_bytes()); |
| | } |
| |
|
| | |
| | if let Some(ref kv) = self.kv_cache { |
| | bytes.push(1); |
| | let kv_bytes = kv.to_bytes(); |
| | bytes.extend_from_slice(&(kv_bytes.len() as u64).to_le_bytes()); |
| | bytes.extend_from_slice(&kv_bytes); |
| | } else { |
| | bytes.push(0); |
| | } |
| |
|
| | |
| | bytes.extend_from_slice(&(self.metadata.len() as u32).to_le_bytes()); |
| | for (key, value) in &self.metadata { |
| | let key_bytes = key.as_bytes(); |
| | let value_bytes = value.as_bytes(); |
| | bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes()); |
| | bytes.extend_from_slice(key_bytes); |
| | bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes()); |
| | bytes.extend_from_slice(value_bytes); |
| | } |
| |
|
| | bytes |
| | } |
| |
|
| | |
| | pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> { |
| | let mut offset = 0; |
| |
|
| | |
| | if data.len() < 8 { |
| | return Err(AttentionError::InvalidFormat("Too short".into())); |
| | } |
| | if &data[0..4] != b"ATTN" { |
| | return Err(AttentionError::InvalidMagic); |
| | } |
| | offset += 4; |
| |
|
| | |
| | let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); |
| | if version != 1 { |
| | return Err(AttentionError::UnsupportedVersion(version)); |
| | } |
| | offset += 4; |
| |
|
| | |
| | if data.len() < offset + 16 { |
| | return Err(AttentionError::InvalidFormat("Missing ID".into())); |
| | } |
| | let mut id_bytes = [0u8; 16]; |
| | id_bytes.copy_from_slice(&data[offset..offset + 16]); |
| | let id = Id::from_bytes(id_bytes); |
| | offset += 16; |
| |
|
| | |
| | if data.len() < offset + 8 { |
| | return Err(AttentionError::InvalidFormat("Missing timestamp".into())); |
| | } |
| | let timestamp_ms = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); |
| | offset += 8; |
| |
|
| | |
| | if data.len() < offset + 1 { |
| | return Err(AttentionError::InvalidFormat("Missing role".into())); |
| | } |
| | let role = Role::from_byte(data[offset]) |
| | .ok_or_else(|| AttentionError::InvalidFormat("Invalid role".into()))?; |
| | offset += 1; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return Err(AttentionError::InvalidFormat("Missing text length".into())); |
| | } |
| | let text_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| | offset += 4; |
| |
|
| | if data.len() < offset + text_len { |
| | return Err(AttentionError::InvalidFormat("Text truncated".into())); |
| | } |
| | let text = String::from_utf8(data[offset..offset + text_len].to_vec()) |
| | .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in text".into()))?; |
| | offset += text_len; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return Err(AttentionError::InvalidFormat("Missing embedding length".into())); |
| | } |
| | let emb_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| | offset += 4; |
| |
|
| | if data.len() < offset + emb_len * 4 { |
| | return Err(AttentionError::InvalidFormat("Embedding truncated".into())); |
| | } |
| | let mut embedding = Vec::with_capacity(emb_len); |
| | for _ in 0..emb_len { |
| | embedding.push(f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap())); |
| | offset += 4; |
| | } |
| |
|
| | |
| | if data.len() < offset + 1 { |
| | return Err(AttentionError::InvalidFormat("Missing KV flag".into())); |
| | } |
| | let has_kv = data[offset] != 0; |
| | offset += 1; |
| |
|
| | let kv_cache = if has_kv { |
| | if data.len() < offset + 8 { |
| | return Err(AttentionError::InvalidFormat("Missing KV length".into())); |
| | } |
| | let kv_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; |
| | offset += 8; |
| |
|
| | if data.len() < offset + kv_len { |
| | return Err(AttentionError::InvalidFormat("KV data truncated".into())); |
| | } |
| | let (kv, _) = CompressedKV::from_bytes(&data[offset..offset + kv_len]) |
| | .ok_or_else(|| AttentionError::InvalidFormat("Invalid KV cache".into()))?; |
| | offset += kv_len; |
| | Some(kv) |
| | } else { |
| | None |
| | }; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return Err(AttentionError::InvalidFormat("Missing metadata count".into())); |
| | } |
| | let meta_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| | offset += 4; |
| |
|
| | let mut metadata = std::collections::HashMap::new(); |
| | for _ in 0..meta_count { |
| | |
| | if data.len() < offset + 4 { |
| | return Err(AttentionError::InvalidFormat("Missing key length".into())); |
| | } |
| | let key_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| | offset += 4; |
| |
|
| | if data.len() < offset + key_len { |
| | return Err(AttentionError::InvalidFormat("Key truncated".into())); |
| | } |
| | let key = String::from_utf8(data[offset..offset + key_len].to_vec()) |
| | .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in key".into()))?; |
| | offset += key_len; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return Err(AttentionError::InvalidFormat("Missing value length".into())); |
| | } |
| | let value_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| | offset += 4; |
| |
|
| | if data.len() < offset + value_len { |
| | return Err(AttentionError::InvalidFormat("Value truncated".into())); |
| | } |
| | let value = String::from_utf8(data[offset..offset + value_len].to_vec()) |
| | .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in value".into()))?; |
| | offset += value_len; |
| |
|
| | metadata.insert(key, value); |
| | } |
| |
|
| | Ok(Self { |
| | id, |
| | timestamp_ms, |
| | role, |
| | text, |
| | embedding, |
| | kv_cache, |
| | metadata, |
| | }) |
| | } |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub enum AttentionError { |
| | InvalidMagic, |
| | UnsupportedVersion(u32), |
| | InvalidFormat(String), |
| | } |
| |
|
| | impl std::fmt::Display for AttentionError { |
| | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| | match self { |
| | AttentionError::InvalidMagic => write!(f, "Invalid magic bytes"), |
| | AttentionError::UnsupportedVersion(v) => write!(f, "Unsupported version: {}", v), |
| | AttentionError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg), |
| | } |
| | } |
| | } |
| |
|
| | impl std::error::Error for AttentionError {} |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct AttentionBatch { |
| | |
| | pub states: Vec<AttentionState>, |
| |
|
| | |
| | pub session_id: Option<Id>, |
| |
|
| | |
| | pub document_id: Option<Id>, |
| | } |
| |
|
| | impl AttentionBatch { |
| | pub fn new() -> Self { |
| | Self { |
| | states: Vec::new(), |
| | session_id: None, |
| | document_id: None, |
| | } |
| | } |
| |
|
| | pub fn with_session(mut self, session_id: Id) -> Self { |
| | self.session_id = Some(session_id); |
| | self |
| | } |
| |
|
| | pub fn with_document(mut self, document_id: Id) -> Self { |
| | self.document_id = Some(document_id); |
| | self |
| | } |
| |
|
| | pub fn add(&mut self, state: AttentionState) { |
| | self.states.push(state); |
| | } |
| |
|
| | |
| | pub fn size_bytes(&self) -> usize { |
| | self.states.iter().map(|s| s.size_bytes()).sum() |
| | } |
| |
|
| | |
| | pub fn to_bytes(&self) -> Vec<u8> { |
| | let mut bytes = Vec::new(); |
| |
|
| | |
| | bytes.extend_from_slice(b"ATNB"); |
| | bytes.extend_from_slice(&1u32.to_le_bytes()); |
| |
|
| | |
| | if let Some(sid) = self.session_id { |
| | bytes.push(1); |
| | bytes.extend_from_slice(sid.as_bytes()); |
| | } else { |
| | bytes.push(0); |
| | } |
| |
|
| | |
| | if let Some(did) = self.document_id { |
| | bytes.push(1); |
| | bytes.extend_from_slice(did.as_bytes()); |
| | } else { |
| | bytes.push(0); |
| | } |
| |
|
| | |
| | bytes.extend_from_slice(&(self.states.len() as u32).to_le_bytes()); |
| |
|
| | |
| | for state in &self.states { |
| | let state_bytes = state.to_bytes(); |
| | bytes.extend_from_slice(&(state_bytes.len() as u64).to_le_bytes()); |
| | bytes.extend_from_slice(&state_bytes); |
| | } |
| |
|
| | bytes |
| | } |
| |
|
| | |
| | pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> { |
| | let mut offset = 0; |
| |
|
| | |
| | if data.len() < 8 { |
| | return Err(AttentionError::InvalidFormat("Too short".into())); |
| | } |
| | if &data[0..4] != b"ATNB" { |
| | return Err(AttentionError::InvalidMagic); |
| | } |
| | offset += 4; |
| |
|
| | |
| | let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); |
| | if version != 1 { |
| | return Err(AttentionError::UnsupportedVersion(version)); |
| | } |
| | offset += 4; |
| |
|
| | |
| | if data.len() < offset + 1 { |
| | return Err(AttentionError::InvalidFormat("Missing session flag".into())); |
| | } |
| | let has_session = data[offset] != 0; |
| | offset += 1; |
| |
|
| | let session_id = if has_session { |
| | if data.len() < offset + 16 { |
| | return Err(AttentionError::InvalidFormat("Missing session ID".into())); |
| | } |
| | let mut id_bytes = [0u8; 16]; |
| | id_bytes.copy_from_slice(&data[offset..offset + 16]); |
| | offset += 16; |
| | Some(Id::from_bytes(id_bytes)) |
| | } else { |
| | None |
| | }; |
| |
|
| | |
| | if data.len() < offset + 1 { |
| | return Err(AttentionError::InvalidFormat("Missing document flag".into())); |
| | } |
| | let has_document = data[offset] != 0; |
| | offset += 1; |
| |
|
| | let document_id = if has_document { |
| | if data.len() < offset + 16 { |
| | return Err(AttentionError::InvalidFormat("Missing document ID".into())); |
| | } |
| | let mut id_bytes = [0u8; 16]; |
| | id_bytes.copy_from_slice(&data[offset..offset + 16]); |
| | offset += 16; |
| | Some(Id::from_bytes(id_bytes)) |
| | } else { |
| | None |
| | }; |
| |
|
| | |
| | if data.len() < offset + 4 { |
| | return Err(AttentionError::InvalidFormat("Missing state count".into())); |
| | } |
| | let state_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| | offset += 4; |
| |
|
| | |
| | let mut states = Vec::with_capacity(state_count); |
| | for _ in 0..state_count { |
| | if data.len() < offset + 8 { |
| | return Err(AttentionError::InvalidFormat("Missing state length".into())); |
| | } |
| | let state_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; |
| | offset += 8; |
| |
|
| | if data.len() < offset + state_len { |
| | return Err(AttentionError::InvalidFormat("State truncated".into())); |
| | } |
| | let state = AttentionState::from_bytes(&data[offset..offset + state_len])?; |
| | offset += state_len; |
| | states.push(state); |
| | } |
| |
|
| | Ok(Self { |
| | states, |
| | session_id, |
| | document_id, |
| | }) |
| | } |
| | } |
| |
|
| | impl Default for AttentionBatch { |
| | fn default() -> Self { |
| | Self::new() |
| | } |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_role_roundtrip() { |
| | for role in [Role::System, Role::User, Role::Assistant, Role::Tool, Role::Context] { |
| | let byte = role.to_byte(); |
| | let restored = Role::from_byte(byte).unwrap(); |
| | assert_eq!(role, restored); |
| | } |
| | } |
| |
|
| | #[test] |
| | fn test_attention_state_roundtrip() { |
| | let state = AttentionState::new( |
| | Role::User, |
| | "Hello, how are you?".to_string(), |
| | vec![0.1, 0.2, 0.3, 0.4], |
| | ) |
| | .with_metadata("turn", "1"); |
| |
|
| | let bytes = state.to_bytes(); |
| | let restored = AttentionState::from_bytes(&bytes).unwrap(); |
| |
|
| | assert_eq!(state.role, restored.role); |
| | assert_eq!(state.text, restored.text); |
| | assert_eq!(state.embedding, restored.embedding); |
| | assert_eq!(state.metadata.get("turn"), restored.metadata.get("turn")); |
| | } |
| |
|
| | #[test] |
| | fn test_attention_state_with_kv() { |
| | let kv = CompressedKV { |
| | model_id: "llama-3-8b".to_string(), |
| | num_layers: 32, |
| | num_heads: 32, |
| | head_dim: 128, |
| | seq_len: 10, |
| | quantization: "fp16".to_string(), |
| | data: vec![1, 2, 3, 4, 5], |
| | }; |
| |
|
| | let state = AttentionState::new( |
| | Role::Assistant, |
| | "I'm doing well!".to_string(), |
| | vec![0.5, 0.6, 0.7, 0.8], |
| | ) |
| | .with_kv_cache(kv); |
| |
|
| | let bytes = state.to_bytes(); |
| | let restored = AttentionState::from_bytes(&bytes).unwrap(); |
| |
|
| | assert!(restored.kv_cache.is_some()); |
| | let restored_kv = restored.kv_cache.unwrap(); |
| | assert_eq!(restored_kv.model_id, "llama-3-8b"); |
| | assert_eq!(restored_kv.num_layers, 32); |
| | assert_eq!(restored_kv.data, vec![1, 2, 3, 4, 5]); |
| | } |
| |
|
| | #[test] |
| | fn test_batch_roundtrip() { |
| | let mut batch = AttentionBatch::new() |
| | .with_session(Id::now()); |
| |
|
| | batch.add(AttentionState::new( |
| | Role::User, |
| | "Question 1".to_string(), |
| | vec![0.1, 0.2], |
| | )); |
| |
|
| | batch.add(AttentionState::new( |
| | Role::Assistant, |
| | "Answer 1".to_string(), |
| | vec![0.3, 0.4], |
| | )); |
| |
|
| | let bytes = batch.to_bytes(); |
| | let restored = AttentionBatch::from_bytes(&bytes).unwrap(); |
| |
|
| | assert_eq!(restored.states.len(), 2); |
| | assert_eq!(restored.states[0].text, "Question 1"); |
| | assert_eq!(restored.states[1].text, "Answer 1"); |
| | assert!(restored.session_id.is_some()); |
| | } |
| | } |
| |
|