Spaces:
Sleeping
Sleeping
Commit
·
52932d2
1
Parent(s):
070106a
added maskformer based object extraction
Browse files- extract_tools.py +16 -30
- utils.py +5 -1
extract_tools.py
CHANGED
|
@@ -124,39 +124,25 @@ def generate_bounding_box_tool(input_data:str)->str:
|
|
| 124 |
object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
|
| 125 |
return object_data
|
| 126 |
|
|
|
|
| 127 |
@tool
|
| 128 |
-
def object_extraction(
|
| 129 |
"Use this tool to identify the objects within the image"
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
try:
|
| 137 |
-
processor = BlipProcessor.from_pretrained(hf_model)
|
| 138 |
-
caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
|
| 139 |
-
except:
|
| 140 |
-
logging.error("unable to load the Blip model ")
|
| 141 |
-
|
| 142 |
-
logging.info("Image Caption model loaded ! ")
|
| 143 |
-
|
| 144 |
-
# unconditional image captioning
|
| 145 |
-
inputs = processor(image, return_tensors ='pt').to(device)
|
| 146 |
-
output = caption_model.generate(**inputs, max_new_tokens=50)
|
| 147 |
-
llm = get_groq_model()
|
| 148 |
-
getobject_chain = create_object_extraction_chain(llm=llm)
|
| 149 |
-
|
| 150 |
-
extracted_objects = getobject_chain.invoke({
|
| 151 |
-
'context': processor.decode(output[0], skip_special_tokens=True)
|
| 152 |
-
}).objects
|
| 153 |
-
|
| 154 |
-
print("Extracted objects : ",extracted_objects)
|
| 155 |
-
## clear the GPU cache
|
| 156 |
with torch.no_grad():
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
@tool
|
| 162 |
def get_image_quality(image_path:str)->str:
|
|
|
|
| 124 |
object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
|
| 125 |
return object_data
|
| 126 |
|
| 127 |
+
|
| 128 |
@tool
|
| 129 |
+
def object_extraction(image_path:str)->str:
|
| 130 |
"Use this tool to identify the objects within the image"
|
| 131 |
+
objects = []
|
| 132 |
+
maskformer_model.to(device)
|
| 133 |
+
image = cv2.imread(image_path)
|
| 134 |
+
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
|
| 135 |
+
inputs = maskformer_processor(image, return_tensors="pt")
|
| 136 |
+
inputs.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
with torch.no_grad():
|
| 138 |
+
outputs = maskformer_model(**inputs)
|
| 139 |
+
prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.shape[:2]])[0]
|
| 140 |
+
segments_info = prediction['segments_info']
|
| 141 |
+
for segment in segments_info:
|
| 142 |
+
segment_label_id = segment['label_id']
|
| 143 |
+
segment_label = maskformer_model.config.id2label[segment_label_id]
|
| 144 |
+
objects.append(segment_label)
|
| 145 |
+
return "Detected objects are: "+ " ".join( objects)
|
| 146 |
|
| 147 |
@tool
|
| 148 |
def get_image_quality(image_path:str)->str:
|
utils.py
CHANGED
|
@@ -50,4 +50,8 @@ def draw_bboxes(rgb_frame,boxes,labels,color=None,line_thickness=3):
|
|
| 50 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
| 51 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
| 52 |
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 0, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 53 |
-
return rgb_frame_copy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
| 51 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
| 52 |
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 0, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 53 |
+
return rgb_frame_copy
|
| 54 |
+
|
| 55 |
+
def object_extraction_using_maskformer(image_path):
|
| 56 |
+
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
| 57 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|