File size: 16,084 Bytes
7602502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
"""API client for Common Standards Project with retry logic and rate limiting."""

from __future__ import annotations

import json
import time
from typing import Any

import requests
from loguru import logger

from tools.config import get_settings
from tools.models import (
    Jurisdiction,
    JurisdictionDetails,
    StandardSet,
    StandardSetReference,
)

settings = get_settings()

# Cache file for jurisdictions
JURISDICTIONS_CACHE_FILE = settings.raw_data_dir / "jurisdictions.json"

# Rate limiting: Max requests per minute
MAX_REQUESTS_PER_MINUTE = settings.max_requests_per_minute
_request_timestamps: list[float] = []


class APIError(Exception):
    """Raised when API request fails after all retries."""

    pass


def _get_headers() -> dict[str, str]:
    """Get authentication headers for API requests."""
    if not settings.csp_api_key:
        logger.error("CSP_API_KEY not found in .env file")
        raise ValueError("CSP_API_KEY environment variable not set")
    return {"Api-Key": settings.csp_api_key}


def _enforce_rate_limit() -> None:
    """Enforce rate limiting by tracking request timestamps."""
    global _request_timestamps
    now = time.time()

    # Remove timestamps older than 1 minute
    _request_timestamps = [ts for ts in _request_timestamps if now - ts < 60]

    # If at limit, wait
    if len(_request_timestamps) >= MAX_REQUESTS_PER_MINUTE:
        sleep_time = 60 - (now - _request_timestamps[0])
        logger.warning(f"Rate limit reached. Waiting {sleep_time:.1f} seconds...")
        time.sleep(sleep_time)
        _request_timestamps = []

    _request_timestamps.append(now)


def _make_request(
    endpoint: str, params: dict[str, Any] | None = None, max_retries: int = 3
) -> dict[str, Any]:
    """
    Make API request with exponential backoff retry logic.

    Args:
        endpoint: API endpoint path (e.g., "/jurisdictions")
        params: Query parameters
        max_retries: Maximum number of retry attempts

    Returns:
        Parsed JSON response

    Raises:
        APIError: After all retries exhausted or on fatal errors
    """
    url = f"{settings.csp_base_url}{endpoint}"
    headers = _get_headers()

    for attempt in range(max_retries):
        try:
            _enforce_rate_limit()

            logger.debug(
                f"API request: {endpoint} (attempt {attempt + 1}/{max_retries})"
            )
            response = requests.get(url, headers=headers, params=params, timeout=30)

            # Handle specific status codes
            if response.status_code == 401:
                logger.error("Invalid API key (401 Unauthorized)")
                raise APIError("Authentication failed. Check your CSP_API_KEY in .env")

            if response.status_code == 404:
                logger.error(f"Resource not found (404): {endpoint}")
                raise APIError(f"Resource not found: {endpoint}")

            if response.status_code == 429:
                # Rate limited by server
                retry_after = int(response.headers.get("Retry-After", 60))
                logger.warning(
                    f"Server rate limit hit. Waiting {retry_after} seconds..."
                )
                time.sleep(retry_after)
                continue

            response.raise_for_status()
            logger.info(f"API request successful: {endpoint}")
            return response.json()

        except requests.exceptions.Timeout:
            wait_time = 2**attempt  # Exponential backoff: 1s, 2s, 4s
            logger.warning(f"Request timeout. Retrying in {wait_time}s...")
            if attempt < max_retries - 1:
                time.sleep(wait_time)
            else:
                raise APIError(f"Request timeout after {max_retries} attempts")

        except requests.exceptions.ConnectionError:
            wait_time = 2**attempt
            logger.warning(f"Connection error. Retrying in {wait_time}s...")
            if attempt < max_retries - 1:
                time.sleep(wait_time)
            else:
                raise APIError(f"Connection failed after {max_retries} attempts")

        except requests.exceptions.HTTPError as e:
            # Don't retry on 4xx errors (except 429)
            if 400 <= response.status_code < 500 and response.status_code != 429:
                raise APIError(f"HTTP {response.status_code}: {response.text}")
            # Retry on 5xx errors
            wait_time = 2**attempt
            logger.warning(
                f"Server error {response.status_code}. Retrying in {wait_time}s..."
            )
            if attempt < max_retries - 1:
                time.sleep(wait_time)
            else:
                raise APIError(f"Server error after {max_retries} attempts")

    raise APIError("Request failed after all retries")


def get_jurisdictions(
    search_term: str | None = None,
    type_filter: str | None = None,
    force_refresh: bool = False,
) -> list[Jurisdiction]:
    """
    Fetch all jurisdictions from the API or local cache.

    Jurisdictions are cached locally in data/raw/jurisdictions.json to avoid
    repeated API calls. Use force_refresh=True to fetch fresh data from the API.

    Args:
        search_term: Optional filter for jurisdiction title (case-insensitive partial match)
        type_filter: Optional filter for jurisdiction type (case-insensitive).
                     Valid values: "school", "organization", "state", "nation"
        force_refresh: If True, fetch fresh data from API and update cache

    Returns:
        List of Jurisdiction models
    """
    jurisdictions: list[Jurisdiction] = []
    raw_data: list[dict[str, Any]] = []

    # Check cache first (unless forcing refresh)
    if not force_refresh and JURISDICTIONS_CACHE_FILE.exists():
        try:
            logger.info("Loading jurisdictions from cache")
            with open(JURISDICTIONS_CACHE_FILE, encoding="utf-8") as f:
                cached_response = json.load(f)
            raw_data = cached_response.get("data", [])
            logger.info(f"Loaded {len(raw_data)} jurisdictions from cache")
        except (json.JSONDecodeError, IOError) as e:
            logger.warning(f"Failed to load cache: {e}. Fetching from API...")
            force_refresh = True

    # Fetch from API if cache doesn't exist or force_refresh is True
    if force_refresh or not raw_data:
        logger.info("Fetching jurisdictions from API")
        response = _make_request("/jurisdictions")
        raw_data = response.get("data", [])

        # Save to cache
        try:
            settings.raw_data_dir.mkdir(parents=True, exist_ok=True)
            with open(JURISDICTIONS_CACHE_FILE, "w", encoding="utf-8") as f:
                json.dump(response, f, indent=2, ensure_ascii=False)
            logger.info(
                f"Cached {len(raw_data)} jurisdictions to {JURISDICTIONS_CACHE_FILE}"
            )
        except IOError as e:
            logger.warning(f"Failed to save cache: {e}")

    # Parse into Pydantic models
    jurisdictions = [Jurisdiction(**j) for j in raw_data]

    # Apply type filter if provided (case-insensitive)
    if type_filter:
        type_lower = type_filter.lower()
        original_count = len(jurisdictions)
        jurisdictions = [j for j in jurisdictions if j.type.lower() == type_lower]
        logger.info(
            f"Filtered to {len(jurisdictions)} jurisdictions of type '{type_filter}' (from {original_count})"
        )

    # Apply search filter if provided (case-insensitive partial match)
    if search_term:
        search_lower = search_term.lower()
        original_count = len(jurisdictions)
        jurisdictions = [j for j in jurisdictions if search_lower in j.title.lower()]
        logger.info(
            f"Filtered to {len(jurisdictions)} jurisdictions matching '{search_term}' (from {original_count})"
        )

    return jurisdictions


def get_jurisdiction_details(
    jurisdiction_id: str, force_refresh: bool = False, hide_hidden_sets: bool = True
) -> JurisdictionDetails:
    """
    Fetch jurisdiction metadata including standard set references.

    Jurisdiction metadata is cached locally in data/raw/jurisdictions/{jurisdiction_id}/data.json
    to avoid repeated API calls. Use force_refresh=True to fetch fresh data from the API.

    Note: This returns metadata about standard sets (IDs, titles, subjects) but NOT the
    full standard set content. Use download_standard_set() to get full standard set data.

    Args:
        jurisdiction_id: The jurisdiction GUID
        force_refresh: If True, fetch fresh data from API and update cache
        hide_hidden_sets: If True, hide deprecated/outdated sets (default: True)

    Returns:
        JurisdictionDetails model with jurisdiction metadata and standardSets array
    """
    cache_dir = settings.raw_data_dir / "jurisdictions" / jurisdiction_id
    cache_file = cache_dir / "data.json"
    raw_data: dict[str, Any] = {}

    # Check cache first (unless forcing refresh)
    if not force_refresh and cache_file.exists():
        try:
            logger.info(f"Loading jurisdiction {jurisdiction_id} from cache")
            with open(cache_file, encoding="utf-8") as f:
                cached_response = json.load(f)
            raw_data = cached_response.get("data", {})
            logger.info(f"Loaded jurisdiction metadata from cache")
        except (json.JSONDecodeError, IOError) as e:
            logger.warning(f"Failed to load cache: {e}. Fetching from API...")
            force_refresh = True

    # Fetch from API if cache doesn't exist or force_refresh is True
    if force_refresh or not raw_data:
        logger.info(f"Fetching jurisdiction {jurisdiction_id} from API")
        params = {"hideHiddenSets": "true" if hide_hidden_sets else "false"}
        response = _make_request(f"/jurisdictions/{jurisdiction_id}", params=params)
        raw_data = response.get("data", {})

        # Save to cache
        try:
            cache_dir.mkdir(parents=True, exist_ok=True)
            with open(cache_file, "w", encoding="utf-8") as f:
                json.dump(response, f, indent=2, ensure_ascii=False)
            logger.info(f"Cached jurisdiction metadata to {cache_file}")
        except IOError as e:
            logger.warning(f"Failed to save cache: {e}")

    # Parse into Pydantic model
    return JurisdictionDetails(**raw_data)


def download_standard_set(set_id: str, force_refresh: bool = False) -> StandardSet:
    """
    Download full standard set data with caching.

    Standard set data is cached locally in data/raw/standardSets/{set_id}/data.json
    to avoid repeated API calls. Use force_refresh=True to fetch fresh data from the API.

    Args:
        set_id: The standard set GUID
        force_refresh: If True, fetch fresh data from API and update cache

    Returns:
        StandardSet model with complete standard set data including hierarchy
    """
    cache_dir = settings.raw_data_dir / "standardSets" / set_id
    cache_file = cache_dir / "data.json"
    raw_data: dict[str, Any] = {}

    # Check cache first (unless forcing refresh)
    if not force_refresh and cache_file.exists():
        try:
            logger.info(f"Loading standard set {set_id} from cache")
            with open(cache_file, encoding="utf-8") as f:
                cached_response = json.load(f)
            raw_data = cached_response.get("data", {})
            logger.info(f"Loaded standard set from cache")
        except (json.JSONDecodeError, IOError) as e:
            logger.warning(f"Failed to load cache: {e}. Fetching from API...")
            force_refresh = True

    # Fetch from API if cache doesn't exist or force_refresh is True
    if force_refresh or not raw_data:
        logger.info(f"Downloading standard set {set_id} from API")
        response = _make_request(f"/standard_sets/{set_id}")
        raw_data = response.get("data", {})

        # Save to cache
        try:
            cache_dir.mkdir(parents=True, exist_ok=True)
            with open(cache_file, "w", encoding="utf-8") as f:
                json.dump(response, f, indent=2, ensure_ascii=False)
            logger.info(f"Cached standard set to {cache_file}")
        except IOError as e:
            logger.warning(f"Failed to save cache: {e}")

    # Parse into Pydantic model
    return StandardSet(**raw_data)


def _filter_standard_set(
    standard_set: StandardSetReference,
    education_levels: list[str] | None = None,
    publication_status: str | None = None,
    valid_year: str | None = None,
    title_search: str | None = None,
    subject_search: str | None = None,
) -> bool:
    """
    Check if a standard set matches all provided filters (AND logic).

    Args:
        standard_set: StandardSetReference model from jurisdiction metadata
        education_levels: List of grade levels to match (any match)
        publication_status: Publication status to match
        valid_year: Valid year string to match
        title_search: Partial string match on title (case-insensitive)
        subject_search: Partial string match on subject (case-insensitive)

    Returns:
        True if standard set matches all provided filters
    """
    # Filter by education levels (any match)
    if education_levels:
        set_levels = {level.upper() for level in standard_set.educationLevels}
        filter_levels = {level.upper() for level in education_levels}
        if not set_levels.intersection(filter_levels):
            return False

    # Filter by publication status
    if publication_status:
        if (
            standard_set.document.publicationStatus
            and standard_set.document.publicationStatus.lower()
            != publication_status.lower()
        ):
            return False

    # Filter by valid year
    if valid_year:
        if standard_set.document.valid != valid_year:
            return False

    # Filter by title search (partial match, case-insensitive)
    if title_search:
        if title_search.lower() not in standard_set.title.lower():
            return False

    # Filter by subject search (partial match, case-insensitive)
    if subject_search:
        if subject_search.lower() not in standard_set.subject.lower():
            return False

    return True


def download_standard_sets_by_jurisdiction(
    jurisdiction_id: str,
    force_refresh: bool = False,
    education_levels: list[str] | None = None,
    publication_status: str | None = None,
    valid_year: str | None = None,
    title_search: str | None = None,
    subject_search: str | None = None,
) -> list[str]:
    """
    Download standard sets for a jurisdiction with optional filtering.

    Args:
        jurisdiction_id: The jurisdiction GUID
        force_refresh: If True, force refresh all downloads (ignores cache)
        education_levels: List of grade levels to filter by
        publication_status: Publication status to filter by
        valid_year: Valid year string to filter by
        title_search: Partial string match on title
        subject_search: Partial string match on subject

    Returns:
        List of downloaded standard set IDs
    """
    # Get jurisdiction metadata
    jurisdiction_data = get_jurisdiction_details(jurisdiction_id, force_refresh=False)
    standard_sets = jurisdiction_data.standardSets

    # Apply filters
    filtered_sets = [
        s
        for s in standard_sets
        if _filter_standard_set(
            s,
            education_levels=education_levels,
            publication_status=publication_status,
            valid_year=valid_year,
            title_search=title_search,
            subject_search=subject_search,
        )
    ]

    # Download each filtered standard set
    downloaded_ids = []
    for standard_set in filtered_sets:
        set_id = standard_set.id
        try:
            download_standard_set(set_id, force_refresh=force_refresh)
            downloaded_ids.append(set_id)
        except Exception as e:
            logger.error(f"Failed to download standard set {set_id}: {e}")

    return downloaded_ids