Spaces:
Running
Running
| """EEE Validator — HuggingFace Space webhook handler. | |
| Listens for PR events on evaleval/EEE_datastore, validates changed data | |
| files with Pydantic, checks for duplicates, and comments results on the PR. | |
| """ | |
| import logging | |
| import os | |
| import re | |
| import tempfile | |
| import threading | |
| from datetime import datetime, timezone | |
| from huggingface_hub import HfApi, WebhookPayload, WebhooksServer | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import EntryNotFoundError | |
| from dedup import DATASET_REPO_ID, DedupReport, check_duplicates, load_manifest | |
| from post_merge import handle_merge | |
| from validate_data import FileValidationResult, validate_with_pydantic | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| api = HfApi() | |
| # --------------------------------------------------------------------------- | |
| # Changed file discovery | |
| # --------------------------------------------------------------------------- | |
| def find_changed_files(pr_num: int) -> list[str]: | |
| """Find added/modified .json and .jsonl files by comparing PR tree to main. | |
| Falls back to tree comparison since DiscussionWithDetails.diff can be None | |
| for dataset repos. | |
| """ | |
| revision = f"refs/pr/{pr_num}" | |
| def _list_files(rev: str) -> dict[str, str]: | |
| """Return {path: oid} for all files at a given revision.""" | |
| files = {} | |
| for entry in api.list_repo_tree( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| revision=rev, | |
| recursive=True, | |
| ): | |
| if hasattr(entry, "rfilename"): | |
| files[entry.rfilename] = getattr(entry, "oid", None) | |
| return files | |
| pr_files = _list_files(revision) | |
| main_files = _list_files("main") | |
| changed: list[str] = [] | |
| for path, oid in pr_files.items(): | |
| if not path.startswith("data/"): | |
| continue | |
| if not (path.endswith(".json") or path.endswith(".jsonl")): | |
| continue | |
| # New file, or existing file with different content | |
| if path not in main_files or main_files[path] != oid: | |
| changed.append(path) | |
| return changed | |
| # --------------------------------------------------------------------------- | |
| # File download | |
| # --------------------------------------------------------------------------- | |
| def download_pr_files( | |
| file_paths: list[str], pr_num: int, tmp_dir: str | |
| ) -> dict[str, str]: | |
| """Download files from a PR branch and return map of repo-path -> local-path.""" | |
| downloaded: dict[str, str] = {} | |
| revision = f"refs/pr/{pr_num}" | |
| for file_path in file_paths: | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename=file_path, | |
| repo_type="dataset", | |
| revision=revision, | |
| local_dir=tmp_dir, | |
| ) | |
| downloaded[file_path] = local_path | |
| logger.info("Downloaded %s -> %s", file_path, local_path) | |
| except EntryNotFoundError: | |
| logger.warning("File not found in PR: %s", file_path) | |
| except Exception: | |
| logger.exception("Failed to download %s", file_path) | |
| return downloaded | |
| # --------------------------------------------------------------------------- | |
| # Validation orchestration | |
| # --------------------------------------------------------------------------- | |
| def validate_files( | |
| downloaded: dict[str, str], | |
| ) -> list[FileValidationResult]: | |
| """Validate all downloaded files and return results.""" | |
| results: list[FileValidationResult] = [] | |
| for repo_path, local_path in downloaded.items(): | |
| if repo_path.endswith(".jsonl"): | |
| file_type = "jsonl" | |
| else: | |
| file_type = "json" | |
| result = validate_with_pydantic(local_path, file_type) | |
| # Store the repo-relative path for reporting | |
| result.file_path = repo_path | |
| results.append(result) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Deduplication orchestration | |
| # --------------------------------------------------------------------------- | |
| def run_dedup( | |
| file_paths: list[str], downloaded: dict[str, str] | |
| ) -> DedupReport: | |
| """Load manifest and check all files for duplicates.""" | |
| manifest = load_manifest(api) | |
| # Read file contents as bytes | |
| file_contents: dict[str, bytes] = {} | |
| for repo_path, local_path in downloaded.items(): | |
| with open(local_path, "rb") as f: | |
| file_contents[repo_path] = f.read() | |
| return check_duplicates(file_paths, file_contents, manifest) | |
| # --------------------------------------------------------------------------- | |
| # Comment formatting | |
| # --------------------------------------------------------------------------- | |
| def format_comment( | |
| pr_num: int, | |
| validation_results: list[FileValidationResult], | |
| dedup_report: DedupReport, | |
| ) -> str: | |
| """Format the PR comment as markdown.""" | |
| now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") | |
| num_passed = sum(1 for r in validation_results if r.valid) | |
| num_failed = sum(1 for r in validation_results if not r.valid) | |
| total = len(validation_results) | |
| failed = [r for r in validation_results if not r.valid] | |
| if num_failed == 0: | |
| status_line = "## ✅ EEE Validation — Ready to Merge" | |
| else: | |
| status_line = "## ❌ EEE Validation — Changes Requested" | |
| lines = [ | |
| status_line, | |
| f"**PR:** #{pr_num} | **Run:** {now}", | |
| "", | |
| f"**{num_passed}/{total} files passed**", | |
| ] | |
| if num_failed > 0: | |
| lines.append("") | |
| lines.append("### Failures") | |
| lines.append("| File | Details |") | |
| lines.append("|------|---------|") | |
| for r in failed: | |
| error_summary = "; ".join(r.errors[:5]) | |
| if len(r.errors) > 5: | |
| error_summary += f" ... and {len(r.errors) - 5} more error(s)" | |
| lines.append(f"| `{r.file_path}` | {error_summary} |") | |
| # Dedup section | |
| has_any_dupes = False | |
| dedup_lines: list[str] = [] | |
| for dr in dedup_report.results: | |
| if dr.exact_duplicate_of: | |
| dedup_lines.append( | |
| f"- **Exact duplicate:** `{dr.file_path}` is identical to " | |
| f"existing `{dr.exact_duplicate_of}`" | |
| ) | |
| has_any_dupes = True | |
| if dr.near_duplicate_of: | |
| dedup_lines.append( | |
| f"- **Potential near-duplicate:** `{dr.file_path}` shares fingerprint " | |
| f"with existing `{dr.near_duplicate_of}` " | |
| f"(identical content minus timestamps/UUIDs)" | |
| ) | |
| has_any_dupes = True | |
| if has_any_dupes: | |
| lines.append("") | |
| lines.append("### Duplicate Check") | |
| lines.extend(dedup_lines) | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Core validation logic (shared by webhook + startup sweep) | |
| # --------------------------------------------------------------------------- | |
| REPORT_HEADER = "## ✅ EEE Validation" | |
| REPORT_HEADER_FAIL = "## ❌ EEE Validation" | |
| def process_pr(pr_num: int) -> dict: | |
| """Run full validation + dedup on a PR and post results as a comment.""" | |
| logger.info("Processing PR #%d", pr_num) | |
| # Guard: skip if already validated for the current state | |
| if not pr_needs_validation(pr_num): | |
| logger.info("PR #%d already validated for current state, skipping", pr_num) | |
| return {"status": "skipped", "reason": "already validated"} | |
| # Find changed data files by comparing PR tree to main | |
| changed_files = find_changed_files(pr_num) | |
| if not changed_files: | |
| logger.info("No data files changed in PR #%d", pr_num) | |
| return {"status": "skipped", "reason": "no data files changed"} | |
| logger.info("Found %d changed data file(s): %s", len(changed_files), changed_files) | |
| # Create temp directory for downloads | |
| tmp_dir = tempfile.mkdtemp(prefix=f"eee_validate_{pr_num}_") | |
| # Download changed files from the PR branch | |
| downloaded = download_pr_files(changed_files, pr_num, tmp_dir) | |
| if not downloaded: | |
| logger.warning("No files could be downloaded for PR #%d", pr_num) | |
| return {"status": "error", "reason": "no files downloaded"} | |
| # Validate files | |
| validation_results = validate_files(downloaded) | |
| # Run dedup check | |
| dedup_report = run_dedup(changed_files, downloaded) | |
| # Format and post comment | |
| comment = format_comment(pr_num, validation_results, dedup_report) | |
| logger.info("Posting validation comment on PR #%d", pr_num) | |
| api.comment_discussion( | |
| repo_id=DATASET_REPO_ID, | |
| discussion_num=pr_num, | |
| comment=comment, | |
| repo_type="dataset", | |
| ) | |
| return { | |
| "status": "ok", | |
| "pr": pr_num, | |
| "files_checked": len(validation_results), | |
| "passed": sum(1 for r in validation_results if r.valid), | |
| "failed": sum(1 for r in validation_results if not r.valid), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Startup sweep — catch PRs missed while the Space was asleep | |
| # --------------------------------------------------------------------------- | |
| def pr_needs_validation(pr_num: int) -> bool: | |
| """Check if a PR has commits newer than the last validation report.""" | |
| details = api.get_discussion_details( | |
| repo_id=DATASET_REPO_ID, | |
| discussion_num=pr_num, | |
| repo_type="dataset", | |
| ) | |
| last_report_time = None | |
| last_commit_time = None | |
| for event in details.events: | |
| if event.type == "comment" and event.content and ( | |
| event.content.startswith(REPORT_HEADER) or event.content.startswith(REPORT_HEADER_FAIL) | |
| ): | |
| last_report_time = event.created_at | |
| if event.type == "commit": | |
| last_commit_time = event.created_at | |
| # No report yet — needs validation | |
| if last_report_time is None: | |
| return True | |
| # Has commits after the last report — needs re-validation | |
| if last_commit_time is not None and last_commit_time > last_report_time: | |
| return True | |
| return False | |
| def startup_sweep() -> None: | |
| """Scan open PRs and validate any that are missing a report.""" | |
| logger.info("Running startup sweep for unvalidated PRs...") | |
| try: | |
| discussions = api.get_repo_discussions( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| for disc in discussions: | |
| if not disc.is_pull_request or disc.status != "open": | |
| continue | |
| if not pr_needs_validation(disc.num): | |
| logger.info("PR #%d is up to date, skipping", disc.num) | |
| continue | |
| logger.info("PR #%d needs validation, processing", disc.num) | |
| try: | |
| process_pr(disc.num) | |
| except Exception: | |
| logger.exception("Startup sweep failed for PR #%d", disc.num) | |
| except Exception: | |
| logger.exception("Startup sweep failed to list discussions") | |
| logger.info("Startup sweep complete") | |
| # Run sweep in background thread so it doesn't block the webhook server startup | |
| threading.Thread(target=startup_sweep, daemon=True).start() | |
| # --------------------------------------------------------------------------- | |
| # Webhook endpoint | |
| # --------------------------------------------------------------------------- | |
| PR_REF_RE = re.compile(r"^refs/pr/(\d+)$") | |
| def _extract_pr_nums_from_refs(payload) -> list[int]: | |
| """Extract PR numbers from updatedRefs in repo.content events.""" | |
| if not payload.updatedRefs: | |
| return [] | |
| pr_nums = [] | |
| for ref in payload.updatedRefs: | |
| m = PR_REF_RE.match(ref.ref) | |
| if m: | |
| pr_nums.append(int(m.group(1))) | |
| return pr_nums | |
| async def validate(payload: WebhookPayload): | |
| """Handle incoming webhook events from HuggingFace.""" | |
| logger.info("Received webhook event: %s", payload.event) | |
| # Filter: ignore comments | |
| if payload.event.scope == "discussion.comment": | |
| logger.info("Skipping comment event") | |
| return {"status": "skipped", "reason": "comment event"} | |
| if payload.repo.type != "dataset": | |
| logger.info("Skipping non-dataset event (type=%s)", payload.repo.type) | |
| return {"status": "skipped", "reason": "not a dataset repo"} | |
| # Route 1: repo.content events (new commits pushed to PR branches) | |
| if payload.event.scope == "repo.content": | |
| pr_nums = _extract_pr_nums_from_refs(payload) | |
| if not pr_nums: | |
| logger.info("No PR refs in repo.content event") | |
| return {"status": "skipped", "reason": "no PR refs"} | |
| results = [] | |
| for pr_num in pr_nums: | |
| try: | |
| results.append(process_pr(pr_num)) | |
| except Exception: | |
| logger.exception("Failed to process PR #%d", pr_num) | |
| results.append({"status": "error", "pr": pr_num}) | |
| return {"status": "ok", "results": results} | |
| # Route 2: discussion events (status changes, merges) | |
| if not payload.discussion or not payload.discussion.isPullRequest: | |
| logger.info("Skipping non-PR event") | |
| return {"status": "skipped", "reason": "not a pull request"} | |
| pr_num = payload.discussion.num | |
| # Handle merged PRs — update manifest + dataset card | |
| if payload.discussion.status == "merged": | |
| try: | |
| return handle_merge(pr_num) | |
| except Exception: | |
| logger.exception("Post-merge failed for PR #%d", pr_num) | |
| return {"status": "error", "reason": "post-merge failed"} | |
| # Handle open PRs — validate + dedup | |
| try: | |
| return process_pr(pr_num) | |
| except Exception: | |
| logger.exception("Failed to process PR #%d", pr_num) | |
| return {"status": "error", "reason": "processing failed"} | |
| # --------------------------------------------------------------------------- | |
| # Server startup | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| app = WebhooksServer( | |
| ui=None, # No UI, just webhook endpoint | |
| webhook_secret=os.environ.get("WEBHOOKS_SECRET"), | |
| ) | |
| app.add_webhook(validate) | |
| # Start the server on port 7860 (HF Spaces default) | |
| port = int(os.environ.get("PORT", 7860)) | |
| logger.info("Starting webhook server on port %d", port) | |
| app.launch(server_name="0.0.0.0", server_port=port) | |