File size: 5,592 Bytes
cd8c2bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
from typing import Dict, Any
from . import prompts
from .schemas import (
    ItemExplanationInput, ItemExplanationOutput,
    MasteryDiagnosticInput, MasteryDiagnosticOutput,
    NextItemSelectorInput, NextItemSelectorOutput,
    SkillFeedbackInput, SkillFeedbackOutput,
    HintGenerationInput, HintGenerationOutput,
    ReflectionInput, ReflectionOutput,
    InstructorInsightInput, InstructorInsightRow,
    ExplanationCompressionInput, ExplanationCompressionOutput,
    QuestionAuthoringInput, QuestionAuthoringOutput,
    ToneNormalizerInput, ToneNormalizerOutput,
)
from .validation import parse_and_validate
from .cache import make_key, get as cache_get, set as cache_set
from .adapters.qwen_adapter import QwenAdapter

PRESETS = {
    'item_explanation': dict(temperature=0.2, max_tokens=256),
    'mastery_diagnostic': dict(temperature=0.2, max_tokens=128),
    'next_item_selector': dict(temperature=0.2, max_tokens=128),
    'skill_feedback': dict(temperature=0.3, max_tokens=256),
    'hint_generation': dict(temperature=0.6, max_tokens=200),
    'reflection': dict(temperature=0.3, max_tokens=120),
    'instructor_insight': dict(temperature=0.2, max_tokens=160),
    'explanation_compression': dict(temperature=0.2, max_tokens=80),
    'question_authoring': dict(temperature=0.6, max_tokens=400),
    'tone_normalizer': dict(temperature=0.2, max_tokens=60),
}

SYSTEMS = {
    'item_explanation': prompts.item_explanation,
    'mastery_diagnostic': prompts.mastery_diagnostic,
    'next_item_selector': prompts.next_item_selector,
    'skill_feedback': prompts.skill_feedback,
    'hint_generation': prompts.hint_generation,
    'reflection': prompts.reflection,
    'instructor_insight': prompts.instructor_insight,
    'explanation_compression': prompts.explanation_compression,
    'question_authoring': prompts.question_authoring,
    'tone_normalizer': prompts.tone_normalizer,
}

INPUT_MODELS = {
    'item_explanation': ItemExplanationInput,
    'mastery_diagnostic': MasteryDiagnosticInput,
    'next_item_selector': NextItemSelectorInput,
    'skill_feedback': SkillFeedbackInput,
    'hint_generation': HintGenerationInput,
    'reflection': ReflectionInput,
    'instructor_insight': InstructorInsightInput,
    'explanation_compression': ExplanationCompressionInput,
    'question_authoring': QuestionAuthoringInput,
    'tone_normalizer': ToneNormalizerInput,
}

OUTPUT_MODELS = {
    'item_explanation': ItemExplanationOutput,
    'mastery_diagnostic': MasteryDiagnosticOutput,
    'next_item_selector': NextItemSelectorOutput,
    'skill_feedback': SkillFeedbackOutput,
    'hint_generation': HintGenerationOutput,
    'reflection': ReflectionOutput,
    'instructor_insight': InstructorInsightRow,  # list validated separately
    'explanation_compression': ExplanationCompressionOutput,
    'question_authoring': QuestionAuthoringOutput,
    'tone_normalizer': ToneNormalizerOutput,
}

_adapter = None
SPECIAL_CACHE_KEYS = {'item_explanation', 'hint_generation'}


def _get_adapter(model_id: str) -> QwenAdapter:
    global _adapter
    if _adapter is None:
        _adapter = QwenAdapter(model_name=model_id)
    return _adapter

def _cache_key(prompt_name: str, input_data: Dict[str, Any], model_id: str, temperature: float) -> str:
    special = None
    if prompt_name in SPECIAL_CACHE_KEYS:
        if prompt_name == 'item_explanation':
            q = input_data.get('question', '')
            ua = input_data.get('user_answer', '')
            special = f"{q}\u241f{ua}"
        elif prompt_name == 'hint_generation':
            q = input_data.get('question', '')
            special = q
    base = json.dumps(input_data, sort_keys=True)
    parts = [prompt_name, base, model_id, temperature, special or '-']
    return make_key(*parts)


def run_prompt(prompt_name: str, input_payload: Dict[str, Any], *, model_id: str = 'Qwen/Qwen3-7B-Instruct', seed: int = 42) -> Any:
    if prompt_name not in PRESETS:
        raise ValueError(f'Unknown prompt: {prompt_name}')

    input_model = INPUT_MODELS[prompt_name]
    parsed_input = input_model.parse_obj(input_payload)

    preset = PRESETS[prompt_name]
    ckey = _cache_key(prompt_name, parsed_input.dict(by_alias=True), model_id, preset['temperature'])
    cached = cache_get(ckey)
    if cached is not None:
        return json.loads(cached)

    # Get adapter with lazy initialization
    adapter = _get_adapter(model_id)
    
    system = SYSTEMS[prompt_name]()
    user = json.dumps(parsed_input.dict(by_alias=True), ensure_ascii=False)

    text = adapter.generate(
        system=system,
        user=f"Return JSON only. No commentary.\nInput: {user}",
        temperature=preset['temperature'],
        max_tokens=preset['max_tokens'],
        stop=None,
        seed=seed,
    )

    if prompt_name == 'instructor_insight':
        data = json.loads(text)
        if not isinstance(data, list):
            raise ValueError('Expected a JSON array')
        from .schemas import InstructorInsightRow
        validated = [InstructorInsightRow.parse_obj(x).dict() for x in data]
        out_obj = validated
    else:
        out_model = OUTPUT_MODELS[prompt_name]
        out_obj = parse_and_validate(out_model, text)
        # Handle RootModel (Pydantic v2)
        if hasattr(out_obj, 'root'):
            out_obj = out_obj.root
        elif hasattr(out_obj, 'dict'):
            out_obj = out_obj.dict(by_alias=True)
        elif hasattr(out_obj, '__root__'):
            out_obj = out_obj.__root__

    cache_set(ckey, json.dumps(out_obj, ensure_ascii=False))
    return out_obj