davanstrien HF Staff commited on
Commit
bbe7feb
Β·
verified Β·
1 Parent(s): 251c25d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -84
app.py CHANGED
@@ -1,105 +1,103 @@
1
  import gradio as gr
2
- from PIL import Image as PILImage
3
  import os
 
4
  import json
5
  import spaces
6
- from typing import Optional
7
- from pydantic import BaseModel, Field
8
- import outlines
9
- from outlines.inputs import Chat, Image
10
- from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
11
 
12
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
13
 
14
-
15
- # Define the metadata schema
16
- class CatalogCardMetadata(BaseModel):
17
- """Structured metadata from a library catalog card."""
18
-
19
- title: Optional[str] = Field(None, description="The main title or name on the card")
20
- author: Optional[str] = Field(
21
- None, description="Author, creator, or associated person/organization"
22
- )
23
- date: Optional[str] = Field(
24
- None,
25
- description="Any dates mentioned (publication, creation, or coverage dates)",
26
- )
27
- call_number: Optional[str] = Field(
28
- None, description="Library classification or call number"
29
- )
30
- physical_description: Optional[str] = Field(
31
- None, description="Details about the physical item (size, extent, format)"
32
- )
33
- subjects: Optional[list[str]] = Field(
34
- None, description="Subject headings or topics"
35
- )
36
- notes: Optional[str] = Field(
37
- None, description="Any additional notes or information"
38
- )
39
-
40
-
41
- # Load model and processor with Outlines
42
- print("Loading Qwen3-VL-30B-A3B-Instruct model with Outlines...")
43
- hf_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
44
- "Qwen/Qwen3-VL-30B-A3B-Instruct", torch_dtype="auto", device_map="auto"
45
  )
46
- hf_processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct")
47
- model = outlines.from_transformers(hf_model, hf_processor)
48
  print("Model loaded successfully!")
49
 
50
- EXTRACTION_PROMPT = """Extract all metadata from this library catalog card. Include title, author, dates, call number, physical description, subjects, and notes. If a field is not present, omit it."""
 
 
 
 
 
 
 
51
 
 
52
 
53
  @spaces.GPU
54
  def extract_metadata(image):
55
- """Extract structured metadata from catalog card image using Outlines."""
56
  if image is None:
57
  return "Please upload an image."
58
 
59
  try:
60
  # Ensure image is PIL Image
61
- print(f"DEBUG: Received image type: {type(image)}")
62
- if not isinstance(image, PILImage.Image):
63
- image = PILImage.open(image).convert("RGB")
64
- print(f"DEBUG: After conversion, image type: {type(image)}")
65
- print(f"DEBUG: Image format before setting: {image.format}")
66
-
67
- # Set format (required by Outlines Image class)
68
- if not image.format:
69
- image.format = "PNG"
70
- print(f"DEBUG: Image format after setting: {image.format}")
71
-
72
- # Wrap in Outlines Image
73
- outlines_image = Image(image)
74
- print(f"DEBUG: Outlines Image created: {type(outlines_image)}")
75
- print(f"DEBUG: Outlines Image.image type: {type(outlines_image.image)}")
76
-
77
- # Create Chat prompt with Image (using simpler list format)
78
- prompt = Chat(
79
- messages=[
80
- {
81
- "role": "user",
82
- "content": [EXTRACTION_PROMPT, outlines_image],
83
- }
84
- ]
85
  )
86
- print(f"DEBUG: Chat prompt created successfully")
87
-
88
- # Generate with structured output - guaranteed valid JSON
89
- print(f"DEBUG: Starting generation...")
90
- result = model(prompt, CatalogCardMetadata, max_new_tokens=512)
91
- print(f"DEBUG: Generation complete, result type: {type(result)}")
92
-
93
- # Parse and format (always valid JSON with Outlines)
94
- metadata = CatalogCardMetadata.model_validate_json(result)
95
- return json.dumps(metadata.model_dump(exclude_none=True), indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  except Exception as e:
98
- import traceback
99
- error_msg = f"Error during extraction: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
100
- print(error_msg)
101
- return error_msg
102
-
103
 
104
  # Create Gradio interface
105
  with gr.Blocks(title="Library Card Metadata Extractor") as demo:
@@ -118,14 +116,25 @@ with gr.Blocks(title="Library Card Metadata Extractor") as demo:
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
  gr.Markdown("### πŸ“€ Upload Catalog Card")
121
- image_input = gr.Image(label="Library Catalog Card", type="pil")
 
 
 
122
  submit_btn = gr.Button("πŸ” Extract Metadata", variant="primary", size="lg")
123
 
124
  with gr.Column(scale=1):
125
  gr.Markdown("### πŸ“‹ Extracted Metadata (JSON)")
126
- output = gr.Code(label="Metadata", language="json", lines=15)
 
 
 
 
127
 
128
- submit_btn.click(fn=extract_metadata, inputs=image_input, outputs=output)
 
 
 
 
129
 
130
  gr.Markdown("---")
131
 
@@ -143,7 +152,7 @@ with gr.Blocks(title="Library Card Metadata Extractor") as demo:
143
  inputs=image_input,
144
  outputs=output,
145
  fn=extract_metadata,
146
- cache_examples=False,
147
  )
148
 
149
  gr.Markdown("---")
 
1
  import gradio as gr
2
+ from PIL import Image
3
  import os
4
+ import torch
5
  import json
6
  import spaces
7
+ from transformers import AutoModelForImageTextToText, AutoProcessor
8
+ from qwen_vl_utils import process_vision_info
 
 
 
9
 
10
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
11
 
12
+ # Load model and processor
13
+ print("Loading Qwen3-VL-30B-A3B-Instruct model...")
14
+ model = AutoModelForImageTextToText.from_pretrained(
15
+ "Qwen/Qwen3-VL-30B-A3B-Instruct",
16
+ torch_dtype=torch.bfloat16,
17
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct")
 
20
  print("Model loaded successfully!")
21
 
22
+ EXTRACTION_PROMPT = """Extract all metadata from this library catalog card and return it as valid JSON with the following fields:
23
+ - title: The main title or name on the card
24
+ - author: Author, creator, or associated person/organization
25
+ - date: Any dates mentioned (publication, creation, or coverage dates)
26
+ - call_number: Library classification or call number
27
+ - physical_description: Details about the physical item (size, extent, format)
28
+ - subjects: Subject headings or topics
29
+ - notes: Any additional notes or information
30
 
31
+ Return ONLY the JSON object, nothing else. If a field is not present on the card, use null for that field."""
32
 
33
  @spaces.GPU
34
  def extract_metadata(image):
35
+ """Extract structured metadata from catalog card image."""
36
  if image is None:
37
  return "Please upload an image."
38
 
39
  try:
40
  # Ensure image is PIL Image
41
+ if not isinstance(image, Image.Image):
42
+ image = Image.open(image).convert("RGB")
43
+
44
+ # Format messages for Qwen3-VL
45
+ messages = [
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {"type": "image", "image": image},
50
+ {"type": "text", "text": EXTRACTION_PROMPT}
51
+ ]
52
+ }
53
+ ]
54
+
55
+ # Prepare inputs
56
+ text = processor.apply_chat_template(
57
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
 
 
 
 
58
  )
59
+ image_inputs, video_inputs = process_vision_info(messages)
60
+
61
+ inputs = processor(
62
+ text=[text],
63
+ images=image_inputs,
64
+ videos=video_inputs,
65
+ padding=True,
66
+ return_tensors="pt"
67
+ )
68
+ inputs = inputs.to(model.device)
69
+
70
+ # Generate
71
+ with torch.inference_mode():
72
+ generated_ids = model.generate(
73
+ **inputs,
74
+ max_new_tokens=512,
75
+ temperature=0.1,
76
+ do_sample=False
77
+ )
78
+
79
+ # Trim input tokens from output
80
+ generated_ids_trimmed = [
81
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
82
+ ]
83
+
84
+ # Decode output
85
+ output_text = processor.batch_decode(
86
+ generated_ids_trimmed,
87
+ skip_special_tokens=True,
88
+ clean_up_tokenization_spaces=False
89
+ )[0]
90
+
91
+ # Try to parse as JSON for pretty formatting
92
+ try:
93
+ json_data = json.loads(output_text)
94
+ return json.dumps(json_data, indent=2)
95
+ except json.JSONDecodeError:
96
+ # If not valid JSON, return as-is
97
+ return output_text
98
 
99
  except Exception as e:
100
+ return f"Error during extraction: {str(e)}"
 
 
 
 
101
 
102
  # Create Gradio interface
103
  with gr.Blocks(title="Library Card Metadata Extractor") as demo:
 
116
  with gr.Row():
117
  with gr.Column(scale=1):
118
  gr.Markdown("### πŸ“€ Upload Catalog Card")
119
+ image_input = gr.Image(
120
+ label="Library Catalog Card",
121
+ type="pil"
122
+ )
123
  submit_btn = gr.Button("πŸ” Extract Metadata", variant="primary", size="lg")
124
 
125
  with gr.Column(scale=1):
126
  gr.Markdown("### πŸ“‹ Extracted Metadata (JSON)")
127
+ output = gr.Code(
128
+ label="Metadata",
129
+ language="json",
130
+ lines=15
131
+ )
132
 
133
+ submit_btn.click(
134
+ fn=extract_metadata,
135
+ inputs=image_input,
136
+ outputs=output
137
+ )
138
 
139
  gr.Markdown("---")
140
 
 
152
  inputs=image_input,
153
  outputs=output,
154
  fn=extract_metadata,
155
+ cache_examples=False
156
  )
157
 
158
  gr.Markdown("---")