File size: 11,671 Bytes
83c150d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30b9f33
 
83c150d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30b9f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83c150d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30b9f33
 
 
 
 
83c150d
 
 
30b9f33
 
83c150d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# ============================================================================
# CONTENTFORGE AI - FASTAPI BACKEND
# REST API for multi-modal AI platform
# ============================================================================

from fastapi import FastAPI, HTTPException, Header, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import torch
import os
from huggingface_hub import login
import base64
from io import BytesIO
import numpy as np
import wave
import struct

# ============================================================================
# AUTHENTICATION
# ============================================================================

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    print("πŸ” Authenticating with HuggingFace...")
    login(token=HF_TOKEN)
    print("βœ… Authenticated!\n")

from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    Qwen2VLForConditionalGeneration, Qwen2VLProcessor,
    AutoProcessor, MusicgenForConditionalGeneration
)
from peft import PeftModel
from qwen_vl_utils import process_vision_info
from diffusers import StableDiffusionPipeline
from PIL import Image

# ============================================================================
# FASTAPI APP SETUP
# ============================================================================

app = FastAPI(
    title="ContentForge AI API",
    description="Multi-modal AI API for education and social media content generation",
    version="1.0.0"
)

# CORS - Allow requests from your frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production: ["https://yourwebsite.vercel.app"]
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Simple API key authentication (improve this for production!)
API_KEYS = {
    "demo_key_123": "Demo User",
    "sk_test_456": "Test User",
}

def verify_api_key(x_api_key: str = Header(None)):
    """Verify API key from header"""
    if x_api_key not in API_KEYS:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    return API_KEYS[x_api_key]

# ============================================================================
# LOAD MODELS
# ============================================================================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸ–₯️ Using device: {device}")
print("πŸ“¦ Loading models...\n")

# 1. T5 Model
print("πŸ“ Loading T5...")
t5_tokenizer = T5Tokenizer.from_pretrained("Bashaarat1/t5-small-arxiv-summarizer")
t5_model = T5ForConditionalGeneration.from_pretrained(
    "Bashaarat1/t5-small-arxiv-summarizer"
).to(device)
t5_model.eval()
print("βœ… T5 loaded!")

# 2. Qwen VLM
print("πŸ€– Loading Qwen...")
qwen_base = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16
)
qwen_model = PeftModel.from_pretrained(
    qwen_base,
    "Bashaarat1/qwen-finetuned-scienceqa"
)
qwen_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
qwen_model.eval()
print("βœ… Qwen loaded!")

# 3. MusicGen
print("🎡 Loading MusicGen...")
music_processor = AutoProcessor.from_pretrained("Bashaarat1/fine-tuned-musicgen-small")
music_model = MusicgenForConditionalGeneration.from_pretrained(
    "Bashaarat1/fine-tuned-musicgen-small"
).to(device)
music_model.eval()
print("βœ… MusicGen loaded!")

# 4. Stable Diffusion
print("🎨 Loading Stable Diffusion...")
sd_pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    safety_checker=None
).to(device)
print("βœ… Stable Diffusion loaded!")

print("\nπŸŽ‰ All models loaded! API ready.\n")

# ============================================================================
# REQUEST/RESPONSE MODELS
# ============================================================================

class SummarizeRequest(BaseModel):
    text: str
    max_length: int = 128

class SummarizeResponse(BaseModel):
    summary: str
    original_words: int
    summary_words: int

class QARequest(BaseModel):
    question: str
    image_base64: Optional[str] = None

class QAResponse(BaseModel):
    answer: str

class ImageRequest(BaseModel):
    prompt: str
    negative_prompt: str = ""
    num_steps: int = 25

class ImageResponse(BaseModel):
    image_base64: str

class MusicRequest(BaseModel):
    prompt: str
    duration: int = 10

class MusicResponse(BaseModel):
    audio_base64: str
    sampling_rate: int
    format: str

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def numpy_to_wav(audio_data: np.ndarray, sampling_rate: int) -> bytes:
    """Convert numpy array to WAV format bytes"""
    # Normalize audio to -1 to 1 range
    audio_data = np.clip(audio_data, -1, 1)
    
    # Convert to 16-bit PCM
    audio_int16 = (audio_data * 32767).astype(np.int16)
    
    # Create WAV file in memory
    wav_io = BytesIO()
    with wave.open(wav_io, 'wb') as wav_file:
        wav_file.setnchannels(1)  # Mono
        wav_file.setsampwidth(2)  # 16-bit
        wav_file.setframerate(sampling_rate)
        wav_file.writeframes(audio_int16.tobytes())
    
    return wav_io.getvalue()

# ============================================================================
# API ENDPOINTS
# ============================================================================

@app.get("/")
def root():
    """API health check"""
    return {
        "status": "online",
        "message": "ContentForge AI API",
        "version": "1.0.0",
        "endpoints": [
            "/summarize",
            "/qa",
            "/generate-image",
            "/generate-music"
        ]
    }

@app.post("/summarize", response_model=SummarizeResponse)
def summarize(
    request: SummarizeRequest,
    user: str = Header(None, alias="x-api-key")
):
    """Summarize text using fine-tuned T5"""
    verify_api_key(user)
    
    if not request.text.strip():
        raise HTTPException(status_code=400, detail="Text cannot be empty")
    
    try:
        inputs = t5_tokenizer(
            f"summarize: {request.text}",
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(device)
        
        with torch.no_grad():
            outputs = t5_model.generate(
                **inputs,
                max_length=request.max_length,
                min_length=30,
                num_beams=4,
                early_stopping=True
            )
        
        summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return SummarizeResponse(
            summary=summary,
            original_words=len(request.text.split()),
            summary_words=len(summary.split())
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/qa", response_model=QAResponse)
def question_answer(
    request: QARequest,
    user: str = Header(None, alias="x-api-key")
):
    """Answer questions with optional image using Qwen VLM"""
    verify_api_key(user)
    
    if not request.question.strip():
        raise HTTPException(status_code=400, detail="Question cannot be empty")
    
    try:
        image = None
        if request.image_base64:
            # Decode base64 image
            image_data = base64.b64decode(request.image_base64)
            image = Image.open(BytesIO(image_data)).convert('RGB')
        
        # Prepare messages
        if image is not None:
            messages = [{
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": request.question}
                ]
            }]
        else:
            messages = [{
                "role": "user",
                "content": [{"type": "text", "text": request.question}]
            }]
        
        text_prompt = qwen_processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        if image is not None:
            img_inputs, _ = process_vision_info(messages)
            inputs = qwen_processor(
                text=[text_prompt],
                images=img_inputs,
                return_tensors="pt"
            ).to(device)
        else:
            inputs = qwen_processor(
                text=[text_prompt],
                return_tensors="pt"
            ).to(device)
        
        with torch.no_grad():
            outputs = qwen_model.generate(**inputs, max_new_tokens=200)
        
        answer = qwen_processor.batch_decode(
            outputs[:, inputs.input_ids.size(1):],
            skip_special_tokens=True
        )[0].strip()
        
        return QAResponse(answer=answer)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate-image", response_model=ImageResponse)
def generate_image(
    request: ImageRequest,
    user: str = Header(None, alias="x-api-key")
):
    """Generate image using Stable Diffusion"""
    verify_api_key(user)
    
    if not request.prompt.strip():
        raise HTTPException(status_code=400, detail="Prompt cannot be empty")
    
    try:
        with torch.no_grad():
            image = sd_pipe(
                request.prompt,
                negative_prompt=request.negative_prompt,
                num_inference_steps=request.num_steps,
                guidance_scale=7.5
            ).images[0]
        
        # Convert image to base64
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        
        return ImageResponse(image_base64=img_str)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate-music", response_model=MusicResponse)
def generate_music(
    request: MusicRequest,
    user: str = Header(None, alias="x-api-key")
):
    """Generate music using MusicGen"""
    verify_api_key(user)
    
    if not request.prompt.strip():
        raise HTTPException(status_code=400, detail="Prompt cannot be empty")
    
    try:
        inputs = music_processor(
            text=[request.prompt],
            padding=True,
            return_tensors="pt"
        ).to(device)
        
        max_tokens = int(request.duration * 50)
        
        with torch.no_grad():
            audio_values = music_model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=True
            )
        
        sampling_rate = music_model.config.audio_encoder.sampling_rate
        audio_data = audio_values[0, 0].cpu().numpy()
        
        # Convert to WAV format
        wav_bytes = numpy_to_wav(audio_data, sampling_rate)
        
        # Encode to base64
        audio_str = base64.b64encode(wav_bytes).decode()
        
        return MusicResponse(
            audio_base64=audio_str,
            sampling_rate=sampling_rate,
            format="wav"
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# ============================================================================
# RUN SERVER
# ============================================================================

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)