joeee2321512 commited on
Commit
f814852
Β·
verified Β·
1 Parent(s): f6d3976

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +106 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from peft import PeftModel
6
+ import gc
7
+ import os
8
+
9
+ # Add this line immediately after your imports
10
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
11
+
12
+ # --- Configuration ---
13
+ base_model_id = "joeee2321512/Qwen2.5-VL-3B-Instruct-finetuned"
14
+ adapter_id = "joeee2321512/Basira"
15
+
16
+ # --- Model Loading ---
17
+ print("Loading base model...")
18
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
19
+ base_model_id,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ token=os.getenv("token_HF")
23
+ )
24
+
25
+ print("Loading processor...")
26
+ processor = AutoProcessor.from_pretrained(
27
+ base_model_id,
28
+ token=os.getenv("token_HF")
29
+ )
30
+ processor.tokenizer.padding_side = "right"
31
+
32
+ print("Loading and applying adapter...")
33
+ model = PeftModel.from_pretrained(model, adapter_id)
34
+ print("Model loaded successfully!")
35
+
36
+ # --- The Inference Function ---
37
+ def perform_ocr_on_image(image_input: Image.Image) -> str:
38
+ """
39
+ This is the core function that Gradio will call.
40
+ It takes a PIL image and returns the transcribed text string.
41
+ """
42
+ if image_input is None:
43
+ return "Please upload an image."
44
+
45
+ try:
46
+ # Format the prompt using the chat template
47
+ messages = [
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {"type": "image", "image": image_input},
52
+ {"type": "text", "text": (
53
+ "Analyze the input image and detect all Arabic text. "
54
+ "Output only the extracted textβ€”verbatim and in its original scriptβ€”"
55
+ "without any added commentary, translation, punctuation or formatting. "
56
+ "Present each line of text as plain UTF-8 strings, with no extra characters or words."
57
+ )},
58
+ ],
59
+ }
60
+ ]
61
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
62
+
63
+ # Prepare inputs for the model
64
+ inputs = processor(text=text, images=image_input, return_tensors="pt").to(model.device)
65
+
66
+ # Generate prediction
67
+ with torch.no_grad():
68
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
69
+
70
+ # Decode the output
71
+ full_response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
+
73
+ # --- FIX: Post-process the response to remove the prompt ---
74
+ # The model's actual output starts after the "assistant" marker.
75
+ # We split the full response by this marker and take the last part.
76
+ parts = full_response.split("assistant")
77
+ if len(parts) > 1:
78
+ # Take the last part and remove any leading/trailing whitespace
79
+ cleaned_response = parts[-1].strip()
80
+ else:
81
+ # If the marker isn't found, return the full response as a fallback
82
+ cleaned_response = full_response
83
+ # --- END OF FIX ---
84
+
85
+ # Clean up memory
86
+ gc.collect()
87
+ torch.cuda.empty_cache()
88
+
89
+ return cleaned_response
90
+
91
+ except Exception as e:
92
+ print(f"An error occurred during inference: {e}")
93
+ return f"An error occurred: {str(e)}"
94
+
95
+ # --- Create and Launch the Gradio Interface ---
96
+ demo = gr.Interface(
97
+ fn=perform_ocr_on_image,
98
+ inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
99
+ outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
100
+ title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
101
+ description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.",
102
+ allow_flagging="never"
103
+ )
104
+
105
+ if _name_ == "_main_":
106
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ torch
3
+ transformers
4
+ peft
5
+ accelerate
6
+ bitsandbytes
7
+ Pillow
8
+ gradio
9
+ sentencepiece
10
+ qwen-vl-utils
11
+ torchvision