File size: 5,853 Bytes
33cfa2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import json
import re
import base64
import aiohttp # Async test. Need to install
import asyncio
# --- 配置区域 ---
BASE_URL = os.getenv('GEMINI_FLOW2API_URL', 'http://127.0.0.1:8000')
BACKEND_URL = BASE_URL + "/v1/chat/completions"
API_KEY = os.getenv('GEMINI_FLOW2API_APIKEY', 'Bearer han1234')
if API_KEY is None:
raise ValueError('[gemini flow2api] api key not set')
MODEL_LANDSCAPE = "gemini-3.0-pro-image-landscape"
MODEL_PORTRAIT = "gemini-3.0-pro-image-portrait"
# 修改: 增加 model 参数,默认为 None
async def request_backend_generation(
prompt: str,
images: list[bytes] = None,
model: str = None) -> bytes | None:
"""
请求后端生成图片。
:param prompt: 提示词
:param images: 图片二进制列表
:param model: 指定模型名称 (可选)
:return: 成功返回图片bytes,失败返回None
"""
# 更新token
images = images or []
# 逻辑: 如果未指定 model,默认使用 Landscape
use_model = model if model else MODEL_LANDSCAPE
# 1. 构造 Payload
if images:
content_payload = [{"type": "text", "text": prompt}]
print(f"[Backend] 正在处理 {len(images)} 张图片输入...")
for img_bytes in images:
b64_str = base64.b64encode(img_bytes).decode('utf-8')
content_payload.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{b64_str}"}
})
else:
content_payload = prompt
payload = {
"model": use_model, # 使用选定的模型
"messages": [{"role": "user", "content": content_payload}],
"stream": True
}
headers = {
"Authorization": API_KEY,
"Content-Type": "application/json"
}
image_url = None
print(f"[Backend] Model: {use_model} | 发起请求: {prompt[:20]}...")
try:
async with aiohttp.ClientSession() as session:
async with session.post(BACKEND_URL, json=payload, headers=headers, timeout=120) as response:
if response.status != 200:
err_text = await response.text()
content = response.content
print(f"[Backend Error] Status {response.status}: {err_text} {content}")
raise Exception(f"API Error: {response.status}: {err_text}")
async for line in response.content:
line_str = line.decode('utf-8').strip()
if line_str.startswith('{"error'):
chunk = json.loads(data_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
msg = delta['reasoning_content']
if '401' in msg:
msg += '\nAccess Token 已失效,需重新配置。'
elif '400' in msg:
msg += '\n返回内容被拦截。'
raise Exception(msg)
if not line_str or not line_str.startswith('data: '):
continue
data_str = line_str[6:]
if data_str == '[DONE]':
break
try:
chunk = json.loads(data_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
# 打印思考过程
if "reasoning_content" in delta:
print(delta['reasoning_content'], end="", flush=True)
# 提取内容中的图片链接
if "content" in delta:
content_text = delta["content"]
img_match = re.search(r'!\[.*?\]\((.*?)\)', content_text)
if img_match:
image_url = img_match.group(1)
print(f"\n[Backend] 捕获图片链接: {image_url}")
except json.JSONDecodeError:
continue
# 3. 下载生成的图片
if image_url:
async with session.get(image_url) as img_resp:
if img_resp.status == 200:
image_bytes = await img_resp.read()
return image_bytes
else:
print(f"[Backend Error] 图片下载失败: {img_resp.status}")
except Exception as e:
print(f"[Backend Exception] {e}")
raise e
return None
if __name__ == '__main__':
async def main():
print("=== AI 绘图接口测试 ===")
user_prompt = input("请输入提示词 (例如 '一只猫'): ").strip()
if not user_prompt:
user_prompt = "A cute cat in the garden"
print(f"正在请求: {user_prompt}")
# 这里的 images 传空列表用于测试文生图
# 如果想测试图生图,你需要手动读取本地文件:
# with open("output_test.jpg", "rb") as f: img_data = f.read()
# result = await request_backend_generation(user_prompt, [img_data])
result = await request_backend_generation(user_prompt)
if result:
filename = "output_test.jpg"
with open(filename, "wb") as f:
f.write(result)
print(f"\n[Success] 图片已保存为 {filename},大小: {len(result)} bytes")
else:
print("\n[Failed] 生成失败")
# 运行测试
if os.name == 'nt': # Windows 兼容性
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main()) |