File size: 4,548 Bytes
a52f96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Memory decay model using Ebbinghaus forgetting curve.

Scientific basis: Retention after time t: R(t) = exp(-t / ฯ„)
where ฯ„ (tau) is the retention constant.
"""

import numpy as np
from typing import Dict, List
from dataclasses import dataclass


@dataclass
class MemoryRecord:
    """Record of practice session for a topic."""
    timestamp: float
    base_skill: float  # Skill level right after practice


class MemoryDecayModel:
    """
    Models realistic forgetting using Ebbinghaus curve.
    
    Key features:
    - Track last practice time per topic
    - Compute retention factor based on time elapsed
    - Effective skill = base_skill ร— retention_factor
    """
    
    def __init__(self, retention_constant: float = 80.0):
        """
        Args:
            retention_constant (tau): Controls forgetting speed.
                Higher = slower forgetting
                tau=80 means ~37% retention after 80 time steps
        """
        self.tau = retention_constant
        
        # Track per-topic memory
        self.topic_memories: Dict[str, MemoryRecord] = {}
        
        # Current time
        self.current_time: float = 0.0
    
    def update_practice(self, topic: str, base_skill: float):
        """
        Record that student just practiced a topic.
        
        Args:
            topic: Topic that was practiced
            base_skill: Student's skill level after practice (0.0-1.0)
        """
        self.topic_memories[topic] = MemoryRecord(
            timestamp=self.current_time,
            base_skill=base_skill
        )
    
    def get_retention_factor(self, topic: str) -> float:
        """
        Compute retention factor for a topic.
        
        Returns:
            Retention factor (0.0-1.0) based on Ebbinghaus curve
            1.0 = just practiced, decays exponentially over time
        """
        if topic not in self.topic_memories:
            return 1.0  # First time seeing topic
        
        memory = self.topic_memories[topic]
        time_elapsed = self.current_time - memory.timestamp
        
        # Ebbinghaus forgetting curve
        retention = np.exp(-time_elapsed / self.tau)
        
        return retention
    
    def get_effective_skill(self, topic: str) -> float:
        """
        Get current effective skill accounting for forgetting.
        
        Returns:
            Effective skill = base_skill ร— retention_factor
        """
        if topic not in self.topic_memories:
            return 0.0  # Never practiced
        
        memory = self.topic_memories[topic]
        retention = self.get_retention_factor(topic)
        
        return memory.base_skill * retention
    
    def get_time_since_practice(self, topic: str) -> float:
        """Get time elapsed since last practice."""
        if topic not in self.topic_memories:
            return float('inf')
        
        return self.current_time - self.topic_memories[topic].timestamp
    
    def advance_time(self, delta: float = 1.0):
        """Simulate time passing."""
        self.current_time += delta
    
    def get_all_topics(self) -> List[str]:
        """Get all topics that have been practiced."""
        return list(self.topic_memories.keys())
    
    def plot_forgetting_curves(self, topics: List[str] = None, 
                               save_path: str = 'forgetting_curves.png'):
        """
        Plot forgetting curves for topics.
        
        Shows how retention decays over time since last practice.
        """
        import matplotlib.pyplot as plt
        
        if topics is None:
            topics = self.get_all_topics()
        
        if not topics:
            print("โš ๏ธ No topics to plot")
            return
        
        # Generate time points
        time_range = np.linspace(0, 200, 100)
        
        plt.figure(figsize=(10, 6))
        for topic in topics:
            retentions = [np.exp(-t / self.tau) for t in time_range]
            plt.plot(time_range, retentions, label=topic, linewidth=2)
        
        plt.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, 
                   label='50% retention threshold')
        plt.xlabel('Time Since Practice', fontsize=12)
        plt.ylabel('Retention Factor', fontsize=12)
        plt.title('Ebbinghaus Forgetting Curves', fontsize=14)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(save_path, dpi=150)
        plt.close()
        print(f"๐Ÿ“Š Saved forgetting curves to {save_path}")