| | from typing import Literal, Optional, TYPE_CHECKING |
| |
|
| | import numpy as np |
| | from fastapi import FastAPI |
| | from fastapi.responses import StreamingResponse, JSONResponse |
| | from pydantic import BaseModel, Field |
| | from platform import system |
| |
|
| | if TYPE_CHECKING: |
| | from flux_pipeline import FluxPipeline |
| |
|
| | if system() == "Windows": |
| | MAX_RAND = 2**16 - 1 |
| | else: |
| | MAX_RAND = 2**32 - 1 |
| |
|
| |
|
| | class AppState: |
| | model: "FluxPipeline" |
| |
|
| |
|
| | class FastAPIApp(FastAPI): |
| | state: AppState |
| |
|
| |
|
| | class LoraArgs(BaseModel): |
| | scale: Optional[float] = 1.0 |
| | path: Optional[str] = None |
| | name: Optional[str] = None |
| | action: Optional[Literal["load", "unload"]] = "load" |
| |
|
| |
|
| | class LoraLoadResponse(BaseModel): |
| | status: Literal["success", "error"] |
| | message: Optional[str] = None |
| |
|
| |
|
| | class GenerateArgs(BaseModel): |
| | prompt: str |
| | width: Optional[int] = Field(default=720) |
| | height: Optional[int] = Field(default=1024) |
| | num_steps: Optional[int] = Field(default=24) |
| | guidance: Optional[float] = Field(default=3.5) |
| | seed: Optional[int] = Field( |
| | default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND |
| | ) |
| | strength: Optional[float] = 1.0 |
| | init_image: Optional[str] = None |
| |
|
| |
|
| | app = FastAPIApp() |
| |
|
| |
|
| | @app.post("/generate") |
| | def generate(args: GenerateArgs): |
| | """ |
| | Generates an image from the Flux flow transformer. |
| | |
| | Args: |
| | args (GenerateArgs): Arguments for image generation: |
| | |
| | - `prompt`: The prompt used for image generation. |
| | |
| | - `width`: The width of the image. |
| | |
| | - `height`: The height of the image. |
| | |
| | - `num_steps`: The number of steps for the image generation. |
| | |
| | - `guidance`: The guidance for image generation, represents the |
| | influence of the prompt on the image generation. |
| | |
| | - `seed`: The seed for the image generation. |
| | |
| | - `strength`: strength for image generation, 0.0 - 1.0. |
| | Represents the percent of diffusion steps to run, |
| | setting the init_image as the noised latent at the |
| | given number of steps. |
| | |
| | - `init_image`: Base64 encoded image or path to image to use as the init image. |
| | |
| | Returns: |
| | StreamingResponse: The generated image as streaming jpeg bytes. |
| | """ |
| | result = app.state.model.generate(**args.model_dump()) |
| | return StreamingResponse(result, media_type="image/jpeg") |
| |
|
| |
|
| | @app.post("/lora", response_model=LoraLoadResponse) |
| | def lora_action(args: LoraArgs): |
| | """ |
| | Loads or unloads a LoRA checkpoint into / from the Flux flow transformer. |
| | |
| | Args: |
| | args (LoraArgs): Arguments for the LoRA action: |
| | |
| | - `scale`: The scaling factor for the LoRA weights. |
| | - `path`: The path to the LoRA checkpoint. |
| | - `name`: The name of the LoRA checkpoint. |
| | - `action`: The action to perform, either "load" or "unload". |
| | |
| | Returns: |
| | LoraLoadResponse: The status of the LoRA action. |
| | """ |
| | try: |
| | if args.action == "load": |
| | app.state.model.load_lora(args.path, args.scale, args.name) |
| | elif args.action == "unload": |
| | app.state.model.unload_lora(args.name if args.name else args.path) |
| | else: |
| | return JSONResponse( |
| | content={ |
| | "status": "error", |
| | "message": f"Invalid action, expected 'load' or 'unload', got {args.action}", |
| | }, |
| | status_code=400, |
| | ) |
| | except Exception as e: |
| | return JSONResponse( |
| | status_code=500, content={"status": "error", "message": str(e)} |
| | ) |
| | return JSONResponse(status_code=200, content={"status": "success"}) |
| |
|