Bashaarat1 commited on
Commit
83c150d
·
verified ·
1 Parent(s): 2e66473

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +358 -0
app.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # CONTENTFORGE AI - FASTAPI BACKEND
3
+ # REST API for multi-modal AI platform
4
+ # ============================================================================
5
+
6
+ from fastapi import FastAPI, HTTPException, Header, File, UploadFile
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ from typing import Optional
10
+ import torch
11
+ import os
12
+ from huggingface_hub import login
13
+ import base64
14
+ from io import BytesIO
15
+ import numpy as np
16
+
17
+ # ============================================================================
18
+ # AUTHENTICATION
19
+ # ============================================================================
20
+
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
22
+ if HF_TOKEN:
23
+ print("🔐 Authenticating with HuggingFace...")
24
+ login(token=HF_TOKEN)
25
+ print("✅ Authenticated!\n")
26
+
27
+ from transformers import (
28
+ T5Tokenizer, T5ForConditionalGeneration,
29
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor,
30
+ AutoProcessor, MusicgenForConditionalGeneration
31
+ )
32
+ from peft import PeftModel
33
+ from qwen_vl_utils import process_vision_info
34
+ from diffusers import StableDiffusionPipeline
35
+ from PIL import Image
36
+
37
+ # ============================================================================
38
+ # FASTAPI APP SETUP
39
+ # ============================================================================
40
+
41
+ app = FastAPI(
42
+ title="ContentForge AI API",
43
+ description="Multi-modal AI API for education and social media content generation",
44
+ version="1.0.0"
45
+ )
46
+
47
+ # CORS - Allow requests from your frontend
48
+ app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"], # In production: ["https://yourwebsite.vercel.app"]
51
+ allow_credentials=True,
52
+ allow_methods=["*"],
53
+ allow_headers=["*"],
54
+ )
55
+
56
+ # Simple API key authentication (improve this for production!)
57
+ API_KEYS = {
58
+ "demo_key_123": "Demo User",
59
+ "sk_test_456": "Test User",
60
+ }
61
+
62
+ def verify_api_key(x_api_key: str = Header(None)):
63
+ """Verify API key from header"""
64
+ if x_api_key not in API_KEYS:
65
+ raise HTTPException(status_code=401, detail="Invalid API Key")
66
+ return API_KEYS[x_api_key]
67
+
68
+ # ============================================================================
69
+ # LOAD MODELS
70
+ # ============================================================================
71
+
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ print(f"🖥️ Using device: {device}")
74
+ print("📦 Loading models...\n")
75
+
76
+ # 1. T5 Model
77
+ print("📝 Loading T5...")
78
+ t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer")
79
+ t5_model = T5ForConditionalGeneration.from_pretrained(
80
+ "Bashaarat1/t5-small-arxiv-summarizer"
81
+ ).to(device)
82
+ t5_model.eval()
83
+ print("✅ T5 loaded!")
84
+
85
+ # 2. Qwen VLM
86
+ print("🤖 Loading Qwen...")
87
+ qwen_base = Qwen2VLForConditionalGeneration.from_pretrained(
88
+ "Qwen/Qwen2-VL-2B-Instruct",
89
+ device_map="auto",
90
+ torch_dtype=torch.bfloat16
91
+ )
92
+ qwen_model = PeftModel.from_pretrained(
93
+ qwen_base,
94
+ "Bashaarat1/qwen-finetuned-scienceqa"
95
+ )
96
+ qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
97
+ qwen_model.eval()
98
+ print("✅ Qwen loaded!")
99
+
100
+ # 3. MusicGen
101
+ print("🎵 Loading MusicGen...")
102
+ music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small")
103
+ music_model = MusicgenForConditionalGeneration.from_pretrained(
104
+ "Bashaarat1/fine-tuned-musicgen-small"
105
+ ).to(device)
106
+ music_model.eval()
107
+ print("✅ MusicGen loaded!")
108
+
109
+ # 4. Stable Diffusion
110
+ print("🎨 Loading Stable Diffusion...")
111
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
112
+ "runwayml/stable-diffusion-v1-5",
113
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
114
+ safety_checker=None
115
+ ).to(device)
116
+ print("✅ Stable Diffusion loaded!")
117
+
118
+ print("\n🎉 All models loaded! API ready.\n")
119
+
120
+ # ============================================================================
121
+ # REQUEST/RESPONSE MODELS
122
+ # ============================================================================
123
+
124
+ class SummarizeRequest(BaseModel):
125
+ text: str
126
+ max_length: int = 128
127
+
128
+ class SummarizeResponse(BaseModel):
129
+ summary: str
130
+ original_words: int
131
+ summary_words: int
132
+
133
+ class QARequest(BaseModel):
134
+ question: str
135
+ image_base64: Optional[str] = None
136
+
137
+ class QAResponse(BaseModel):
138
+ answer: str
139
+
140
+ class ImageRequest(BaseModel):
141
+ prompt: str
142
+ negative_prompt: str = ""
143
+ num_steps: int = 25
144
+
145
+ class ImageResponse(BaseModel):
146
+ image_base64: str
147
+
148
+ class MusicRequest(BaseModel):
149
+ prompt: str
150
+ duration: int = 10
151
+
152
+ class MusicResponse(BaseModel):
153
+ audio_base64: str
154
+ sampling_rate: int
155
+
156
+ # ============================================================================
157
+ # API ENDPOINTS
158
+ # ============================================================================
159
+
160
+ @app.get("/")
161
+ def root():
162
+ """API health check"""
163
+ return {
164
+ "status": "online",
165
+ "message": "ContentForge AI API",
166
+ "version": "1.0.0",
167
+ "endpoints": [
168
+ "/summarize",
169
+ "/qa",
170
+ "/generate-image",
171
+ "/generate-music"
172
+ ]
173
+ }
174
+
175
+ @app.post("/summarize", response_model=SummarizeResponse)
176
+ def summarize(
177
+ request: SummarizeRequest,
178
+ user: str = Header(None, alias="x-api-key")
179
+ ):
180
+ """Summarize text using fine-tuned T5"""
181
+ verify_api_key(user)
182
+
183
+ if not request.text.strip():
184
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
185
+
186
+ try:
187
+ inputs = t5_tokenizer(
188
+ f"summarize: {request.text}",
189
+ return_tensors="pt",
190
+ max_length=512,
191
+ truncation=True
192
+ ).to(device)
193
+
194
+ with torch.no_grad():
195
+ outputs = t5_model.generate(
196
+ **inputs,
197
+ max_length=request.max_length,
198
+ min_length=30,
199
+ num_beams=4,
200
+ early_stopping=True
201
+ )
202
+
203
+ summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
204
+
205
+ return SummarizeResponse(
206
+ summary=summary,
207
+ original_words=len(request.text.split()),
208
+ summary_words=len(summary.split())
209
+ )
210
+
211
+ except Exception as e:
212
+ raise HTTPException(status_code=500, detail=str(e))
213
+
214
+ @app.post("/qa", response_model=QAResponse)
215
+ def question_answer(
216
+ request: QARequest,
217
+ user: str = Header(None, alias="x-api-key")
218
+ ):
219
+ """Answer questions with optional image using Qwen VLM"""
220
+ verify_api_key(user)
221
+
222
+ if not request.question.strip():
223
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
224
+
225
+ try:
226
+ image = None
227
+ if request.image_base64:
228
+ # Decode base64 image
229
+ image_data = base64.b64decode(request.image_base64)
230
+ image = Image.open(BytesIO(image_data)).convert('RGB')
231
+
232
+ # Prepare messages
233
+ if image is not None:
234
+ messages = [{
235
+ "role": "user",
236
+ "content": [
237
+ {"type": "image", "image": image},
238
+ {"type": "text", "text": request.question}
239
+ ]
240
+ }]
241
+ else:
242
+ messages = [{
243
+ "role": "user",
244
+ "content": [{"type": "text", "text": request.question}]
245
+ }]
246
+
247
+ text_prompt = qwen_processor.apply_chat_template(
248
+ messages,
249
+ tokenize=False,
250
+ add_generation_prompt=True
251
+ )
252
+
253
+ if image is not None:
254
+ img_inputs, _ = process_vision_info(messages)
255
+ inputs = qwen_processor(
256
+ text=[text_prompt],
257
+ images=img_inputs,
258
+ return_tensors="pt"
259
+ ).to(device)
260
+ else:
261
+ inputs = qwen_processor(
262
+ text=[text_prompt],
263
+ return_tensors="pt"
264
+ ).to(device)
265
+
266
+ with torch.no_grad():
267
+ outputs = qwen_model.generate(**inputs, max_new_tokens=200)
268
+
269
+ answer = qwen_processor.batch_decode(
270
+ outputs[:, inputs.input_ids.size(1):],
271
+ skip_special_tokens=True
272
+ )[0].strip()
273
+
274
+ return QAResponse(answer=answer)
275
+
276
+ except Exception as e:
277
+ raise HTTPException(status_code=500, detail=str(e))
278
+
279
+ @app.post("/generate-image", response_model=ImageResponse)
280
+ def generate_image(
281
+ request: ImageRequest,
282
+ user: str = Header(None, alias="x-api-key")
283
+ ):
284
+ """Generate image using Stable Diffusion"""
285
+ verify_api_key(user)
286
+
287
+ if not request.prompt.strip():
288
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
289
+
290
+ try:
291
+ with torch.no_grad():
292
+ image = sd_pipe(
293
+ request.prompt,
294
+ negative_prompt=request.negative_prompt,
295
+ num_inference_steps=request.num_steps,
296
+ guidance_scale=7.5
297
+ ).images[0]
298
+
299
+ # Convert image to base64
300
+ buffered = BytesIO()
301
+ image.save(buffered, format="PNG")
302
+ img_str = base64.b64encode(buffered.getvalue()).decode()
303
+
304
+ return ImageResponse(image_base64=img_str)
305
+
306
+ except Exception as e:
307
+ raise HTTPException(status_code=500, detail=str(e))
308
+
309
+ @app.post("/generate-music", response_model=MusicResponse)
310
+ def generate_music(
311
+ request: MusicRequest,
312
+ user: str = Header(None, alias="x-api-key")
313
+ ):
314
+ """Generate music using MusicGen"""
315
+ verify_api_key(user)
316
+
317
+ if not request.prompt.strip():
318
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
319
+
320
+ try:
321
+ inputs = music_processor(
322
+ text=[request.prompt],
323
+ padding=True,
324
+ return_tensors="pt"
325
+ ).to(device)
326
+
327
+ max_tokens = int(request.duration * 50)
328
+
329
+ with torch.no_grad():
330
+ audio_values = music_model.generate(
331
+ **inputs,
332
+ max_new_tokens=max_tokens,
333
+ do_sample=True
334
+ )
335
+
336
+ sampling_rate = music_model.config.audio_encoder.sampling_rate
337
+ audio_data = audio_values[0, 0].cpu().numpy()
338
+
339
+ # Convert audio to base64
340
+ audio_bytes = BytesIO()
341
+ np.save(audio_bytes, audio_data)
342
+ audio_str = base64.b64encode(audio_bytes.getvalue()).decode()
343
+
344
+ return MusicResponse(
345
+ audio_base64=audio_str,
346
+ sampling_rate=sampling_rate
347
+ )
348
+
349
+ except Exception as e:
350
+ raise HTTPException(status_code=500, detail=str(e))
351
+
352
+ # ============================================================================
353
+ # RUN SERVER
354
+ # ============================================================================
355
+
356
+ if __name__ == "__main__":
357
+ import uvicorn
358
+ uvicorn.run(app, host="0.0.0.0", port=7860)