Parth1503 commited on
Commit
9517f59
·
verified ·
1 Parent(s): 0c2d792

Upload 3 files

Browse files
Files changed (3) hide show
  1. Clip.py +166 -0
  2. app.py +28 -0
  3. requirements.txt +11 -0
Clip.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPModel, CLIPProcessor
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+ TF_ENABLE_ONEDNN_OPTS=0
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_name = "openai/clip-vit-base-patch32"
11
+ model = CLIPModel.from_pretrained(model_name).to(device)
12
+ processor = CLIPProcessor.from_pretrained(model_name)
13
+
14
+ # This function extracts patches from an image and returns them along with their coordinates.
15
+ def image_patch(img, patch_size =(100, 100), stride = 2):
16
+
17
+ img_w, img_h = img.size
18
+ print(f"Image dimensions: width={img_w}, height={img_h}")
19
+ patches = []
20
+
21
+ for i in range(0, img_h - patch_size[1] + 1, stride):
22
+ for j in range(0, img_w - patch_size[0] + 1, stride):
23
+ patch = img.crop((j, i, j + patch_size[0], i + patch_size[1]))
24
+ patches.append((patch, (j, i)))
25
+
26
+ return patches
27
+
28
+ def bounding_box(img, heatmap):
29
+
30
+ img_copy = np.array(img).copy()
31
+ found = False
32
+
33
+ normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
34
+
35
+ _, binary = cv2.threshold(normalized, 200, 255, cv2.THRESH_BINARY)
36
+
37
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
38
+
39
+ if contours:
40
+ largest = max(contours, key=cv2.contourArea)
41
+
42
+ x, y, w, h = cv2.boundingRect(largest)
43
+ cv2.rectangle(img_copy, (x, y), (x + w, y + h), (255, 0, 0), 2)
44
+ found = True
45
+
46
+ return img_copy, found
47
+
48
+ # def main():
49
+
50
+ # print("Starting the object detection process...")
51
+
52
+ # img_path = r"C:\Users\sahas\OneDrive\Desktop\GenMatch\Photo of a dog.jpg"
53
+
54
+ # score_patches = []
55
+ # prompt = ["a photo of a human", "a close up of a dog's face"]
56
+
57
+ # try:
58
+
59
+ # # Open the image
60
+ # img = Image.open(img_path)
61
+ # print(f"Image opened successfully: {img_path}")
62
+
63
+ # # Extract patches from the image
64
+ # patches = image_patch(img)
65
+ # print(f"Extracted {len(patches)} patches from the image.")
66
+
67
+ # # Process all patches with the CLIP model to get the probabilities
68
+ # patch_batch = [p for p, (x, y) in patches]
69
+ # input = processor(text=prompt, images=patch_batch, return_tensors="pt", padding=True)
70
+ # input = {k: v.to(device) for k, v in input.items()}
71
+ # with torch.no_grad():
72
+ # output = model(**input)
73
+
74
+ # logits = output.logits_per_image
75
+ # prob = logits.softmax(dim=1)
76
+
77
+ # for i, (patch, (x, y)) in enumerate(patches):
78
+ # score = prob[i][0].item()
79
+ # score_patches.append((patch, (x, y), score))
80
+
81
+ # # Create heatmap based on scores
82
+ # img_h, img_w = img.size
83
+ # pat_h, pat_w = patches[0][0].size
84
+
85
+ # heatmap = np.zeros((img_h, img_w))
86
+
87
+ # for _, (x, y), score in score_patches:
88
+ # heatmap[y:y + pat_h, x:x + pat_w] += score
89
+
90
+ # fig, ax = plt.subplots()
91
+ # ax.imshow(img)
92
+ # ax.imshow(heatmap, cmap='viridis', alpha=0.6)
93
+ # ax.axis('off')
94
+ # plt.show()
95
+
96
+ # print("Genrating images with bounding box")
97
+
98
+ # box_img = bounding_box(img, heatmap)
99
+
100
+ # plt.imshow(box_img)
101
+ # plt.axis('off')
102
+ # plt.show()
103
+
104
+
105
+ # except FileNotFoundError:
106
+ # print(f"Error opening image: {img_path}")
107
+ # return
108
+
109
+ # if __name__ == "__main__":
110
+ # main()
111
+
112
+ def run_detection_pipeline(input_image, text_prompt):
113
+
114
+ print("Starting the object detection process...")
115
+
116
+ img = input_image
117
+ prompt = [text_prompt, "a photo of a blank background"]
118
+ score_patches = []
119
+ all_scores = []
120
+
121
+ patches = image_patch(img)
122
+ print(f"Extracted {len(patches)} patches from the image.")
123
+
124
+ patch_batch = [p for p, (x, y) in patches]
125
+ input_data = processor(text=prompt, images=patch_batch, return_tensors="pt", padding=True)
126
+ input_data = {k: v.to(device) for k, v in input_data.items()}
127
+ with torch.no_grad():
128
+ output = model(**input_data)
129
+
130
+ logits = output.logits_per_image
131
+ prob = logits.softmax(dim=1)
132
+
133
+ for i, (patch, (x, y)) in enumerate(patches):
134
+ score = prob[i][0].item()
135
+ score_patches.append((patch, (x, y), score))
136
+ all_scores.append(score)
137
+
138
+ confidence_threshold = 0.20
139
+ max_score = max(all_scores) if all_scores else 0
140
+ print(f"Max confidence score: {max_score:.4f}")
141
+
142
+ if max_score < confidence_threshold:
143
+ msg = f"Could not find '{text_prompt}' with enough confidence."
144
+ return msg, input_image
145
+
146
+ img_h, img_w = img.size
147
+
148
+ if not patches:
149
+ print("Warning: No patches were extracted from the image.")
150
+ return img
151
+
152
+ pat_h, pat_w = patches[0][0].size
153
+ heatmap = np.zeros((img_h, img_w))
154
+
155
+ for _, (x, y), score in score_patches:
156
+ heatmap[y:y + pat_h, x:x + pat_w] += score
157
+
158
+ print("Generating image with bounding box...")
159
+ box_img, found = bounding_box(img, heatmap)
160
+
161
+ if not found:
162
+ msg = "No object detected matching the prompt."
163
+ else:
164
+ msg = "Object detected and highlighted."
165
+
166
+ return msg, Image.fromarray(box_img)
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from Clip import run_detection_pipeline
3
+
4
+ print("Loading the application...")
5
+
6
+ iface = gr.Interface(
7
+ fn=run_detection_pipeline,
8
+ inputs = [
9
+ gr.Image(type="pil", label="Upload Image"),
10
+ gr.Textbox(label="Text Prompt", placeholder="e.g., a photo of a dog's face")
11
+ ],
12
+ outputs=[
13
+ gr.Textbox(label="Result"),
14
+ gr.Image(type="pil", label="Detection Result")
15
+ ],
16
+ title="GenMatch: Open-Vocabulary Object Detector",
17
+ description="Upload an image and type what you want to find. The model will draw a box around it.",
18
+ # Adding examples makes your app much easier to test and demonstrate!
19
+ # Create a folder named 'examples' and put some images inside it.
20
+ # examples=[
21
+ # [r"C:\Users\sahas\OneDrive\Desktop\GenMatch\Bounding Box.png", "Object detected and highlighted."],
22
+ # [r"C:\Users\sahas\OneDrive\Desktop\GenMatch\Photo of a dog.jpg", "Could not find '{prompt}' with enough confidence."],
23
+ # ]
24
+ )
25
+
26
+ # Launch the app
27
+ if __name__ == "__main__":
28
+ iface.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ numpy
3
+ matplotlib
4
+ scikit-learn
5
+ seaborn
6
+ tensorflow
7
+ torch
8
+ torchvision
9
+ transformers
10
+ pillow
11
+ opencv-python