TangYiJay commited on
Commit
ff494c0
·
verified ·
1 Parent(s): c2dffd2
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -1,13 +1,18 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForVision2Seq
3
- import torch
4
  from PIL import Image
 
5
 
6
- # Load model and processor
7
  MODEL_ID = "liuhaotian/llava-v1.6-vicuna-7b"
8
- processor = AutoProcessor.from_pretrained(MODEL_ID)
9
- model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, torch_dtype=torch.float16, low_cpu_mem_usage=True)
10
 
 
 
 
 
 
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
@@ -23,37 +28,36 @@ def detect_object(image, prompt):
23
  return "⚠️ Please upload a base image first."
24
 
25
  query = (
26
- f"Ignore the base image and focus only on new or added objects. "
27
- f"Base image and detection image are given. {prompt or 'Describe and identify the materials of new objects.'}"
28
  )
29
 
30
  inputs = processor(
31
  text=query,
32
  images=[base_image, image],
33
  return_tensors="pt"
34
- ).to(device, torch.float16)
35
 
36
  output = model.generate(**inputs, max_new_tokens=256)
37
  result = processor.decode(output[0], skip_special_tokens=True)
38
  return result
39
 
40
- # Build Gradio UI
41
- with gr.Blocks(title="LLaVA Object & Material Detector") as demo:
42
- gr.Markdown("## 🧠 LLaVA-1.6 (Vicuna-7B) Object & Material Detection\nUpload a base image first, then upload another image to detect new objects while ignoring the base.")
43
 
44
  with gr.Row():
45
  with gr.Column():
46
  base_img = gr.Image(label="Base Image", type="pil")
47
  set_base_btn = gr.Button("Set as Base Image")
48
- set_base_status = gr.Textbox(label="Status")
49
 
50
  with gr.Column():
51
  target_img = gr.Image(label="Detection Image", type="pil")
52
- user_prompt = gr.Textbox(label="Custom Instruction", placeholder="e.g. Detect added objects and describe their material.")
53
  run_btn = gr.Button("Run Detection")
54
- output_box = gr.Textbox(label="Model Output")
55
 
56
- set_base_btn.click(set_base, inputs=base_img, outputs=set_base_status)
57
- run_btn.click(detect_object, inputs=[target_img, user_prompt], outputs=output_box)
58
 
59
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
 
3
  from PIL import Image
4
+ import torch
5
 
 
6
  MODEL_ID = "liuhaotian/llava-v1.6-vicuna-7b"
 
 
7
 
8
+ # Load model and processor (use correct classes)
9
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
10
+ model = LlavaForConditionalGeneration.from_pretrained(
11
+ MODEL_ID,
12
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
13
+ low_cpu_mem_usage=True,
14
+ trust_remote_code=True
15
+ )
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model.to(device)
18
 
 
28
  return "⚠️ Please upload a base image first."
29
 
30
  query = (
31
+ f"Ignore the base image and only analyze the differences. "
32
+ f"{prompt or 'Detect new objects and identify their material type.'}"
33
  )
34
 
35
  inputs = processor(
36
  text=query,
37
  images=[base_image, image],
38
  return_tensors="pt"
39
+ ).to(device, torch.float16 if torch.cuda.is_available() else torch.float32)
40
 
41
  output = model.generate(**inputs, max_new_tokens=256)
42
  result = processor.decode(output[0], skip_special_tokens=True)
43
  return result
44
 
45
+ with gr.Blocks(title="LLaVA Object Detector") as demo:
46
+ gr.Markdown("## 🧠 LLaVA 1.6 Vicuna-7B — Visual Detection & Material Identification")
 
47
 
48
  with gr.Row():
49
  with gr.Column():
50
  base_img = gr.Image(label="Base Image", type="pil")
51
  set_base_btn = gr.Button("Set as Base Image")
52
+ base_status = gr.Textbox(label="Status")
53
 
54
  with gr.Column():
55
  target_img = gr.Image(label="Detection Image", type="pil")
56
+ prompt = gr.Textbox(label="Instruction", placeholder="Detect new objects and describe material")
57
  run_btn = gr.Button("Run Detection")
58
+ output_box = gr.Textbox(label="Output")
59
 
60
+ set_base_btn.click(set_base, inputs=base_img, outputs=base_status)
61
+ run_btn.click(detect_object, inputs=[target_img, prompt], outputs=output_box)
62
 
63
  demo.launch()