File size: 10,592 Bytes
a52f96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
"""
DistilBERT-based student agent with online learning and memory decay.

Uses DistilBERT for Multiple Choice to answer reading comprehension tasks.
Implements online learning (fine-tune on 1 example at a time).
"""

import torch
from torch.optim import AdamW
from transformers import (
    DistilBertForMultipleChoice,
    DistilBertTokenizer,
)
from typing import List, Dict
import numpy as np
from collections import defaultdict

from interfaces import StudentAgentInterface, StudentState, Task
from memory_decay import MemoryDecayModel


class StudentAgent(StudentAgentInterface):
    """
    DistilBERT-based student that learns reading comprehension.
    
    Features:
    - Online learning (1 example at a time)
    - Memory decay (Ebbinghaus forgetting)
    - Per-topic skill tracking
    - Gradient accumulation for stability
    """
    
    def __init__(
        self,
        learning_rate: float = 5e-5,
        retention_constant: float = 80.0,
        device: str = 'cpu',
        max_length: int = 256,
        gradient_accumulation_steps: int = 4
    ):
        """
        Args:
            learning_rate: LM fine-tuning learning rate
            retention_constant: Forgetting speed (higher = slower forgetting)
            device: 'cpu' or 'cuda'
            max_length: Max tokens for passage + question + choices
            gradient_accumulation_steps: Accumulate gradients for stability
        """
        self.device = device
        self.max_length = max_length
        self.gradient_accumulation_steps = gradient_accumulation_steps
        
        # Load DistilBERT for multiple choice
        # Allow silent mode for testing
        verbose = True  # Can be overridden
        
        try:
            if verbose:
                print("Loading DistilBERT model...", end=" ", flush=True)
            self.model = DistilBertForMultipleChoice.from_pretrained(
                "distilbert-base-uncased"
            ).to(self.device)
            
            self.tokenizer = DistilBertTokenizer.from_pretrained(
                "distilbert-base-uncased"
            )
            if verbose:
                print("✅")
        except Exception as e:
            if verbose:
                print(f"⚠️ (Model unavailable, using dummy mode)")
            self.model = None
            self.tokenizer = None
        
        # Optimizer
        if self.model:
            self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        else:
            self.optimizer = None
        
        # Memory decay model
        self.memory = MemoryDecayModel(retention_constant=retention_constant)
        
        # Track per-topic base skills (before forgetting)
        self.topic_base_skills: Dict[str, float] = {}
        
        # Track learning history
        self.topic_attempts: Dict[str, int] = defaultdict(int)
        self.topic_correct: Dict[str, int] = defaultdict(int)
        
        # Gradient accumulation counter
        self.grad_step = 0
        
        # Training mode flag
        if self.model:
            self.model.train()
    
    def answer(self, task: Task) -> int:
        """
        Predict answer without updating weights.
        
        Prediction accuracy is modulated by memory decay.
        """
        if not self.model:
            # Dummy model: random guessing
            return np.random.randint(0, 4)
        
        self.model.eval()
        
        # Prepare inputs
        inputs = self._prepare_inputs(task)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            predicted_idx = torch.argmax(logits, dim=-1).item()
        
        # Apply memory decay to prediction
        # If student has forgotten, prediction becomes more random
        effective_skill = self.memory.get_effective_skill(task.topic)
        
        # Probability of using learned answer vs random guess
        # MCQ baseline = 0.25 (random guessing)
        use_learned_prob = 0.25 + 0.75 * effective_skill
        
        if np.random.random() < use_learned_prob:
            return predicted_idx
        else:
            # Random guess
            return np.random.randint(0, 4)
    
    def learn(self, task: Task) -> bool:
        """
        Fine-tune on a single task (online learning).
        
        Returns:
            True if prediction was correct, False otherwise
        """
        if not self.model:
            # Dummy learning: track statistics only
            predicted = np.random.randint(0, 4)
            was_correct = (predicted == task.answer)
            self._update_stats(task, was_correct)
            return was_correct
        
        self.model.train()
        
        # Get prediction before learning
        predicted = self.answer(task)
        was_correct = (predicted == task.answer)
        
        # Prepare inputs with correct answer
        inputs = self._prepare_inputs(task)
        inputs['labels'] = torch.tensor([task.answer], device=self.device)
        
        # Forward pass
        outputs = self.model(**inputs)
        loss = outputs.loss
        
        # Backward pass with gradient accumulation
        loss = loss / self.gradient_accumulation_steps
        loss.backward()
        
        self.grad_step += 1
        
        # Update weights every N steps
        if self.grad_step % self.gradient_accumulation_steps == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        # Update statistics
        self._update_stats(task, was_correct)
        
        return was_correct
    
    def _update_stats(self, task: Task, was_correct: bool):
        """Update topic statistics and memory."""
        self.topic_attempts[task.topic] += 1
        if was_correct:
            self.topic_correct[task.topic] += 1
        
        # Compute base skill (accuracy without forgetting)
        base_skill = self.topic_correct[task.topic] / self.topic_attempts[task.topic]
        self.topic_base_skills[task.topic] = base_skill
        
        # Update memory (record practice)
        self.memory.update_practice(task.topic, base_skill)
    
    def evaluate(self, eval_tasks: List[Task]) -> float:
        """
        Evaluate on held-out tasks without updating weights.
        
        Returns:
            Accuracy (0.0-1.0)
        """
        if not eval_tasks:
            return 0.0
        
        if not self.model:
            # Dummy evaluation: return random
            return 0.25
        
        self.model.eval()
        
        correct = 0
        for task in eval_tasks:
            predicted = self.answer(task)
            if predicted == task.answer:
                correct += 1
        
        return correct / len(eval_tasks)
    
    def get_state(self) -> StudentState:
        """
        Get current state for teacher observation.
        
        Returns per-topic accuracies accounting for forgetting.
        """
        topic_accuracies = {}
        time_since_practice = {}
        
        for topic in self.topic_base_skills:
            # Get effective skill (with forgetting)
            effective_skill = self.memory.get_effective_skill(topic)
            
            # Convert to expected accuracy on MCQ
            topic_accuracies[topic] = 0.25 + 0.75 * effective_skill
            
            # Time since last practice
            time_since_practice[topic] = self.memory.get_time_since_practice(topic)
        
        return StudentState(
            topic_accuracies=topic_accuracies,
            topic_attempts=dict(self.topic_attempts),
            time_since_practice=time_since_practice,
            total_timesteps=sum(self.topic_attempts.values()),
            current_time=self.memory.current_time
        )
    
    def _prepare_inputs(self, task: Task) -> Dict[str, torch.Tensor]:
        """
        Prepare inputs for DistilBERT multiple choice model.
        
        Format: [CLS] passage [SEP] question [SEP] choice [SEP]
        Repeated for each of 4 choices.
        """
        if not self.tokenizer:
            return {}
        
        # Create 4 input sequences (one per choice)
        input_texts = []
        for choice in task.choices:
            # Format: passage + question + choice
            text = f"{task.passage} {task.question} {choice}"
            input_texts.append(text)
        
        # Tokenize
        encoded = self.tokenizer(
            input_texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Reshape for multiple choice format
        # (batch_size=1, num_choices=4, seq_length)
        input_ids = encoded['input_ids'].unsqueeze(0).to(self.device)
        attention_mask = encoded['attention_mask'].unsqueeze(0).to(self.device)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }
    
    def advance_time(self, delta: float = 1.0):
        """Advance time for memory decay."""
        self.memory.advance_time(delta)
    
    def save(self, path: str):
        """Save model checkpoint."""
        if not self.model:
            print("⚠️ No model to save (using dummy model)")
            return
        
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer else None,
            'topic_base_skills': self.topic_base_skills,
            'topic_attempts': dict(self.topic_attempts),
            'topic_correct': dict(self.topic_correct),
            'memory': self.memory,
            'grad_step': self.grad_step
        }, path)
        print(f"💾 Saved checkpoint to {path}")
    
    def load(self, path: str):
        """Load model checkpoint."""
        checkpoint = torch.load(path, map_location=self.device)
        
        if self.model:
            self.model.load_state_dict(checkpoint['model_state_dict'])
            if self.optimizer and checkpoint.get('optimizer_state_dict'):
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        self.topic_base_skills = checkpoint['topic_base_skills']
        self.topic_attempts = defaultdict(int, checkpoint['topic_attempts'])
        self.topic_correct = defaultdict(int, checkpoint['topic_correct'])
        self.memory = checkpoint['memory']
        self.grad_step = checkpoint.get('grad_step', 0)
        print(f"✅ Loaded checkpoint from {path}")