jbilcke-hf commited on
Commit
afd17e4
·
1 Parent(s): 1503cd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import warnings
2
  warnings.filterwarnings('ignore')
3
 
@@ -26,7 +27,7 @@ import copy
26
 
27
  import numpy as np
28
  import torch
29
- from PIL import Image
30
 
31
  # Grounding DINO
32
  import GroundingDINO.groundingdino.datasets.transforms as T
@@ -35,6 +36,50 @@ from GroundingDINO.groundingdino.util import box_ops
35
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
36
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def load_image(image_path):
40
  # # load image
@@ -54,6 +99,18 @@ def load_image(image_path):
54
  return image_pil, image
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
58
  caption = caption.lower()
59
  caption = caption.strip()
 
1
+
2
  import warnings
3
  warnings.filterwarnings('ignore')
4
 
 
27
 
28
  import numpy as np
29
  import torch
30
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
31
 
32
  # Grounding DINO
33
  import GroundingDINO.groundingdino.datasets.transforms as T
 
36
  from GroundingDINO.groundingdino.util.slconfig import SLConfig
37
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
38
 
39
+ import cv2
40
+ import numpy as np
41
+ import matplotlib.pyplot as plt
42
+ from lama_cleaner.model_manager import ModelManager
43
+ from lama_cleaner.schema import Config as lama_Config
44
+
45
+ # segment anything
46
+ from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
47
+
48
+ # diffusers
49
+ import PIL
50
+ import requests
51
+ import torch
52
+ from io import BytesIO
53
+ from diffusers import StableDiffusionInpaintPipeline
54
+ from huggingface_hub import hf_hub_download
55
+
56
+ from utils import computer_info
57
+ # relate anything
58
+ from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
59
+ from ram_train_eval import RamModel,RamPredictor
60
+ from mmengine.config import Config as mmengine_Config
61
+ from lama_cleaner.helper import (
62
+ load_img,
63
+ numpy_to_bytes,
64
+ resize_max_size,
65
+ )
66
+
67
+ config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
68
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
69
+ ckpt_filenmae = "groundingdino_swint_ogc.pth"
70
+ sam_checkpoint = './sam_vit_h_4b8939.pth'
71
+ output_dir = "outputs"
72
+ device = 'cpu'
73
+
74
+ os.makedirs(output_dir, exist_ok=True)
75
+ groundingdino_model = None
76
+ sam_device = None
77
+ sam_model = None
78
+ sam_predictor = None
79
+ sam_mask_generator = None
80
+ sd_pipe = None
81
+ lama_cleaner_model= None
82
+ ram_model = None
83
 
84
  def load_image(image_path):
85
  # # load image
 
99
  return image_pil, image
100
 
101
 
102
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
103
+ args = SLConfig.fromfile(model_config_path)
104
+ model = build_model(args)
105
+ args.device = device
106
+
107
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
108
+ checkpoint = torch.load(cache_file, map_location=device)
109
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
110
+ print("Model loaded from {} \n => {}".format(cache_file, log))
111
+ _ = model.eval()
112
+ return model
113
+
114
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
115
  caption = caption.lower()
116
  caption = caption.strip()