Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # app/main.py | |
| from fastapi import FastAPI, UploadFile, File, Request, Form, Query | |
| from fastapi.responses import HTMLResponse, PlainTextResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from cbow_logic import MeaningCalculator | |
| from ppo_logic import generate_summary | |
| import numpy as np | |
| import json | |
| import shutil | |
| from pathlib import Path | |
| import uvicorn | |
| import os | |
| import praw | |
| import random | |
| from vit_captioning.generate import CaptionGenerator | |
| from cbow_logic import MeaningCalculator | |
| reddit = praw.Reddit( | |
| client_id=os.getenv("REDDIT_CLIENT_ID"), | |
| client_secret=os.getenv("REDDIT_CLIENT_SECRET"), | |
| user_agent="script:ContentDistilleryBot:v0.1 (by u/ClementHa)" | |
| ) | |
| app = FastAPI() | |
| templates = Jinja2Templates(directory="templates") | |
| calculator = MeaningCalculator() | |
| # Serve static files | |
| static_dir = Path(__file__).parent / "vit_captioning" / "static" | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| #Landing page at `/` | |
| async def landing(): | |
| return Path("vit_captioning/static/landing.html").read_text() | |
| def health_check(): | |
| return {"status": "ok"} | |
| # β Captioning page at `/captioning` | |
| async def captioning(): | |
| return Path("vit_captioning/static/captioning/index.html").read_text() | |
| async def contentdistillery(): | |
| return Path("content_distillery/static/content_distillery.html").read_text() | |
| # β Caption generation endpoint for captioning app | |
| # Keep the path consistent with your JS fetch()! | |
| caption_generator = CaptionGenerator( | |
| model_type="CLIPEncoder", | |
| checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth", | |
| quantized=False, | |
| runAsContainer=False | |
| ) | |
| async def generate(file: UploadFile = File(...)): | |
| temp_file = os.path.join("/tmp", file.filename) | |
| with open(temp_file, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| captions = caption_generator.generate_caption(temp_file) | |
| return captions | |
| async def cbow_form(request: Request): | |
| return templates.TemplateResponse("cbow.html", {"request": request}) | |
| async def cbow(request: Request, expression: str = Form(...)): | |
| expression = expression.lower() | |
| results = MeaningCalculator().evaluate_expression(expression = expression) | |
| # formatted = [ | |
| # (word, f"{score:.2f}" if score >= 0.4 else "Irrelevant result") | |
| # for word, score in results[:5] | |
| # ] | |
| return templates.TemplateResponse("cbow.html", { | |
| "request": request, | |
| "expression": expression, | |
| "results": results | |
| }) | |
| async def contentdistillery_page(): | |
| return Path("contentdistillery.html").read_text(encoding="utf-8") | |
| async def generate_summary_from_post(post: str = Form(...)): | |
| return generate_summary(post) | |
| def get_sample(source: str = Query(...)): | |
| try: | |
| if source == "reddit_romance": | |
| submissions = reddit.subreddit("relationships").top(limit=10) | |
| elif source == "reddit_aita": | |
| submissions = reddit.subreddit("AmItheAsshole").hot(limit=10) | |
| elif source == "reddit_careers": | |
| submissions = reddit.subreddit("careerguidance").hot(limit=10) | |
| elif source == "reddit_cars": | |
| submissions = reddit.subreddit("cars").hot(limit=10) | |
| elif source == "reddit_whatcarshouldibuy": | |
| submissions = reddit.subreddit("whatcarshouldibuy").top(limit=10) | |
| elif source == "reddit_nosleep": | |
| submissions = reddit.subreddit("nosleep").top(limit=10) | |
| elif source == "reddit_maliciouscompliance": | |
| submissions = reddit.subreddit("maliciouscompliance").hot(limit=10) | |
| elif source == "reddit_talesfromtechsupport": | |
| submissions = reddit.subreddit("talesfromtechsupport").top(limit=10) | |
| elif source == "reddit_decidingtobebetter": | |
| submissions = reddit.subreddit("decidingtobebetter").hot(limit=10) | |
| elif source == "reddit_askphilosophy": | |
| submissions = reddit.subreddit("askphilosophy").top(limit=10) | |
| else: | |
| return "Unsupported source." | |
| posts = [s.selftext.strip() for s in submissions if s.selftext.strip()] | |
| if posts: | |
| return random.choice(posts) | |
| return "No suitable post found." | |
| except Exception as e: | |
| return f"Error fetching Reddit post: {str(e)}" | |