# ============================================================================ # CONTENTFORGE AI - FASTAPI BACKEND # REST API for multi-modal AI platform # ============================================================================ from fastapi import FastAPI, HTTPException, Header, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional import torch import os from huggingface_hub import login import base64 from io import BytesIO import numpy as np import wave import struct # ============================================================================ # AUTHENTICATION # ============================================================================ HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: print("šŸ” Authenticating with HuggingFace...") login(token=HF_TOKEN) print("āœ… Authenticated!\n") from transformers import ( T5Tokenizer, T5ForConditionalGeneration, Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoProcessor, MusicgenForConditionalGeneration ) from peft import PeftModel from qwen_vl_utils import process_vision_info from diffusers import StableDiffusionPipeline from PIL import Image # ============================================================================ # FASTAPI APP SETUP # ============================================================================ app = FastAPI( title="ContentForge AI API", description="Multi-modal AI API for education and social media content generation", version="1.0.0" ) # CORS - Allow requests from your frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production: ["https://yourwebsite.vercel.app"] allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Simple API key authentication (improve this for production!) API_KEYS = { "demo_key_123": "Demo User", "sk_test_456": "Test User", } def verify_api_key(x_api_key: str = Header(None)): """Verify API key from header""" if x_api_key not in API_KEYS: raise HTTPException(status_code=401, detail="Invalid API Key") return API_KEYS[x_api_key] # ============================================================================ # LOAD MODELS # ============================================================================ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"šŸ–„ļø Using device: {device}") print("šŸ“¦ Loading models...\n") # 1. T5 Model print("šŸ“ Loading T5...") t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer") t5_model = T5ForConditionalGeneration.from_pretrained( "Bashaarat1/t5-small-arxiv-summarizer" ).to(device) t5_model.eval() print("āœ… T5 loaded!") # 2. Qwen VLM print("šŸ¤– Loading Qwen...") qwen_base = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", device_map="auto", torch_dtype=torch.bfloat16 ) qwen_model = PeftModel.from_pretrained( qwen_base, "Bashaarat1/qwen-finetuned-scienceqa" ) qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") qwen_model.eval() print("āœ… Qwen loaded!") # 3. MusicGen print("šŸŽµ Loading MusicGen...") music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small") music_model = MusicgenForConditionalGeneration.from_pretrained( "Bashaarat1/fine-tuned-musicgen-small" ).to(device) music_model.eval() print("āœ… MusicGen loaded!") # 4. Stable Diffusion print("šŸŽØ Loading Stable Diffusion...") sd_pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32, safety_checker=None ).to(device) print("āœ… Stable Diffusion loaded!") print("\nšŸŽ‰ All models loaded! API ready.\n") # ============================================================================ # REQUEST/RESPONSE MODELS # ============================================================================ class SummarizeRequest(BaseModel): text: str max_length: int = 128 class SummarizeResponse(BaseModel): summary: str original_words: int summary_words: int class QARequest(BaseModel): question: str image_base64: Optional[str] = None class QAResponse(BaseModel): answer: str class ImageRequest(BaseModel): prompt: str negative_prompt: str = "" num_steps: int = 25 class ImageResponse(BaseModel): image_base64: str class MusicRequest(BaseModel): prompt: str duration: int = 10 class MusicResponse(BaseModel): audio_base64: str sampling_rate: int format: str # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def numpy_to_wav(audio_data: np.ndarray, sampling_rate: int) -> bytes: """Convert numpy array to WAV format bytes""" # Normalize audio to -1 to 1 range audio_data = np.clip(audio_data, -1, 1) # Convert to 16-bit PCM audio_int16 = (audio_data * 32767).astype(np.int16) # Create WAV file in memory wav_io = BytesIO() with wave.open(wav_io, 'wb') as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(sampling_rate) wav_file.writeframes(audio_int16.tobytes()) return wav_io.getvalue() # ============================================================================ # API ENDPOINTS # ============================================================================ @app.get("/") def root(): """API health check""" return { "status": "online", "message": "ContentForge AI API", "version": "1.0.0", "endpoints": [ "/summarize", "/qa", "/generate-image", "/generate-music" ] } @app.post("/summarize", response_model=SummarizeResponse) def summarize( request: SummarizeRequest, user: str = Header(None, alias="x-api-key") ): """Summarize text using fine-tuned T5""" verify_api_key(user) if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") try: inputs = t5_tokenizer( f"summarize: {request.text}", return_tensors="pt", max_length=512, truncation=True ).to(device) with torch.no_grad(): outputs = t5_model.generate( **inputs, max_length=request.max_length, min_length=30, num_beams=4, early_stopping=True ) summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) return SummarizeResponse( summary=summary, original_words=len(request.text.split()), summary_words=len(summary.split()) ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/qa", response_model=QAResponse) def question_answer( request: QARequest, user: str = Header(None, alias="x-api-key") ): """Answer questions with optional image using Qwen VLM""" verify_api_key(user) if not request.question.strip(): raise HTTPException(status_code=400, detail="Question cannot be empty") try: image = None if request.image_base64: # Decode base64 image image_data = base64.b64decode(request.image_base64) image = Image.open(BytesIO(image_data)).convert('RGB') # Prepare messages if image is not None: messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": request.question} ] }] else: messages = [{ "role": "user", "content": [{"type": "text", "text": request.question}] }] text_prompt = qwen_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) if image is not None: img_inputs, _ = process_vision_info(messages) inputs = qwen_processor( text=[text_prompt], images=img_inputs, return_tensors="pt" ).to(device) else: inputs = qwen_processor( text=[text_prompt], return_tensors="pt" ).to(device) with torch.no_grad(): outputs = qwen_model.generate(**inputs, max_new_tokens=200) answer = qwen_processor.batch_decode( outputs[:, inputs.input_ids.size(1):], skip_special_tokens=True )[0].strip() return QAResponse(answer=answer) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate-image", response_model=ImageResponse) def generate_image( request: ImageRequest, user: str = Header(None, alias="x-api-key") ): """Generate image using Stable Diffusion""" verify_api_key(user) if not request.prompt.strip(): raise HTTPException(status_code=400, detail="Prompt cannot be empty") try: with torch.no_grad(): image = sd_pipe( request.prompt, negative_prompt=request.negative_prompt, num_inference_steps=request.num_steps, guidance_scale=7.5 ).images[0] # Convert image to base64 buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return ImageResponse(image_base64=img_str) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate-music", response_model=MusicResponse) def generate_music( request: MusicRequest, user: str = Header(None, alias="x-api-key") ): """Generate music using MusicGen""" verify_api_key(user) if not request.prompt.strip(): raise HTTPException(status_code=400, detail="Prompt cannot be empty") try: inputs = music_processor( text=[request.prompt], padding=True, return_tensors="pt" ).to(device) max_tokens = int(request.duration * 50) with torch.no_grad(): audio_values = music_model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True ) sampling_rate = music_model.config.audio_encoder.sampling_rate audio_data = audio_values[0, 0].cpu().numpy() # Convert to WAV format wav_bytes = numpy_to_wav(audio_data, sampling_rate) # Encode to base64 audio_str = base64.b64encode(wav_bytes).decode() return MusicResponse( audio_base64=audio_str, sampling_rate=sampling_rate, format="wav" ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # RUN SERVER # ============================================================================ if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)