Spaces:
Build error
Build error
| # Based on liuhaotian/LLaVA-1.6 | |
| import sys | |
| import os | |
| import argparse | |
| import time | |
| import subprocess | |
| import gradio as gr | |
| import llava.serve.gradio_web_server as gws | |
| # Execute the pip install command with additional options | |
| subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'wheel', 'setuptools']) | |
| subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']) | |
| def start_controller(): | |
| print("Starting the controller") | |
| controller_command = [ | |
| sys.executable, | |
| "-m", | |
| "llava.serve.controller", | |
| "--host", | |
| "0.0.0.0", | |
| "--port", | |
| "10000", | |
| ] | |
| print("Controller Command:", controller_command) | |
| return subprocess.Popen(controller_command) | |
| def start_worker(model_path: str, bits=16): | |
| print(f"Starting the model worker for the model {model_path}") | |
| model_name = model_path.strip("/").split("/")[-1] | |
| assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit." | |
| if bits != 16: | |
| model_name += f"-{bits}bit" | |
| model_name += "-lora" | |
| worker_command = [ | |
| sys.executable, | |
| "-m", | |
| "llava.serve.model_worker", | |
| "--host", | |
| "0.0.0.0", | |
| "--controller", | |
| "http://localhost:10000", | |
| "--model-path", | |
| model_path, | |
| "--model-name", | |
| model_name, | |
| "--model-base", | |
| "liuhaotian/llava-1.5-7b", | |
| "--use-flash-attn", | |
| ] | |
| print("Worker Command:", worker_command) | |
| return subprocess.Popen(worker_command) | |
| def handle_text_prompt(text, temperature=0.2, top_p=0.7, max_new_tokens=512): | |
| """ | |
| Custom API endpoint to handle text prompts. | |
| Replace the placeholder logic with actual model inference. | |
| """ | |
| # TODO: Replace the following placeholder with actual model inference code | |
| print(f"Received prompt: {text}") | |
| print(f"Parameters - Temperature: {temperature}, Top P: {top_p}, Max New Tokens: {max_new_tokens}") | |
| # Example response (replace with actual model response) | |
| response = f"Model response to '{text}' with temperature={temperature}, top_p={top_p}, max_new_tokens={max_new_tokens}" | |
| return response | |
| def add_text_with_image(text, image, mode): | |
| """ | |
| Custom API endpoint to add text with an image. | |
| Replace the placeholder logic with actual processing. | |
| """ | |
| # TODO: Replace the following placeholder with actual processing code | |
| print(f"Adding text: {text}") | |
| print(f"Image path: {image}") | |
| print(f"Image processing mode: {mode}") | |
| # Example response (replace with actual processing code) | |
| response = f"Added text '{text}' with image at '{image}' using mode '{mode}'." | |
| return response | |
| def build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=5): | |
| """ | |
| Builds a Gradio Blocks interface with custom API endpoints. | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# AstroLLaVA") | |
| gr.Markdown("Welcome to the AstroLLaVA interface. Use the API endpoints to interact with the model.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Prompt the Model") | |
| text_input = gr.Textbox(label="Enter your text prompt", placeholder="Type your prompt here...") | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Top P") | |
| max_tokens_slider = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max New Tokens") | |
| submit_button = gr.Button("Submit Prompt") | |
| with gr.Column(): | |
| chatbot_output = gr.Textbox(label="Model Response", interactive=False) | |
| submit_button.click( | |
| fn=handle_text_prompt, | |
| inputs=[text_input, temperature_slider, top_p_slider, max_tokens_slider], | |
| outputs=chatbot_output, | |
| api_name="prompt_model" # Custom API endpoint name | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Add Text with Image") | |
| add_text_input = gr.Textbox(label="Add Text", placeholder="Enter text to add...") | |
| add_image_input = gr.Image(label="Upload Image") | |
| image_process_mode = gr.Radio(choices=["Crop", "Resize", "Pad", "Default"], value="Default", label="Image Process Mode") | |
| add_submit_button = gr.Button("Add Text with Image") | |
| with gr.Column(): | |
| add_output = gr.Textbox(label="Add Text Response", interactive=False) | |
| add_submit_button.click( | |
| fn=add_text_with_image, | |
| inputs=[add_text_input, add_image_input, image_process_mode], | |
| outputs=add_output, | |
| api_name="add_text_with_image" # Another custom API endpoint | |
| ) | |
| # Additional API endpoints can be added here following the same structure | |
| return demo | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="AstroLLaVA Gradio App") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Hostname to listen on") | |
| parser.add_argument("--port", type=int, default=7860, help="Port number") | |
| parser.add_argument("--controller-url", type=str, default="http://localhost:10000", help="Controller URL") | |
| parser.add_argument("--concurrency-count", type=int, default=5, help="Number of concurrent requests") | |
| parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"], help="Model list mode") | |
| parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly") | |
| parser.add_argument("--moderate", action="store_true", help="Enable moderation") | |
| parser.add_argument("--embed", action="store_true", help="Enable embed mode") | |
| args = parser.parse_args() | |
| gws.args = args | |
| gws.models = [] | |
| gws.title_markdown += """ AstroLLaVA """ | |
| print(f"AstroLLaVA arguments: {gws.args}") | |
| model_path = os.getenv("model", "universeTBD/AstroLLaVA_v2") | |
| bits = int(os.getenv("bits", 4)) | |
| concurrency_count = int(os.getenv("concurrency_count", 5)) | |
| controller_proc = start_controller() | |
| worker_proc = start_worker(model_path, bits=bits) | |
| # Wait for worker and controller to start | |
| print("Waiting for worker and controller to start...") | |
| time.sleep(30) | |
| exit_status = 0 | |
| try: | |
| # Build the custom Gradio demo with additional API endpoints | |
| demo = build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count) | |
| print("Launching Gradio with custom API endpoints...") | |
| demo.queue( | |
| status_update_rate=10, | |
| api_open=False | |
| ).launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share | |
| ) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| exit_status = 1 | |
| finally: | |
| worker_proc.kill() | |
| controller_proc.kill() | |
| sys.exit(exit_status) | |