Spaces:
Paused
Paused
alessandro trinca tornidor
commited on
Commit
·
719ecfd
1
Parent(s):
995f2bf
[refactor] fix missing parse_args() functions, fix inference()
Browse files- app.py +54 -36
- utils/constants.py +47 -0
app.py
CHANGED
|
@@ -1,32 +1,20 @@
|
|
| 1 |
import argparse
|
| 2 |
-
import cv2
|
| 3 |
-
import gradio as gr
|
| 4 |
import json
|
| 5 |
import logging
|
| 6 |
-
import nh3
|
| 7 |
-
import numpy as np
|
| 8 |
import os
|
| 9 |
-
import re
|
| 10 |
import sys
|
| 11 |
-
import
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
from fastapi.staticfiles import StaticFiles
|
| 16 |
from fastapi.templating import Jinja2Templates
|
| 17 |
-
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 18 |
-
from typing import Callable
|
| 19 |
|
| 20 |
-
from
|
| 21 |
-
from model.llava import conversation as conversation_lib
|
| 22 |
-
from model.llava.mm_utils import tokenizer_image_token
|
| 23 |
-
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 24 |
-
from utils import session_logger
|
| 25 |
-
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 26 |
-
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 27 |
|
| 28 |
-
session_logger.change_logging(logging.DEBUG)
|
| 29 |
|
|
|
|
| 30 |
|
| 31 |
CUSTOM_GRADIO_PATH = "/"
|
| 32 |
app = FastAPI(title="lisa_app", version="1.0")
|
|
@@ -48,6 +36,37 @@ def health() -> str:
|
|
| 48 |
return json.dumps({"msg": "request failed"})
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@session_logger.set_uuid_logging
|
| 52 |
def get_cleaned_input(input_str):
|
| 53 |
logging.info(f"start cleaning of input_str: {input_str}.")
|
|
@@ -85,12 +104,11 @@ def get_inference_model_by_args(args_to_parse):
|
|
| 85 |
|
| 86 |
@session_logger.set_uuid_logging
|
| 87 |
def inference(input_str, input_image):
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
logging.info(f"
|
| 93 |
-
|
| 94 |
return output_image, output_str
|
| 95 |
|
| 96 |
return inference
|
|
@@ -100,20 +118,20 @@ def get_inference_model_by_args(args_to_parse):
|
|
| 100 |
def get_gradio_interface(fn_inference: Callable):
|
| 101 |
return gr.Interface(
|
| 102 |
fn_inference,
|
| 103 |
-
|
| 104 |
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
|
| 105 |
gr.Image(type="filepath", label="Input Image")
|
| 106 |
-
|
| 107 |
-
|
| 108 |
gr.Image(type="pil", label="Segmentation Output"),
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
title=title,
|
| 112 |
-
description=description,
|
| 113 |
-
article=article,
|
| 114 |
-
examples=examples,
|
| 115 |
-
allow_flagging="auto"
|
| 116 |
-
)
|
| 117 |
|
| 118 |
|
| 119 |
args = parse_args(sys.argv[1:])
|
|
|
|
| 1 |
import argparse
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import logging
|
|
|
|
|
|
|
| 4 |
import os
|
|
|
|
| 5 |
import sys
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import nh3
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
from fastapi.staticfiles import StaticFiles
|
| 12 |
from fastapi.templating import Jinja2Templates
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
from utils import constants, session_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
+
session_logger.change_logging(logging.DEBUG)
|
| 18 |
|
| 19 |
CUSTOM_GRADIO_PATH = "/"
|
| 20 |
app = FastAPI(title="lisa_app", version="1.0")
|
|
|
|
| 36 |
return json.dumps({"msg": "request failed"})
|
| 37 |
|
| 38 |
|
| 39 |
+
@session_logger.set_uuid_logging
|
| 40 |
+
def parse_args(args_to_parse):
|
| 41 |
+
parser = argparse.ArgumentParser(description="LISA chat")
|
| 42 |
+
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
|
| 43 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--precision",
|
| 46 |
+
default="fp16",
|
| 47 |
+
type=str,
|
| 48 |
+
choices=["fp32", "bf16", "fp16"],
|
| 49 |
+
help="precision for inference",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
| 52 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 53 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
| 58 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 59 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 60 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--conv_type",
|
| 63 |
+
default="llava_v1",
|
| 64 |
+
type=str,
|
| 65 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 66 |
+
)
|
| 67 |
+
return parser.parse_args(args_to_parse)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
@session_logger.set_uuid_logging
|
| 71 |
def get_cleaned_input(input_str):
|
| 72 |
logging.info(f"start cleaning of input_str: {input_str}.")
|
|
|
|
| 104 |
|
| 105 |
@session_logger.set_uuid_logging
|
| 106 |
def inference(input_str, input_image):
|
| 107 |
+
logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.")
|
| 108 |
+
output_str = get_cleaned_input(input_str)
|
| 109 |
+
logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.")
|
| 110 |
+
output_image = input_image
|
| 111 |
+
logging.info(f"output_image type: {type(output_image)}.")
|
|
|
|
| 112 |
return output_image, output_str
|
| 113 |
|
| 114 |
return inference
|
|
|
|
| 118 |
def get_gradio_interface(fn_inference: Callable):
|
| 119 |
return gr.Interface(
|
| 120 |
fn_inference,
|
| 121 |
+
inputs=[
|
| 122 |
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
|
| 123 |
gr.Image(type="filepath", label="Input Image")
|
| 124 |
+
],
|
| 125 |
+
outputs=[
|
| 126 |
gr.Image(type="pil", label="Segmentation Output"),
|
| 127 |
+
gr.Textbox(lines=1, placeholder=None, label="Text Output")
|
| 128 |
+
],
|
| 129 |
+
title=constants.title,
|
| 130 |
+
description=constants.description,
|
| 131 |
+
article=constants.article,
|
| 132 |
+
examples=constants.examples,
|
| 133 |
+
allow_flagging="auto"
|
| 134 |
+
)
|
| 135 |
|
| 136 |
|
| 137 |
args = parse_args(sys.argv[1:])
|
utils/constants.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gradio
|
| 2 |
+
examples = [
|
| 3 |
+
[
|
| 4 |
+
"Where can the driver see the car speed in this image? Please output segmentation mask.",
|
| 5 |
+
"./resources/imgs/example1.jpg",
|
| 6 |
+
],
|
| 7 |
+
[
|
| 8 |
+
"Can you segment the food that tastes spicy and hot?",
|
| 9 |
+
"./resources/imgs/example2.jpg",
|
| 10 |
+
],
|
| 11 |
+
[
|
| 12 |
+
"Assuming you are an autonomous driving robot, what part of the diagram would you manipulate to control the direction of travel? Please output segmentation mask and explain why.",
|
| 13 |
+
"./resources/imgs/example1.jpg",
|
| 14 |
+
],
|
| 15 |
+
[
|
| 16 |
+
"What can make the woman stand higher? Please output segmentation mask and explain why.",
|
| 17 |
+
"./resources/imgs/example3.jpg",
|
| 18 |
+
],
|
| 19 |
+
]
|
| 20 |
+
output_labels = ["Segmentation Output"]
|
| 21 |
+
|
| 22 |
+
title = "LISA: Reasoning Segmentation via Large Language Model"
|
| 23 |
+
|
| 24 |
+
description = """
|
| 25 |
+
<font size=4>
|
| 26 |
+
This is the online demo of LISA. \n
|
| 27 |
+
If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n
|
| 28 |
+
**Note**: **Different prompts can lead to significantly varied results**. \n
|
| 29 |
+
**Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
|
| 30 |
+
**Note**: Current model is **LISA-13B-llama2-v0-explanatory**, and 4-bit quantization may impair text-generation quality. \n
|
| 31 |
+
**Usage**: <br>
|
| 32 |
+
 (1) To let LISA **segment something**, input prompt like: "Can you segment xxx in this image?", "What is xxx in this image? Please output segmentation mask."; <br>
|
| 33 |
+
 (2) To let LISA **output an explanation**, input prompt like: "What is xxx in this image? Please output segmentation mask and explain why."; <br>
|
| 34 |
+
 (3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA). <br>
|
| 35 |
+
Hope you can enjoy our work!
|
| 36 |
+
</font>
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
article = """
|
| 40 |
+
<p style='text-align: center'>
|
| 41 |
+
<a href='https://arxiv.org/abs/2308.00692' target='_blank'>
|
| 42 |
+
Preprint Paper
|
| 43 |
+
</a>
|
| 44 |
+
\n
|
| 45 |
+
<p style='text-align: center'>
|
| 46 |
+
<a href='https://github.com/dvlab-research/LISA' target='_blank'> Github Repo </a></p>
|
| 47 |
+
"""
|