Spaces:
Sleeping
Sleeping
Commit
·
74af434
1
Parent(s):
caa3cab
init
Browse files- app.py +221 -0
- hf_model/CountEX.py +543 -0
- hf_model/__init__.py +16 -0
- hf_model/mmdet2groundingdino_swinb.py +259 -0
- hf_model/mmdet2groundingdino_swinl.py +259 -0
- hf_model/mmdet2groundingdino_swint.py +259 -0
- hf_model/modeling_grounding_dino.py +0 -0
- utils.py +455 -0
app.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image, ImageDraw
|
| 5 |
+
from transformers import GroundingDinoProcessor
|
| 6 |
+
from hf_model import CountEX
|
| 7 |
+
from utils import post_process_grounded_object_detection
|
| 8 |
+
|
| 9 |
+
# Global variables for model and processor
|
| 10 |
+
model = None
|
| 11 |
+
processor = None
|
| 12 |
+
device = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model():
|
| 16 |
+
"""Load model and processor once at startup"""
|
| 17 |
+
global model, processor, device
|
| 18 |
+
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
|
| 21 |
+
# Load model - change path for HF Spaces
|
| 22 |
+
model_id = "BBVisual/CountEX-KC" # Change to your HF model repo
|
| 23 |
+
model = CountEX.from_pretrained(model_id)
|
| 24 |
+
model = model.to(torch.bfloat16)
|
| 25 |
+
model = model.to(device)
|
| 26 |
+
model.eval()
|
| 27 |
+
|
| 28 |
+
# Load processor
|
| 29 |
+
processor_id = "fushh7/llmdet_swin_tiny_hf"
|
| 30 |
+
processor = GroundingDinoProcessor.from_pretrained(processor_id)
|
| 31 |
+
|
| 32 |
+
return model, processor, device
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
|
| 36 |
+
"""
|
| 37 |
+
Main inference function for counting objects
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
image: Input PIL Image
|
| 41 |
+
pos_caption: Positive prompt (objects to count)
|
| 42 |
+
neg_caption: Negative prompt (objects to exclude)
|
| 43 |
+
box_threshold: Detection confidence threshold
|
| 44 |
+
point_radius: Radius of visualization points
|
| 45 |
+
point_color: Color of visualization points
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Annotated image and count
|
| 49 |
+
"""
|
| 50 |
+
global model, processor, device
|
| 51 |
+
|
| 52 |
+
if model is None:
|
| 53 |
+
load_model()
|
| 54 |
+
|
| 55 |
+
# Ensure image is RGB
|
| 56 |
+
if image.mode != "RGB":
|
| 57 |
+
image = image.convert("RGB")
|
| 58 |
+
|
| 59 |
+
# Ensure captions end with period
|
| 60 |
+
if not pos_caption.endswith('.'):
|
| 61 |
+
pos_caption = pos_caption + '.'
|
| 62 |
+
if neg_caption and not neg_caption.endswith('.'):
|
| 63 |
+
neg_caption = neg_caption + '.'
|
| 64 |
+
|
| 65 |
+
# Process positive caption
|
| 66 |
+
pos_inputs = processor(
|
| 67 |
+
images=image,
|
| 68 |
+
text=pos_caption,
|
| 69 |
+
return_tensors="pt",
|
| 70 |
+
padding=True
|
| 71 |
+
)
|
| 72 |
+
pos_inputs = pos_inputs.to(device)
|
| 73 |
+
pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
|
| 74 |
+
|
| 75 |
+
# Process negative caption if provided
|
| 76 |
+
use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
|
| 77 |
+
|
| 78 |
+
if use_neg:
|
| 79 |
+
neg_inputs = processor(
|
| 80 |
+
images=image,
|
| 81 |
+
text=neg_caption,
|
| 82 |
+
return_tensors="pt",
|
| 83 |
+
padding=True
|
| 84 |
+
)
|
| 85 |
+
neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
|
| 86 |
+
neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
|
| 87 |
+
|
| 88 |
+
# Add negative inputs to positive inputs dict
|
| 89 |
+
pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
|
| 90 |
+
pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
|
| 91 |
+
pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
|
| 92 |
+
pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
|
| 93 |
+
pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
|
| 94 |
+
pos_inputs['use_neg'] = True
|
| 95 |
+
else:
|
| 96 |
+
pos_inputs['use_neg'] = False
|
| 97 |
+
|
| 98 |
+
# Run inference
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
outputs = model(**pos_inputs)
|
| 101 |
+
|
| 102 |
+
# Post-process outputs
|
| 103 |
+
outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
|
| 104 |
+
outputs["pred_logits"] = outputs["logits"]
|
| 105 |
+
|
| 106 |
+
# Use custom threshold if provided, otherwise use model default
|
| 107 |
+
threshold = box_threshold if box_threshold > 0 else model.box_threshold
|
| 108 |
+
results = post_process_grounded_object_detection(outputs, box_threshold=threshold)[0]
|
| 109 |
+
|
| 110 |
+
# Extract points
|
| 111 |
+
boxes = results["boxes"]
|
| 112 |
+
boxes = [box.tolist() for box in boxes]
|
| 113 |
+
points = [[box[0], box[1]] for box in boxes]
|
| 114 |
+
|
| 115 |
+
# Visualize results
|
| 116 |
+
img_w, img_h = image.size
|
| 117 |
+
img_draw = image.copy()
|
| 118 |
+
draw = ImageDraw.Draw(img_draw)
|
| 119 |
+
|
| 120 |
+
for point in points:
|
| 121 |
+
x = point[0] * img_w
|
| 122 |
+
y = point[1] * img_h
|
| 123 |
+
draw.ellipse(
|
| 124 |
+
[x - point_radius, y - point_radius, x + point_radius, y + point_radius],
|
| 125 |
+
fill=point_color
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
count = len(points)
|
| 129 |
+
|
| 130 |
+
return img_draw, f"Count: {count}"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# Create Gradio interface
|
| 134 |
+
def create_demo():
|
| 135 |
+
with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
|
| 136 |
+
gr.Markdown("""
|
| 137 |
+
# CountEx: Discriminative Visual Counting
|
| 138 |
+
|
| 139 |
+
Count specific objects in images using positive and negative text prompts.
|
| 140 |
+
|
| 141 |
+
**Positive Prompt**: Describe what you want to count (e.g., "Green Apple")
|
| 142 |
+
|
| 143 |
+
**Negative Prompt**: Describe what you want to exclude (e.g., "Red Apple")
|
| 144 |
+
""")
|
| 145 |
+
|
| 146 |
+
with gr.Row():
|
| 147 |
+
with gr.Column(scale=1):
|
| 148 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
| 149 |
+
|
| 150 |
+
pos_caption = gr.Textbox(
|
| 151 |
+
label="Positive Prompt",
|
| 152 |
+
placeholder="e.g., Green Apple",
|
| 153 |
+
value="Green Apple"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
neg_caption = gr.Textbox(
|
| 157 |
+
label="Negative Prompt (optional)",
|
| 158 |
+
placeholder="e.g., Red Apple",
|
| 159 |
+
value="Red Apple"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 163 |
+
box_threshold = gr.Slider(
|
| 164 |
+
minimum=0.0,
|
| 165 |
+
maximum=1.0,
|
| 166 |
+
value=0.0,
|
| 167 |
+
step=0.01,
|
| 168 |
+
label="Detection Threshold (0 = use model default)"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
point_radius = gr.Slider(
|
| 172 |
+
minimum=1,
|
| 173 |
+
maximum=20,
|
| 174 |
+
value=5,
|
| 175 |
+
step=1,
|
| 176 |
+
label="Point Radius"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
point_color = gr.Dropdown(
|
| 180 |
+
choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
|
| 181 |
+
value="blue",
|
| 182 |
+
label="Point Color"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
submit_btn = gr.Button("Count Objects", variant="primary")
|
| 186 |
+
|
| 187 |
+
with gr.Column(scale=1):
|
| 188 |
+
output_image = gr.Image(type="pil", label="Result")
|
| 189 |
+
count_output = gr.Textbox(label="Count Result")
|
| 190 |
+
|
| 191 |
+
# Example images
|
| 192 |
+
# gr.Examples(
|
| 193 |
+
# examples=[
|
| 194 |
+
# ["examples/apples.jpg", "Green Apple", "Red Apple"],
|
| 195 |
+
# ["examples/cars.jpg", "Red Car", "Blue Car"],
|
| 196 |
+
# ["examples/people.jpg", "Person wearing hat", "Person without hat"],
|
| 197 |
+
# ],
|
| 198 |
+
# inputs=[input_image, pos_caption, neg_caption],
|
| 199 |
+
# outputs=[output_image, count_output],
|
| 200 |
+
# fn=count_objects,
|
| 201 |
+
# cache_examples=False,
|
| 202 |
+
# )
|
| 203 |
+
|
| 204 |
+
submit_btn.click(
|
| 205 |
+
fn=count_objects,
|
| 206 |
+
inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
|
| 207 |
+
outputs=[output_image, count_output]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return demo
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
# Load model at startup
|
| 215 |
+
print("Loading model...")
|
| 216 |
+
load_model()
|
| 217 |
+
print("Model loaded!")
|
| 218 |
+
|
| 219 |
+
# Create and launch demo
|
| 220 |
+
demo = create_demo()
|
| 221 |
+
demo.launch()
|
hf_model/CountEX.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""
|
| 3 |
+
Negative Grounding DINO Model for Object Detection with Negative Caption Support.
|
| 4 |
+
|
| 5 |
+
This module extends the original GroundingDinoForObjectDetection to support negative captions
|
| 6 |
+
for improved object detection performance.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 12 |
+
from transformers.modeling_outputs import ModelOutput
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from .modeling_grounding_dino import (
|
| 15 |
+
GroundingDinoForObjectDetection,
|
| 16 |
+
GroundingDinoObjectDetectionOutput,
|
| 17 |
+
GroundingDinoEncoderOutput,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# density_fpn_head.py
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _bilinear(x, size):
|
| 28 |
+
return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DensityFPNHead(nn.Module):
|
| 32 |
+
def __init__(self,
|
| 33 |
+
in_channels: int = 512,
|
| 34 |
+
mid_channels: int = 128,
|
| 35 |
+
act_layer=nn.ReLU,
|
| 36 |
+
norm_layer=nn.BatchNorm2d):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
# ---- 1×1 lateral convs (P3–P6) ----
|
| 40 |
+
self.lateral = nn.ModuleList([
|
| 41 |
+
nn.Conv2d(in_channels, mid_channels, 1) for _ in range(4)
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
# ---- smooth convs after add ----
|
| 45 |
+
self.smooth = nn.ModuleList([
|
| 46 |
+
nn.Sequential(
|
| 47 |
+
nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False),
|
| 48 |
+
norm_layer(mid_channels),
|
| 49 |
+
act_layer(inplace=True),
|
| 50 |
+
) for _ in range(3) # P6→P5, P5→P4, P4→P3
|
| 51 |
+
])
|
| 52 |
+
|
| 53 |
+
self.up_blocks = nn.ModuleList([
|
| 54 |
+
nn.Sequential(
|
| 55 |
+
act_layer(inplace=True),
|
| 56 |
+
nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False),
|
| 57 |
+
norm_layer(mid_channels),
|
| 58 |
+
act_layer(inplace=True),
|
| 59 |
+
) for _ in range(3) # 167×94 → … → 1336×752
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
# ---- output 3×3 conv -> 1 ----
|
| 63 |
+
self.out_conv = nn.Conv2d(mid_channels, 1, 3, padding=1, bias=False)
|
| 64 |
+
|
| 65 |
+
def forward(self, feats):
|
| 66 |
+
assert len(feats) == 4, "Expect feats list = [P3,P4,P5,P6]"
|
| 67 |
+
|
| 68 |
+
# lateral 1×1
|
| 69 |
+
lat = [l(f) for l, f in zip(self.lateral, feats)]
|
| 70 |
+
|
| 71 |
+
# top-down FPN fusion
|
| 72 |
+
x = lat[-1] # P6
|
| 73 |
+
for i in range(3)[::-1]: # P5,P4,P3
|
| 74 |
+
x = _bilinear(x, lat[i].shape[-2:])
|
| 75 |
+
x = x + lat[i]
|
| 76 |
+
x = self.smooth[i](x)
|
| 77 |
+
|
| 78 |
+
# three-stage upsample + conv
|
| 79 |
+
for up in self.up_blocks:
|
| 80 |
+
h, w = x.shape[-2], x.shape[-1]
|
| 81 |
+
x = _bilinear(x, (h * 2, w * 2))
|
| 82 |
+
x = up(x)
|
| 83 |
+
|
| 84 |
+
x = self.out_conv(x)
|
| 85 |
+
return F.relu(x)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
import torch
|
| 89 |
+
import torch.nn as nn
|
| 90 |
+
import torch.nn.functional as F
|
| 91 |
+
|
| 92 |
+
def l2norm(x, dim=-1, eps=1e-6):
|
| 93 |
+
return x / (x.norm(dim=dim, keepdim=True) + eps)
|
| 94 |
+
|
| 95 |
+
# -----------------------------------
|
| 96 |
+
# 1) CommonFinderSimple
|
| 97 |
+
# learn r "common prototypes", representing the common representative of positive/negative
|
| 98 |
+
# non fancy: only MHA pooling + two light regularizations (shareability + diversity)
|
| 99 |
+
# -----------------------------------
|
| 100 |
+
class CommonFinderSimple(nn.Module):
|
| 101 |
+
"""
|
| 102 |
+
Inputs:
|
| 103 |
+
Q_pos: [B, K, D]
|
| 104 |
+
Q_neg: [B, K, D]
|
| 105 |
+
Returns:
|
| 106 |
+
C_rows: [B, r, D] # batch copied r common prototypes (unitized)
|
| 107 |
+
loss: scalar # small regularization: shareability + diversity
|
| 108 |
+
stats: dict
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, d_model=256, r=64, nhead=4,
|
| 111 |
+
share_w=0.02, div_w=0.02, ln_after=False):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.r = r
|
| 114 |
+
self.share_w = share_w
|
| 115 |
+
self.div_w = div_w
|
| 116 |
+
|
| 117 |
+
proto = torch.randn(r, d_model)
|
| 118 |
+
self.proto = nn.Parameter(l2norm(proto, -1)) # r×D learnable "core queries"
|
| 119 |
+
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
| 120 |
+
self.post = nn.Linear(d_model, d_model)
|
| 121 |
+
self.ln = nn.LayerNorm(d_model) if ln_after else nn.Identity()
|
| 122 |
+
|
| 123 |
+
def forward(self, Q_pos: torch.Tensor, Q_neg: torch.Tensor):
|
| 124 |
+
B, K, D = Q_pos.shape
|
| 125 |
+
seeds = self.proto[None].expand(B, -1, -1).contiguous() # [B,r,D]
|
| 126 |
+
X = torch.cat([Q_pos, Q_neg], dim=1) # [B,2K,D]
|
| 127 |
+
|
| 128 |
+
# use seeds to do one attention pooling on positive and negative sets, get r "common prototypes"
|
| 129 |
+
C, _ = self.attn(query=seeds, key=X, value=X) # [B,r,D]
|
| 130 |
+
C = l2norm(self.ln(self.post(C)), -1) # unitization
|
| 131 |
+
|
| 132 |
+
# ---- Simple regularization: encourage C to be close to both Q_pos and Q_neg, and diverse from each other ----
|
| 133 |
+
# Shareability: average of maximum cosine similarity between C and Q_pos/Q_neg
|
| 134 |
+
cos_pos = torch.einsum('brd,bkd->brk', C, l2norm(Q_pos, -1)) # [B,r,K]
|
| 135 |
+
cos_neg = torch.einsum('brd,bkd->brk', C, l2norm(Q_neg, -1))
|
| 136 |
+
share_term = -(cos_pos.amax(dim=-1).mean() + cos_neg.amax(dim=-1).mean())
|
| 137 |
+
|
| 138 |
+
# Diversity: cosine between C should not collapse
|
| 139 |
+
C0 = l2norm(self.proto, -1) # [r,D]
|
| 140 |
+
gram = C0 @ C0.t() # [r,r]
|
| 141 |
+
div_term = (gram - torch.eye(self.r, device=gram.device)).pow(2).mean()
|
| 142 |
+
|
| 143 |
+
loss = self.share_w * share_term + self.div_w * div_term
|
| 144 |
+
stats = {
|
| 145 |
+
'share_term': share_term.detach(),
|
| 146 |
+
'div_term': div_term.detach(),
|
| 147 |
+
'mean_cos_pos': cos_pos.mean().detach(),
|
| 148 |
+
'mean_cos_neg': cos_neg.mean().detach()
|
| 149 |
+
}
|
| 150 |
+
return C, loss, stats
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# -----------------------------------
|
| 154 |
+
# 2) NegExclusiveSimple
|
| 155 |
+
# Remove "common" information from negative queries: two simple strategies can be used independently or together
|
| 156 |
+
# (A) Soft removal: subtract the projection onto C (residual keeps non-common)
|
| 157 |
+
# (B) Filtering: only keep the Top-M negative samples least similar to C
|
| 158 |
+
# -----------------------------------
|
| 159 |
+
class NegExclusiveSimple(nn.Module):
|
| 160 |
+
"""
|
| 161 |
+
Inputs:
|
| 162 |
+
Q_neg: [B,K,D]
|
| 163 |
+
C_rows: [B,r,D] # common prototypes
|
| 164 |
+
Args:
|
| 165 |
+
mode: 'residual' | 'filter' | 'both'
|
| 166 |
+
M: Top-M for 'filter'
|
| 167 |
+
thresh: Filter threshold (max_cos_neg < thresh to keep), None means only use Top-M
|
| 168 |
+
Returns:
|
| 169 |
+
neg_refs: [B, M_or_K, D] # as negative reference (for next fusion)
|
| 170 |
+
aux: dict
|
| 171 |
+
"""
|
| 172 |
+
def __init__(self, mode='residual', M=16, thresh=None):
|
| 173 |
+
super().__init__()
|
| 174 |
+
assert mode in ('residual', 'filter', 'both')
|
| 175 |
+
self.mode = mode
|
| 176 |
+
self.M = M
|
| 177 |
+
self.thresh = thresh
|
| 178 |
+
|
| 179 |
+
def forward(self, Q_neg: torch.Tensor, C_rows: torch.Tensor):
|
| 180 |
+
B, K, D = Q_neg.shape
|
| 181 |
+
r = C_rows.size(1)
|
| 182 |
+
Qn = l2norm(Q_neg, -1)
|
| 183 |
+
C = l2norm(C_rows, -1)
|
| 184 |
+
|
| 185 |
+
sim = torch.einsum('bkd,brd->bkr', Qn, C).amax(dim=-1) # [B,K]
|
| 186 |
+
|
| 187 |
+
outputs = {}
|
| 188 |
+
if self.mode in ('residual', 'both'):
|
| 189 |
+
# proj = (Q · C^T) C -> [B,K,D]; first weight [B,K,r], then multiply C [B,r,D]
|
| 190 |
+
w = torch.einsum('bkd,brd->bkr', Qn, C) # [B,K,r]
|
| 191 |
+
proj = torch.einsum('bkr,brd->bkd', w, C) # [B,K,D]
|
| 192 |
+
neg_resid = l2norm(Qn - proj, -1) # non-common residual
|
| 193 |
+
outputs['residual'] = neg_resid
|
| 194 |
+
|
| 195 |
+
if self.mode in ('filter', 'both'):
|
| 196 |
+
excl_score = 1.0 - sim # large = away from common
|
| 197 |
+
if self.thresh is not None:
|
| 198 |
+
mask = (sim < self.thresh).float()
|
| 199 |
+
excl_score = excl_score * mask + (-1e4) * (1 - mask)
|
| 200 |
+
M = min(self.M, K)
|
| 201 |
+
topv, topi = torch.topk(excl_score, k=M, dim=1) # [B,M]
|
| 202 |
+
neg_top = torch.gather(Qn, 1, topi.unsqueeze(-1).expand(-1, -1, D))
|
| 203 |
+
outputs['filtered'] = neg_top
|
| 204 |
+
|
| 205 |
+
if self.mode == 'residual':
|
| 206 |
+
neg_refs = outputs['residual']
|
| 207 |
+
elif self.mode == 'filter':
|
| 208 |
+
neg_refs = outputs['filtered']
|
| 209 |
+
else:
|
| 210 |
+
R = outputs['residual'] # [B,K,D]
|
| 211 |
+
excl_score = 1.0 - sim
|
| 212 |
+
M = min(self.M, K)
|
| 213 |
+
topv, topi = torch.topk(excl_score, k=M, dim=1)
|
| 214 |
+
neg_refs = torch.gather(R, 1, topi.unsqueeze(-1).expand(-1, -1, D)) # [B,M,D]
|
| 215 |
+
|
| 216 |
+
aux = {
|
| 217 |
+
'mean_sim_to_common': sim.mean().detach(),
|
| 218 |
+
'kept_M': neg_refs.size(1)
|
| 219 |
+
}
|
| 220 |
+
return neg_refs, aux
|
| 221 |
+
|
| 222 |
+
import torch
|
| 223 |
+
import torch.nn as nn
|
| 224 |
+
import torch.nn.functional as F
|
| 225 |
+
|
| 226 |
+
def l2norm(x, dim=-1, eps=1e-6):
|
| 227 |
+
return x / (x.norm(dim=dim, keepdim=True) + eps)
|
| 228 |
+
|
| 229 |
+
class FusionNoGate(nn.Module):
|
| 230 |
+
"""
|
| 231 |
+
Direct fusion (no gating): fuse neg_ref into Q_pos via one cross-attn.
|
| 232 |
+
Variants:
|
| 233 |
+
- 'residual_sub': Q_new = Q_pos - scale * LN(Z)
|
| 234 |
+
- 'residual_add': Q_new = Q_pos + scale * LN(Z)
|
| 235 |
+
- 'concat_linear': Q_new = Q_pos + Linear([Q_pos; Z])
|
| 236 |
+
"""
|
| 237 |
+
def __init__(self, d_model=256, nhead=4, fusion_mode='residual_sub',
|
| 238 |
+
init_scale=0.2, dropout_p=0.0):
|
| 239 |
+
super().__init__()
|
| 240 |
+
assert fusion_mode in ('residual_sub', 'residual_add', 'concat_linear')
|
| 241 |
+
self.fusion_mode = fusion_mode
|
| 242 |
+
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
| 243 |
+
self.ln_z = nn.LayerNorm(d_model)
|
| 244 |
+
self.drop = nn.Dropout(dropout_p) if dropout_p > 0 else nn.Identity()
|
| 245 |
+
self.scale = nn.Parameter(torch.tensor(float(init_scale)))
|
| 246 |
+
if fusion_mode == 'concat_linear':
|
| 247 |
+
self.mix = nn.Linear(2 * d_model, d_model)
|
| 248 |
+
nn.init.zeros_(self.mix.weight)
|
| 249 |
+
nn.init.zeros_(self.mix.bias)
|
| 250 |
+
|
| 251 |
+
def forward(self, Q_pos: torch.Tensor, neg_ref: torch.Tensor):
|
| 252 |
+
"""
|
| 253 |
+
Q_pos: [B, K, D]
|
| 254 |
+
neg_ref: [B, M, D]
|
| 255 |
+
return: Q_new [B, K, D], stats dict
|
| 256 |
+
"""
|
| 257 |
+
B, K, D = Q_pos.shape
|
| 258 |
+
M = neg_ref.size(1)
|
| 259 |
+
if M == 0:
|
| 260 |
+
return Q_pos, {'kept': 0, 'scale': self.scale.detach()}
|
| 261 |
+
|
| 262 |
+
# 1) Cross-attention:
|
| 263 |
+
Z, attn_w = self.attn(query=Q_pos, key=neg_ref, value=neg_ref) # Z:[B,K,D]
|
| 264 |
+
Z = self.ln_z(Z)
|
| 265 |
+
Z = self.drop(Z)
|
| 266 |
+
|
| 267 |
+
# 2) wo gating
|
| 268 |
+
if self.fusion_mode == 'residual_sub':
|
| 269 |
+
Q_new = Q_pos - self.scale * Z
|
| 270 |
+
# print("z: ", Z.sum())
|
| 271 |
+
# print(torch.abs(Q_new - Q_pos).sum())
|
| 272 |
+
elif self.fusion_mode == 'residual_add':
|
| 273 |
+
Q_new = Q_pos + self.scale * Z
|
| 274 |
+
else: # 'concat_linear'
|
| 275 |
+
fused = torch.cat([Q_pos, Z], dim=-1) # [B,K,2D]
|
| 276 |
+
delta = self.mix(fused) # [B,K,D]
|
| 277 |
+
Q_new = Q_pos + delta
|
| 278 |
+
|
| 279 |
+
stats = {
|
| 280 |
+
'kept': M,
|
| 281 |
+
'attn_mean': attn_w.mean().detach(),
|
| 282 |
+
'fusion_scale': self.scale.detach()
|
| 283 |
+
}
|
| 284 |
+
return Q_new, stats
|
| 285 |
+
|
| 286 |
+
class QuerySideNegNaive(nn.Module):
|
| 287 |
+
def __init__(self, d_model=256, r=64, M=64, nhead=4,
|
| 288 |
+
excl_mode='both', excl_thresh=0.5, gamma_max=0.7,
|
| 289 |
+
share_w=0.02, div_w=0.02):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.common = CommonFinderSimple(d_model, r, nhead, share_w, div_w)
|
| 292 |
+
self.excl = NegExclusiveSimple(mode=excl_mode, M=M, thresh=excl_thresh)
|
| 293 |
+
self.fuse = FusionNoGate(d_model=d_model,
|
| 294 |
+
nhead=4,
|
| 295 |
+
fusion_mode='residual_sub', # or 'concat_linear'
|
| 296 |
+
init_scale=0.25,
|
| 297 |
+
dropout_p=0.1)
|
| 298 |
+
|
| 299 |
+
def forward(self, Q_pos: torch.Tensor, Q_neg: torch.Tensor):
|
| 300 |
+
C_rows, l_common, common_stats = self.common(Q_pos, Q_neg)
|
| 301 |
+
neg_refs, excl_stats = self.excl(Q_neg, C_rows)
|
| 302 |
+
Q_new, fuse_stats = self.fuse(Q_pos, neg_refs)
|
| 303 |
+
loss = l_common
|
| 304 |
+
stats = {}
|
| 305 |
+
stats.update(common_stats); stats.update(excl_stats); stats.update(fuse_stats)
|
| 306 |
+
return Q_new, loss, stats
|
| 307 |
+
|
| 308 |
+
def set_fusion_scale(self, scale: float):
|
| 309 |
+
del self.fuse.scale
|
| 310 |
+
self.fuse.scale = nn.Parameter(torch.tensor(scale))
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class CountEX(GroundingDinoForObjectDetection):
|
| 314 |
+
"""
|
| 315 |
+
Grounding DINO Model with negative caption support for improved object detection.
|
| 316 |
+
|
| 317 |
+
This model extends the original GroundingDinoForObjectDetection by adding
|
| 318 |
+
support for negative captions, which helps improve detection accuracy by
|
| 319 |
+
learning what NOT to detect.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
def __init__(self, config):
|
| 323 |
+
super().__init__(config)
|
| 324 |
+
|
| 325 |
+
# Initialize negative fusion modules directly in __init__
|
| 326 |
+
self.query_side_neg_pipeline = QuerySideNegNaive()
|
| 327 |
+
self.density_head = DensityFPNHead()
|
| 328 |
+
self.config = config
|
| 329 |
+
self.box_threshold = getattr(config, 'box_threshold', 0.4)
|
| 330 |
+
|
| 331 |
+
def forward(
|
| 332 |
+
self,
|
| 333 |
+
pixel_values: torch.FloatTensor,
|
| 334 |
+
input_ids: torch.LongTensor,
|
| 335 |
+
token_type_ids: torch.LongTensor = None,
|
| 336 |
+
attention_mask: torch.LongTensor = None,
|
| 337 |
+
pixel_mask: Optional[torch.BoolTensor] = None,
|
| 338 |
+
encoder_outputs: Optional[Union[GroundingDinoEncoderOutput, Tuple]] = None,
|
| 339 |
+
output_attentions: Optional[bool] = None,
|
| 340 |
+
output_hidden_states: Optional[bool] = None,
|
| 341 |
+
return_dict: Optional[bool] = None,
|
| 342 |
+
labels: List[Dict[str, Union[torch.LongTensor, torch.FloatTensor]]] = None,
|
| 343 |
+
# Negative prompt parameters
|
| 344 |
+
neg_pixel_values: Optional[torch.FloatTensor] = None,
|
| 345 |
+
neg_input_ids: Optional[torch.LongTensor] = None,
|
| 346 |
+
neg_token_type_ids: Optional[torch.LongTensor] = None,
|
| 347 |
+
neg_attention_mask: Optional[torch.LongTensor] = None,
|
| 348 |
+
neg_pixel_mask: Optional[torch.BoolTensor] = None,
|
| 349 |
+
**kwargs,
|
| 350 |
+
):
|
| 351 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 352 |
+
use_neg = kwargs.get('use_neg', True)
|
| 353 |
+
# Get positive outputs
|
| 354 |
+
pos_kwargs = {
|
| 355 |
+
'exemplars': kwargs.get('pos_exemplars', None),
|
| 356 |
+
}
|
| 357 |
+
outputs = self.model(
|
| 358 |
+
pixel_values=pixel_values,
|
| 359 |
+
input_ids=input_ids,
|
| 360 |
+
token_type_ids=token_type_ids,
|
| 361 |
+
attention_mask=attention_mask,
|
| 362 |
+
pixel_mask=pixel_mask,
|
| 363 |
+
encoder_outputs=encoder_outputs,
|
| 364 |
+
output_attentions=output_attentions,
|
| 365 |
+
output_hidden_states=output_hidden_states,
|
| 366 |
+
return_dict=return_dict,
|
| 367 |
+
**pos_kwargs,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
spatial_shapes = outputs.spatial_shapes
|
| 371 |
+
token_num = 0
|
| 372 |
+
token_num_list = [0]
|
| 373 |
+
for i in range(len(spatial_shapes)):
|
| 374 |
+
token_num += spatial_shapes[i][0] * spatial_shapes[i][1]
|
| 375 |
+
token_num_list.append(token_num.item())
|
| 376 |
+
|
| 377 |
+
positive_feature_maps = []
|
| 378 |
+
encoder_last_hidden_state_vision = outputs.encoder_last_hidden_state_vision
|
| 379 |
+
for i in range(len(spatial_shapes)):
|
| 380 |
+
feature_map = encoder_last_hidden_state_vision[:, token_num_list[i]:token_num_list[i+1], :]
|
| 381 |
+
spatial_shape = spatial_shapes[i]
|
| 382 |
+
b, t, d = feature_map.shape
|
| 383 |
+
feature_map = feature_map.reshape(b, spatial_shape[0], spatial_shape[1], d)
|
| 384 |
+
positive_feature_maps.append(feature_map)
|
| 385 |
+
|
| 386 |
+
# Get negative outputs
|
| 387 |
+
neg_kwargs = {
|
| 388 |
+
'exemplars': kwargs.get('neg_exemplars', None),
|
| 389 |
+
}
|
| 390 |
+
# print(kwargs)
|
| 391 |
+
neg_outputs = self.model(
|
| 392 |
+
pixel_values=neg_pixel_values,
|
| 393 |
+
input_ids=neg_input_ids,
|
| 394 |
+
token_type_ids=neg_token_type_ids,
|
| 395 |
+
attention_mask=neg_attention_mask,
|
| 396 |
+
pixel_mask=neg_pixel_mask,
|
| 397 |
+
encoder_outputs=encoder_outputs,
|
| 398 |
+
output_attentions=output_attentions,
|
| 399 |
+
output_hidden_states=output_hidden_states,
|
| 400 |
+
return_dict=return_dict,
|
| 401 |
+
**neg_kwargs,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
neg_encoder_last_hidden_state_vision = neg_outputs.encoder_last_hidden_state_vision
|
| 405 |
+
neg_positive_feature_maps = []
|
| 406 |
+
for i in range(len(spatial_shapes)):
|
| 407 |
+
feature_map = neg_encoder_last_hidden_state_vision[:, token_num_list[i]:token_num_list[i+1], :]
|
| 408 |
+
spatial_shape = spatial_shapes[i]
|
| 409 |
+
b, t, d = feature_map.shape
|
| 410 |
+
feature_map = feature_map.reshape(b, spatial_shape[0], spatial_shape[1], d)
|
| 411 |
+
neg_positive_feature_maps.append(feature_map)
|
| 412 |
+
|
| 413 |
+
if return_dict:
|
| 414 |
+
hidden_states = outputs.intermediate_hidden_states
|
| 415 |
+
neg_hidden_states = neg_outputs.intermediate_hidden_states
|
| 416 |
+
else:
|
| 417 |
+
hidden_states = outputs[2]
|
| 418 |
+
neg_hidden_states = neg_outputs[2]
|
| 419 |
+
|
| 420 |
+
idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
|
| 421 |
+
enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]
|
| 422 |
+
hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
|
| 423 |
+
init_reference_points = outputs.init_reference_points if return_dict else outputs[1]
|
| 424 |
+
inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3]
|
| 425 |
+
|
| 426 |
+
# drop the exemplar tokens if used
|
| 427 |
+
pos_exemplars = pos_kwargs.get('pos_exemplars', None)
|
| 428 |
+
neg_exemplars = neg_kwargs.get('neg_exemplars', None)
|
| 429 |
+
if pos_exemplars is not None or neg_exemplars is not None or attention_mask.shape[1] != enc_text_hidden_state.shape[1]:
|
| 430 |
+
enc_text_hidden_state = enc_text_hidden_state[:, :enc_text_hidden_state.shape[1] - 3, :]
|
| 431 |
+
|
| 432 |
+
# class logits + predicted bounding boxes
|
| 433 |
+
outputs_classes = []
|
| 434 |
+
outputs_coords = []
|
| 435 |
+
|
| 436 |
+
# Apply negative fusion
|
| 437 |
+
if use_neg:
|
| 438 |
+
# print("Using negative fusions")
|
| 439 |
+
#neg_hidden_states = self.negative_semantic_extractor(neg_hidden_states)
|
| 440 |
+
#hidden_states = self.negative_fusion_module(hidden_states, neg_hidden_states)
|
| 441 |
+
hidden_states = hidden_states.squeeze(0)
|
| 442 |
+
neg_hidden_states = neg_hidden_states.squeeze(0)
|
| 443 |
+
hidden_states, extra_loss, logs = self.query_side_neg_pipeline(hidden_states, neg_hidden_states)
|
| 444 |
+
hidden_states = hidden_states.unsqueeze(0)
|
| 445 |
+
neg_hidden_states = neg_hidden_states.unsqueeze(0)
|
| 446 |
+
# print("extra_loss: ", extra_loss)
|
| 447 |
+
else:
|
| 448 |
+
# print("Not using negative fusions")
|
| 449 |
+
extra_loss = None
|
| 450 |
+
logs = None
|
| 451 |
+
# print("Not using negative fusion")
|
| 452 |
+
# print("extra_loss: ", extra_loss)
|
| 453 |
+
|
| 454 |
+
# predict class and bounding box deltas for each stage
|
| 455 |
+
num_levels = hidden_states.shape[1]
|
| 456 |
+
for level in range(num_levels):
|
| 457 |
+
if level == 0:
|
| 458 |
+
reference = init_reference_points
|
| 459 |
+
else:
|
| 460 |
+
reference = inter_references_points[:, level - 1]
|
| 461 |
+
reference = torch.special.logit(reference, eps=1e-5)
|
| 462 |
+
|
| 463 |
+
# print("hidden_states[:, level]: ", hidden_states[:, level].shape)
|
| 464 |
+
# print("enc_text_hidden_state: ", enc_text_hidden_state.shape)
|
| 465 |
+
# print("attention_mask: ", attention_mask.shape)
|
| 466 |
+
|
| 467 |
+
assert attention_mask.shape[1] == enc_text_hidden_state.shape[1], "Attention mask and text hidden state have different lengths: {} != {}".format(attention_mask.shape[1], enc_text_hidden_state.shape[1])
|
| 468 |
+
outputs_class = self.class_embed[level](
|
| 469 |
+
vision_hidden_state=hidden_states[:, level],
|
| 470 |
+
text_hidden_state=enc_text_hidden_state,
|
| 471 |
+
text_token_mask=attention_mask.bool(),
|
| 472 |
+
)
|
| 473 |
+
delta_bbox = self.bbox_embed[level](hidden_states[:, level])
|
| 474 |
+
|
| 475 |
+
reference_coordinates = reference.shape[-1]
|
| 476 |
+
if reference_coordinates == 4:
|
| 477 |
+
outputs_coord_logits = delta_bbox + reference
|
| 478 |
+
elif reference_coordinates == 2:
|
| 479 |
+
delta_bbox[..., :2] += reference
|
| 480 |
+
outputs_coord_logits = delta_bbox
|
| 481 |
+
else:
|
| 482 |
+
raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
|
| 483 |
+
outputs_coord = outputs_coord_logits.sigmoid()
|
| 484 |
+
outputs_classes.append(outputs_class)
|
| 485 |
+
outputs_coords.append(outputs_coord)
|
| 486 |
+
outputs_class = torch.stack(outputs_classes)
|
| 487 |
+
outputs_coord = torch.stack(outputs_coords)
|
| 488 |
+
|
| 489 |
+
logits = outputs_class[-1]
|
| 490 |
+
pred_boxes = outputs_coord[-1]
|
| 491 |
+
|
| 492 |
+
loss, loss_dict, auxiliary_outputs = None, None, None
|
| 493 |
+
if not return_dict:
|
| 494 |
+
if auxiliary_outputs is not None:
|
| 495 |
+
output = (logits, pred_boxes) + auxiliary_outputs + outputs
|
| 496 |
+
else:
|
| 497 |
+
output = (logits, pred_boxes) + outputs
|
| 498 |
+
tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
|
| 499 |
+
|
| 500 |
+
return tuple_outputs
|
| 501 |
+
|
| 502 |
+
all_feats = []
|
| 503 |
+
for pf, npf in zip(positive_feature_maps, neg_positive_feature_maps):
|
| 504 |
+
pf = pf.permute(0, 3, 1, 2)
|
| 505 |
+
npf = npf.permute(0, 3, 1, 2)
|
| 506 |
+
all_feats.append(torch.cat([pf, npf], dim=1))
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# pos_feat = positive_feature_maps[0].permute(0, 3, 1, 2)
|
| 510 |
+
# neg_feat = neg_positive_feature_maps[0].permute(0, 3, 1, 2)
|
| 511 |
+
# pos_minus_neg_feat = F.relu(pos_feat - neg_feat)
|
| 512 |
+
# density_feat_map = torch.cat([pos_feat, neg_feat, pos_minus_neg_feat], dim=1)
|
| 513 |
+
# density_feat_map = torch.cat([pos_feat, neg_feat], dim=1)
|
| 514 |
+
density_map_pred = self.density_head(all_feats)
|
| 515 |
+
|
| 516 |
+
dict_outputs = GroundingDinoObjectDetectionOutput(
|
| 517 |
+
loss=loss,
|
| 518 |
+
loss_dict=loss_dict,
|
| 519 |
+
logits=logits,
|
| 520 |
+
pred_boxes=pred_boxes,
|
| 521 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 522 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 523 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 524 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 525 |
+
encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision,
|
| 526 |
+
encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text,
|
| 527 |
+
encoder_vision_hidden_states=outputs.encoder_vision_hidden_states,
|
| 528 |
+
encoder_text_hidden_states=outputs.encoder_text_hidden_states,
|
| 529 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 530 |
+
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
| 531 |
+
intermediate_reference_points=outputs.intermediate_reference_points,
|
| 532 |
+
init_reference_points=outputs.init_reference_points,
|
| 533 |
+
enc_outputs_class=outputs.enc_outputs_class,
|
| 534 |
+
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
| 535 |
+
spatial_shapes=outputs.spatial_shapes,
|
| 536 |
+
positive_feature_maps=positive_feature_maps,
|
| 537 |
+
negative_feature_maps=neg_positive_feature_maps,
|
| 538 |
+
density_map_pred=density_map_pred,
|
| 539 |
+
extra_loss=extra_loss,
|
| 540 |
+
extra_logs=logs,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
return dict_outputs
|
hf_model/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""
|
| 3 |
+
HF Model package for Grounding DINO with negative caption support.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .modeling_grounding_dino import (
|
| 7 |
+
GroundingDinoForObjectDetection,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
from .CountEX import (
|
| 11 |
+
CountEX
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"CountEX",
|
| 16 |
+
]
|
hf_model/mmdet2groundingdino_swinb.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mmdet to groundingdino
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import torch
|
| 5 |
+
from mmengine.runner import CheckpointLoader
|
| 6 |
+
|
| 7 |
+
# convert the functions from mmdet to groundingdino
|
| 8 |
+
def correct_unfold_reduction_order(x):
|
| 9 |
+
out_channel, in_channel = x.shape
|
| 10 |
+
x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2)
|
| 11 |
+
x = x[:, [0, 2, 1, 3], :]
|
| 12 |
+
x = x.reshape(out_channel, in_channel)
|
| 13 |
+
return x
|
| 14 |
+
|
| 15 |
+
def correct_unfold_norm_order(x):
|
| 16 |
+
in_channel = x.shape[0]
|
| 17 |
+
x = x.reshape(in_channel // 4, 4).transpose(0, 1)
|
| 18 |
+
x = x[[0, 2, 1, 3], :]
|
| 19 |
+
x = x.reshape(in_channel)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
def convert(ckpt):
|
| 23 |
+
"""Inverse mapping of checkpoint parameters to their original names."""
|
| 24 |
+
# Create a dictionary to hold the reversed checkpoint
|
| 25 |
+
new_ckpt = OrderedDict()
|
| 26 |
+
|
| 27 |
+
for k, v in list(ckpt.items()):
|
| 28 |
+
new_v = v # Start with the original value
|
| 29 |
+
|
| 30 |
+
# Inverse rules based on the convert function (from specific to general)
|
| 31 |
+
if k.startswith('decoder'):
|
| 32 |
+
new_k = k.replace('decoder', 'transformer.decoder')
|
| 33 |
+
if 'norms.2' in new_k:
|
| 34 |
+
new_k = new_k.replace('norms.2', 'norm1')
|
| 35 |
+
if 'norms.1' in new_k:
|
| 36 |
+
new_k = new_k.replace('norms.1', 'catext_norm')
|
| 37 |
+
if 'norms.0' in new_k:
|
| 38 |
+
new_k = new_k.replace('norms.0', 'norm2')
|
| 39 |
+
if 'norms.3' in new_k:
|
| 40 |
+
new_k = new_k.replace('norms.3', 'norm3')
|
| 41 |
+
if 'cross_attn_text' in new_k:
|
| 42 |
+
new_k = new_k.replace('cross_attn_text', 'ca_text')
|
| 43 |
+
new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight')
|
| 44 |
+
new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias')
|
| 45 |
+
new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight')
|
| 46 |
+
new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias')
|
| 47 |
+
if 'ffn.layers.0.0' in new_k:
|
| 48 |
+
new_k = new_k.replace('ffn.layers.0.0', 'linear1')
|
| 49 |
+
if 'ffn.layers.1' in new_k:
|
| 50 |
+
new_k = new_k.replace('ffn.layers.1', 'linear2')
|
| 51 |
+
if 'self_attn.attn' in new_k:
|
| 52 |
+
new_k = new_k.replace('self_attn.attn', 'self_attn')
|
| 53 |
+
|
| 54 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 55 |
+
|
| 56 |
+
#########################################################################
|
| 57 |
+
|
| 58 |
+
# encoder部分最后的reg_layer_id是6,和decoder区分开来
|
| 59 |
+
elif k.startswith('bbox_head.reg_branches.6'):
|
| 60 |
+
if k.startswith('bbox_head.reg_branches.6.0'):
|
| 61 |
+
new_k = k.replace('bbox_head.reg_branches.6.0',
|
| 62 |
+
'transformer.enc_out_bbox_embed.layers.0')
|
| 63 |
+
if k.startswith('bbox_head.reg_branches.6.2'):
|
| 64 |
+
new_k = k.replace('bbox_head.reg_branches.6.2',
|
| 65 |
+
'transformer.enc_out_bbox_embed.layers.1')
|
| 66 |
+
if k.startswith('bbox_head.reg_branches.6.4'):
|
| 67 |
+
new_k = k.replace('bbox_head.reg_branches.6.4',
|
| 68 |
+
'transformer.enc_out_bbox_embed.layers.2')
|
| 69 |
+
|
| 70 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 71 |
+
|
| 72 |
+
#########################################################################
|
| 73 |
+
|
| 74 |
+
elif k.startswith('query_embedding'):
|
| 75 |
+
new_k = k.replace('query_embedding', 'transformer.tgt_embed')
|
| 76 |
+
|
| 77 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 78 |
+
|
| 79 |
+
#########################################################################
|
| 80 |
+
|
| 81 |
+
elif k.startswith('bbox_head.reg_branches'):
|
| 82 |
+
# mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint
|
| 83 |
+
# groundingdino有两部分参数值是一致的
|
| 84 |
+
# 分别是bbox_embed和transformer.decoder.embed
|
| 85 |
+
# 所以mmdet直接将两部分参数进行了“合并”
|
| 86 |
+
reg_layer_id = int(k.split('.')[2])
|
| 87 |
+
linear_id = int(k.split('.')[3])
|
| 88 |
+
weight_or_bias = k.split('.')[-1]
|
| 89 |
+
new_k1 = 'transformer.decoder.bbox_embed.' + \
|
| 90 |
+
str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
|
| 91 |
+
new_k2 = 'bbox_embed.' + \
|
| 92 |
+
str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
|
| 93 |
+
|
| 94 |
+
new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
|
| 95 |
+
new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
|
| 96 |
+
|
| 97 |
+
#########################################################################
|
| 98 |
+
|
| 99 |
+
elif k.startswith('bbox_head.cls_branches.6'):
|
| 100 |
+
# mmdet在contrastive_embed中添加了bias项
|
| 101 |
+
# 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed
|
| 102 |
+
new_k = 'transformer.enc_out_class_embed.bias'
|
| 103 |
+
|
| 104 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 105 |
+
|
| 106 |
+
#########################################################################
|
| 107 |
+
|
| 108 |
+
elif k.startswith('bbox_head.cls_branches'):
|
| 109 |
+
# mmdet在contrastive_embed中添加了bias项
|
| 110 |
+
new_k1 = 'transformer.decoder.class_embed.' + k[-6:]
|
| 111 |
+
new_k2 = 'class_embed.' + k[-6:]
|
| 112 |
+
|
| 113 |
+
new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
|
| 114 |
+
new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
|
| 115 |
+
|
| 116 |
+
#########################################################################
|
| 117 |
+
|
| 118 |
+
elif k.startswith('memory_trans_'):
|
| 119 |
+
if k.startswith('memory_trans_fc'):
|
| 120 |
+
new_k = k.replace('memory_trans_fc', 'transformer.enc_output')
|
| 121 |
+
elif k.startswith('memory_trans_norm'):
|
| 122 |
+
new_k = k.replace('memory_trans_norm', 'transformer.enc_output_norm')
|
| 123 |
+
|
| 124 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 125 |
+
|
| 126 |
+
#########################################################################
|
| 127 |
+
|
| 128 |
+
elif k.startswith('encoder'):
|
| 129 |
+
new_k = k.replace('encoder', 'transformer.encoder')
|
| 130 |
+
new_k = new_k.replace('norms.0', 'norm1')
|
| 131 |
+
new_k = new_k.replace('norms.1', 'norm2')
|
| 132 |
+
new_k = new_k.replace('norms.2', 'norm3')
|
| 133 |
+
new_k = new_k.replace('ffn.layers.0.0', 'linear1')
|
| 134 |
+
new_k = new_k.replace('ffn.layers.1', 'linear2')
|
| 135 |
+
if 'text_layers' in new_k:
|
| 136 |
+
new_k = new_k.replace('self_attn.attn', 'self_attn')
|
| 137 |
+
|
| 138 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 139 |
+
|
| 140 |
+
#########################################################################
|
| 141 |
+
|
| 142 |
+
elif k.startswith('level_embed'):
|
| 143 |
+
new_k = k.replace('level_embed', 'transformer.level_embed')
|
| 144 |
+
|
| 145 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 146 |
+
|
| 147 |
+
#########################################################################
|
| 148 |
+
|
| 149 |
+
elif k.startswith('neck.convs'):
|
| 150 |
+
new_k = k.replace('neck.convs', 'input_proj')
|
| 151 |
+
new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3')
|
| 152 |
+
new_k = new_k.replace('conv.weight', '0.weight')
|
| 153 |
+
new_k = new_k.replace('conv.bias', '0.bias')
|
| 154 |
+
new_k = new_k.replace('gn.weight', '1.weight')
|
| 155 |
+
new_k = new_k.replace('gn.bias', '1.bias')
|
| 156 |
+
|
| 157 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 158 |
+
|
| 159 |
+
#########################################################################
|
| 160 |
+
|
| 161 |
+
elif 'neck.extra_convs.0' in k:
|
| 162 |
+
new_k = k.replace('neck.extra_convs.0', 'neck.convs.3')
|
| 163 |
+
new_k = new_k.replace('neck.convs', 'input_proj')
|
| 164 |
+
new_k = new_k.replace('conv.weight', '0.weight')
|
| 165 |
+
new_k = new_k.replace('conv.bias', '0.bias')
|
| 166 |
+
new_k = new_k.replace('gn.weight', '1.weight')
|
| 167 |
+
new_k = new_k.replace('gn.bias', '1.bias')
|
| 168 |
+
|
| 169 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 170 |
+
|
| 171 |
+
#########################################################################
|
| 172 |
+
|
| 173 |
+
elif k.startswith('text_feat_map'):
|
| 174 |
+
new_k = k.replace('text_feat_map', 'feat_map')
|
| 175 |
+
|
| 176 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 177 |
+
|
| 178 |
+
#########################################################################
|
| 179 |
+
|
| 180 |
+
elif k.startswith('language_model.language_backbone.body.model'):
|
| 181 |
+
new_k = k.replace('language_model.language_backbone.body.model', 'bert')
|
| 182 |
+
|
| 183 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 184 |
+
|
| 185 |
+
#########################################################################
|
| 186 |
+
|
| 187 |
+
elif k.startswith('backbone'):
|
| 188 |
+
new_k = k.replace('backbone', 'backbone.0')
|
| 189 |
+
if 'patch_embed.projection' in new_k:
|
| 190 |
+
new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj')
|
| 191 |
+
elif 'drop_after_pos' in new_k:
|
| 192 |
+
new_k = new_k.replace('drop_after_pos', 'pos_drop')
|
| 193 |
+
|
| 194 |
+
if 'stages' in new_k:
|
| 195 |
+
new_k = new_k.replace('stages', 'layers')
|
| 196 |
+
if 'ffn.layers.0.0' in new_k:
|
| 197 |
+
new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1')
|
| 198 |
+
elif 'ffn.layers.1' in new_k:
|
| 199 |
+
new_k = new_k.replace('ffn.layers.1', 'mlp.fc2')
|
| 200 |
+
elif 'attn.w_msa' in new_k:
|
| 201 |
+
new_k = new_k.replace('attn.w_msa', 'attn')
|
| 202 |
+
|
| 203 |
+
if 'downsample' in k:
|
| 204 |
+
if 'reduction.' in k:
|
| 205 |
+
new_v = correct_unfold_reduction_order(v)
|
| 206 |
+
elif 'norm.' in k:
|
| 207 |
+
new_v = correct_unfold_norm_order(v)
|
| 208 |
+
|
| 209 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 210 |
+
|
| 211 |
+
#########################################################################
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
print('skip:', k)
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
# if 'transformer.decoder.bbox_embed' in new_k:
|
| 218 |
+
# new_k = new_k.replace('transformer.decoder.bbox_embed', 'bbox_embed')
|
| 219 |
+
# if new_k.startswith('module.'):
|
| 220 |
+
# new_k = new_k.replace('module.', '')
|
| 221 |
+
|
| 222 |
+
return new_ckpt
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
parser = argparse.ArgumentParser(
|
| 226 |
+
description='Convert keys to GroundingDINO style.')
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
'src',
|
| 229 |
+
nargs='?',
|
| 230 |
+
default='grounding_dino_swin-b_pretrain_all-f9818a7c.pth',
|
| 231 |
+
help='src model path or url')
|
| 232 |
+
# The dst path must be a full path of the new checkpoint.
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
'dst',
|
| 235 |
+
nargs='?',
|
| 236 |
+
default='mmdet_swinb_cogcoor.pth_groundingdino.pth',
|
| 237 |
+
help='save path')
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
| 241 |
+
|
| 242 |
+
# mmdet中是state_dict而不是model
|
| 243 |
+
if 'state_dict' in checkpoint:
|
| 244 |
+
state_dict = checkpoint['state_dict']
|
| 245 |
+
else:
|
| 246 |
+
state_dict = checkpoint
|
| 247 |
+
|
| 248 |
+
weight = convert(state_dict)
|
| 249 |
+
torch.save(weight, args.dst)
|
| 250 |
+
# sha = subprocess.check_output(['sha256sum', args.dst]).decode()
|
| 251 |
+
# sha = calculate_sha256(args.dst)
|
| 252 |
+
# final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8])
|
| 253 |
+
# subprocess.Popen(['mv', args.dst, final_file])
|
| 254 |
+
print(f'Done!!, save to {args.dst}')
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
main()
|
| 258 |
+
|
| 259 |
+
# skip: dn_query_generator.label_embedding.weight
|
hf_model/mmdet2groundingdino_swinl.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mmdet to groundingdino
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import torch
|
| 5 |
+
from mmengine.runner import CheckpointLoader
|
| 6 |
+
|
| 7 |
+
# convert the functions from mmdet to groundingdino
|
| 8 |
+
def correct_unfold_reduction_order(x):
|
| 9 |
+
out_channel, in_channel = x.shape
|
| 10 |
+
x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2)
|
| 11 |
+
x = x[:, [0, 2, 1, 3], :]
|
| 12 |
+
x = x.reshape(out_channel, in_channel)
|
| 13 |
+
return x
|
| 14 |
+
|
| 15 |
+
def correct_unfold_norm_order(x):
|
| 16 |
+
in_channel = x.shape[0]
|
| 17 |
+
x = x.reshape(in_channel // 4, 4).transpose(0, 1)
|
| 18 |
+
x = x[[0, 2, 1, 3], :]
|
| 19 |
+
x = x.reshape(in_channel)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
def convert(ckpt):
|
| 23 |
+
"""Inverse mapping of checkpoint parameters to their original names."""
|
| 24 |
+
# Create a dictionary to hold the reversed checkpoint
|
| 25 |
+
new_ckpt = OrderedDict()
|
| 26 |
+
|
| 27 |
+
for k, v in list(ckpt.items()):
|
| 28 |
+
new_v = v # Start with the original value
|
| 29 |
+
|
| 30 |
+
# Inverse rules based on the convert function (from specific to general)
|
| 31 |
+
if k.startswith('decoder'):
|
| 32 |
+
new_k = k.replace('decoder', 'transformer.decoder')
|
| 33 |
+
if 'norms.2' in new_k:
|
| 34 |
+
new_k = new_k.replace('norms.2', 'norm1')
|
| 35 |
+
if 'norms.1' in new_k:
|
| 36 |
+
new_k = new_k.replace('norms.1', 'catext_norm')
|
| 37 |
+
if 'norms.0' in new_k:
|
| 38 |
+
new_k = new_k.replace('norms.0', 'norm2')
|
| 39 |
+
if 'norms.3' in new_k:
|
| 40 |
+
new_k = new_k.replace('norms.3', 'norm3')
|
| 41 |
+
if 'cross_attn_text' in new_k:
|
| 42 |
+
new_k = new_k.replace('cross_attn_text', 'ca_text')
|
| 43 |
+
new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight')
|
| 44 |
+
new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias')
|
| 45 |
+
new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight')
|
| 46 |
+
new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias')
|
| 47 |
+
if 'ffn.layers.0.0' in new_k:
|
| 48 |
+
new_k = new_k.replace('ffn.layers.0.0', 'linear1')
|
| 49 |
+
if 'ffn.layers.1' in new_k:
|
| 50 |
+
new_k = new_k.replace('ffn.layers.1', 'linear2')
|
| 51 |
+
if 'self_attn.attn' in new_k:
|
| 52 |
+
new_k = new_k.replace('self_attn.attn', 'self_attn')
|
| 53 |
+
|
| 54 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 55 |
+
|
| 56 |
+
#########################################################################
|
| 57 |
+
|
| 58 |
+
# encoder部分最后的reg_layer_id是6,和decoder区分开来
|
| 59 |
+
elif k.startswith('bbox_head.reg_branches.6'):
|
| 60 |
+
if k.startswith('bbox_head.reg_branches.6.0'):
|
| 61 |
+
new_k = k.replace('bbox_head.reg_branches.6.0',
|
| 62 |
+
'transformer.enc_out_bbox_embed.layers.0')
|
| 63 |
+
if k.startswith('bbox_head.reg_branches.6.2'):
|
| 64 |
+
new_k = k.replace('bbox_head.reg_branches.6.2',
|
| 65 |
+
'transformer.enc_out_bbox_embed.layers.1')
|
| 66 |
+
if k.startswith('bbox_head.reg_branches.6.4'):
|
| 67 |
+
new_k = k.replace('bbox_head.reg_branches.6.4',
|
| 68 |
+
'transformer.enc_out_bbox_embed.layers.2')
|
| 69 |
+
|
| 70 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 71 |
+
|
| 72 |
+
#########################################################################
|
| 73 |
+
|
| 74 |
+
elif k.startswith('query_embedding'):
|
| 75 |
+
new_k = k.replace('query_embedding', 'transformer.tgt_embed')
|
| 76 |
+
|
| 77 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 78 |
+
|
| 79 |
+
#########################################################################
|
| 80 |
+
|
| 81 |
+
elif k.startswith('bbox_head.reg_branches'):
|
| 82 |
+
# mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint
|
| 83 |
+
# groundingdino有两部分参数值是一致的
|
| 84 |
+
# 分别是bbox_embed和transformer.decoder.embed
|
| 85 |
+
# 所以mmdet直接将两部分参数进行了“合并”
|
| 86 |
+
reg_layer_id = int(k.split('.')[2])
|
| 87 |
+
linear_id = int(k.split('.')[3])
|
| 88 |
+
weight_or_bias = k.split('.')[-1]
|
| 89 |
+
new_k1 = 'transformer.decoder.bbox_embed.' + \
|
| 90 |
+
str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
|
| 91 |
+
new_k2 = 'bbox_embed.' + \
|
| 92 |
+
str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
|
| 93 |
+
|
| 94 |
+
new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
|
| 95 |
+
new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
|
| 96 |
+
|
| 97 |
+
#########################################################################
|
| 98 |
+
|
| 99 |
+
elif k.startswith('bbox_head.cls_branches.6'):
|
| 100 |
+
# mmdet在contrastive_embed中添加了bias项
|
| 101 |
+
# 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed
|
| 102 |
+
new_k = 'transformer.enc_out_class_embed.bias'
|
| 103 |
+
|
| 104 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 105 |
+
|
| 106 |
+
#########################################################################
|
| 107 |
+
|
| 108 |
+
elif k.startswith('bbox_head.cls_branches'):
|
| 109 |
+
# mmdet在contrastive_embed中添加了bias项
|
| 110 |
+
new_k1 = 'transformer.decoder.class_embed.' + k[-6:]
|
| 111 |
+
new_k2 = 'class_embed.' + k[-6:]
|
| 112 |
+
|
| 113 |
+
new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
|
| 114 |
+
new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
|
| 115 |
+
|
| 116 |
+
#########################################################################
|
| 117 |
+
|
| 118 |
+
elif k.startswith('memory_trans_'):
|
| 119 |
+
if k.startswith('memory_trans_fc'):
|
| 120 |
+
new_k = k.replace('memory_trans_fc', 'transformer.enc_output')
|
| 121 |
+
elif k.startswith('memory_trans_norm'):
|
| 122 |
+
new_k = k.replace('memory_trans_norm', 'transformer.enc_output_norm')
|
| 123 |
+
|
| 124 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 125 |
+
|
| 126 |
+
#########################################################################
|
| 127 |
+
|
| 128 |
+
elif k.startswith('encoder'):
|
| 129 |
+
new_k = k.replace('encoder', 'transformer.encoder')
|
| 130 |
+
new_k = new_k.replace('norms.0', 'norm1')
|
| 131 |
+
new_k = new_k.replace('norms.1', 'norm2')
|
| 132 |
+
new_k = new_k.replace('norms.2', 'norm3')
|
| 133 |
+
new_k = new_k.replace('ffn.layers.0.0', 'linear1')
|
| 134 |
+
new_k = new_k.replace('ffn.layers.1', 'linear2')
|
| 135 |
+
if 'text_layers' in new_k:
|
| 136 |
+
new_k = new_k.replace('self_attn.attn', 'self_attn')
|
| 137 |
+
|
| 138 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 139 |
+
|
| 140 |
+
#########################################################################
|
| 141 |
+
|
| 142 |
+
elif k.startswith('level_embed'):
|
| 143 |
+
new_k = k.replace('level_embed', 'transformer.level_embed')
|
| 144 |
+
|
| 145 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 146 |
+
|
| 147 |
+
#########################################################################
|
| 148 |
+
|
| 149 |
+
elif k.startswith('neck.convs'):
|
| 150 |
+
new_k = k.replace('neck.convs', 'input_proj')
|
| 151 |
+
new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3')
|
| 152 |
+
new_k = new_k.replace('conv.weight', '0.weight')
|
| 153 |
+
new_k = new_k.replace('conv.bias', '0.bias')
|
| 154 |
+
new_k = new_k.replace('gn.weight', '1.weight')
|
| 155 |
+
new_k = new_k.replace('gn.bias', '1.bias')
|
| 156 |
+
|
| 157 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 158 |
+
|
| 159 |
+
#########################################################################
|
| 160 |
+
|
| 161 |
+
elif 'neck.extra_convs.0' in k:
|
| 162 |
+
new_k = k.replace('neck.extra_convs.0', 'neck.convs.4')
|
| 163 |
+
new_k = new_k.replace('neck.convs', 'input_proj')
|
| 164 |
+
new_k = new_k.replace('conv.weight', '0.weight')
|
| 165 |
+
new_k = new_k.replace('conv.bias', '0.bias')
|
| 166 |
+
new_k = new_k.replace('gn.weight', '1.weight')
|
| 167 |
+
new_k = new_k.replace('gn.bias', '1.bias')
|
| 168 |
+
|
| 169 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 170 |
+
|
| 171 |
+
#########################################################################
|
| 172 |
+
|
| 173 |
+
elif k.startswith('text_feat_map'):
|
| 174 |
+
new_k = k.replace('text_feat_map', 'feat_map')
|
| 175 |
+
|
| 176 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 177 |
+
|
| 178 |
+
#########################################################################
|
| 179 |
+
|
| 180 |
+
elif k.startswith('language_model.language_backbone.body.model'):
|
| 181 |
+
new_k = k.replace('language_model.language_backbone.body.model', 'bert')
|
| 182 |
+
|
| 183 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 184 |
+
|
| 185 |
+
#########################################################################
|
| 186 |
+
|
| 187 |
+
elif k.startswith('backbone'):
|
| 188 |
+
new_k = k.replace('backbone', 'backbone.0')
|
| 189 |
+
if 'patch_embed.projection' in new_k:
|
| 190 |
+
new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj')
|
| 191 |
+
elif 'drop_after_pos' in new_k:
|
| 192 |
+
new_k = new_k.replace('drop_after_pos', 'pos_drop')
|
| 193 |
+
|
| 194 |
+
if 'stages' in new_k:
|
| 195 |
+
new_k = new_k.replace('stages', 'layers')
|
| 196 |
+
if 'ffn.layers.0.0' in new_k:
|
| 197 |
+
new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1')
|
| 198 |
+
elif 'ffn.layers.1' in new_k:
|
| 199 |
+
new_k = new_k.replace('ffn.layers.1', 'mlp.fc2')
|
| 200 |
+
elif 'attn.w_msa' in new_k:
|
| 201 |
+
new_k = new_k.replace('attn.w_msa', 'attn')
|
| 202 |
+
|
| 203 |
+
if 'downsample' in k:
|
| 204 |
+
if 'reduction.' in k:
|
| 205 |
+
new_v = correct_unfold_reduction_order(v)
|
| 206 |
+
elif 'norm.' in k:
|
| 207 |
+
new_v = correct_unfold_norm_order(v)
|
| 208 |
+
|
| 209 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 210 |
+
|
| 211 |
+
#########################################################################
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
print('skip:', k)
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
# if 'transformer.decoder.bbox_embed' in new_k:
|
| 218 |
+
# new_k = new_k.replace('transformer.decoder.bbox_embed', 'bbox_embed')
|
| 219 |
+
# if new_k.startswith('module.'):
|
| 220 |
+
# new_k = new_k.replace('module.', '')
|
| 221 |
+
|
| 222 |
+
return new_ckpt
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
parser = argparse.ArgumentParser(
|
| 226 |
+
description='Convert keys to GroundingDINO style.')
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
'src',
|
| 229 |
+
nargs='?',
|
| 230 |
+
default='grounding_dino_swin-l_pretrain_all-56d69e78.pth',
|
| 231 |
+
help='src model path or url')
|
| 232 |
+
# The dst path must be a full path of the new checkpoint.
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
'dst',
|
| 235 |
+
nargs='?',
|
| 236 |
+
default='mmdet_swinl.pth_groundingdino.pth',
|
| 237 |
+
help='save path')
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
| 241 |
+
|
| 242 |
+
# mmdet中是state_dict而不是model
|
| 243 |
+
if 'state_dict' in checkpoint:
|
| 244 |
+
state_dict = checkpoint['state_dict']
|
| 245 |
+
else:
|
| 246 |
+
state_dict = checkpoint
|
| 247 |
+
|
| 248 |
+
weight = convert(state_dict)
|
| 249 |
+
torch.save(weight, args.dst)
|
| 250 |
+
# sha = subprocess.check_output(['sha256sum', args.dst]).decode()
|
| 251 |
+
# sha = calculate_sha256(args.dst)
|
| 252 |
+
# final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8])
|
| 253 |
+
# subprocess.Popen(['mv', args.dst, final_file])
|
| 254 |
+
print(f'Done!!, save to {args.dst}')
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
main()
|
| 258 |
+
|
| 259 |
+
# skip: dn_query_generator.label_embedding.weight
|
hf_model/mmdet2groundingdino_swint.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mmdet to groundingdino
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import torch
|
| 5 |
+
from mmengine.runner import CheckpointLoader
|
| 6 |
+
|
| 7 |
+
# convert the functions from mmdet to groundingdino
|
| 8 |
+
def correct_unfold_reduction_order(x):
|
| 9 |
+
out_channel, in_channel = x.shape
|
| 10 |
+
x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2)
|
| 11 |
+
x = x[:, [0, 2, 1, 3], :]
|
| 12 |
+
x = x.reshape(out_channel, in_channel)
|
| 13 |
+
return x
|
| 14 |
+
|
| 15 |
+
def correct_unfold_norm_order(x):
|
| 16 |
+
in_channel = x.shape[0]
|
| 17 |
+
x = x.reshape(in_channel // 4, 4).transpose(0, 1)
|
| 18 |
+
x = x[[0, 2, 1, 3], :]
|
| 19 |
+
x = x.reshape(in_channel)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
def convert(ckpt):
|
| 23 |
+
"""Inverse mapping of checkpoint parameters to their original names."""
|
| 24 |
+
# Create a dictionary to hold the reversed checkpoint
|
| 25 |
+
new_ckpt = OrderedDict()
|
| 26 |
+
|
| 27 |
+
for k, v in list(ckpt.items()):
|
| 28 |
+
new_v = v # Start with the original value
|
| 29 |
+
|
| 30 |
+
# Inverse rules based on the convert function (from specific to general)
|
| 31 |
+
if k.startswith('decoder'):
|
| 32 |
+
new_k = k.replace('decoder', 'module.transformer.decoder')
|
| 33 |
+
if 'norms.2' in new_k:
|
| 34 |
+
new_k = new_k.replace('norms.2', 'norm1')
|
| 35 |
+
if 'norms.1' in new_k:
|
| 36 |
+
new_k = new_k.replace('norms.1', 'catext_norm')
|
| 37 |
+
if 'norms.0' in new_k:
|
| 38 |
+
new_k = new_k.replace('norms.0', 'norm2')
|
| 39 |
+
if 'norms.3' in new_k:
|
| 40 |
+
new_k = new_k.replace('norms.3', 'norm3')
|
| 41 |
+
if 'cross_attn_text' in new_k:
|
| 42 |
+
new_k = new_k.replace('cross_attn_text', 'ca_text')
|
| 43 |
+
new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight')
|
| 44 |
+
new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias')
|
| 45 |
+
new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight')
|
| 46 |
+
new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias')
|
| 47 |
+
if 'ffn.layers.0.0' in new_k:
|
| 48 |
+
new_k = new_k.replace('ffn.layers.0.0', 'linear1')
|
| 49 |
+
if 'ffn.layers.1' in new_k:
|
| 50 |
+
new_k = new_k.replace('ffn.layers.1', 'linear2')
|
| 51 |
+
if 'self_attn.attn' in new_k:
|
| 52 |
+
new_k = new_k.replace('self_attn.attn', 'self_attn')
|
| 53 |
+
|
| 54 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 55 |
+
|
| 56 |
+
#########################################################################
|
| 57 |
+
|
| 58 |
+
# encoder部分最后的reg_layer_id是6,和decoder区分开来
|
| 59 |
+
elif k.startswith('bbox_head.reg_branches.6'):
|
| 60 |
+
if k.startswith('bbox_head.reg_branches.6.0'):
|
| 61 |
+
new_k = k.replace('bbox_head.reg_branches.6.0',
|
| 62 |
+
'module.transformer.enc_out_bbox_embed.layers.0')
|
| 63 |
+
if k.startswith('bbox_head.reg_branches.6.2'):
|
| 64 |
+
new_k = k.replace('bbox_head.reg_branches.6.2',
|
| 65 |
+
'module.transformer.enc_out_bbox_embed.layers.1')
|
| 66 |
+
if k.startswith('bbox_head.reg_branches.6.4'):
|
| 67 |
+
new_k = k.replace('bbox_head.reg_branches.6.4',
|
| 68 |
+
'module.transformer.enc_out_bbox_embed.layers.2')
|
| 69 |
+
|
| 70 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 71 |
+
|
| 72 |
+
#########################################################################
|
| 73 |
+
|
| 74 |
+
elif k.startswith('query_embedding'):
|
| 75 |
+
new_k = k.replace('query_embedding', 'module.transformer.tgt_embed')
|
| 76 |
+
|
| 77 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 78 |
+
|
| 79 |
+
#########################################################################
|
| 80 |
+
|
| 81 |
+
elif k.startswith('bbox_head.reg_branches'):
|
| 82 |
+
# mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint
|
| 83 |
+
# groundingdino有两部分参数值是一致的
|
| 84 |
+
# 分别是module.bbox_embed和module.transformer.decoder.embed
|
| 85 |
+
# 所以mmdet直接将两部分参数进行了“合并”
|
| 86 |
+
reg_layer_id = int(k.split('.')[2])
|
| 87 |
+
linear_id = int(k.split('.')[3])
|
| 88 |
+
weight_or_bias = k.split('.')[-1]
|
| 89 |
+
new_k1 = 'module.transformer.decoder.bbox_embed.' + \
|
| 90 |
+
str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
|
| 91 |
+
new_k2 = 'module.bbox_embed.' + \
|
| 92 |
+
str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
|
| 93 |
+
|
| 94 |
+
new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
|
| 95 |
+
new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
|
| 96 |
+
|
| 97 |
+
#########################################################################
|
| 98 |
+
|
| 99 |
+
elif k.startswith('bbox_head.cls_branches.6'):
|
| 100 |
+
# mmdet在contrastive_embed中添加了bias项
|
| 101 |
+
# 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed
|
| 102 |
+
new_k = 'module.transformer.enc_out_class_embed.bias'
|
| 103 |
+
|
| 104 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 105 |
+
|
| 106 |
+
#########################################################################
|
| 107 |
+
|
| 108 |
+
elif k.startswith('bbox_head.cls_branches'):
|
| 109 |
+
# mmdet在contrastive_embed中添加了bias项
|
| 110 |
+
new_k1 = 'module.transformer.decoder.class_embed.' + k[-6:]
|
| 111 |
+
new_k2 = 'module.class_embed.' + k[-6:]
|
| 112 |
+
|
| 113 |
+
new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
|
| 114 |
+
new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
|
| 115 |
+
|
| 116 |
+
#########################################################################
|
| 117 |
+
|
| 118 |
+
elif k.startswith('memory_trans_'):
|
| 119 |
+
if k.startswith('memory_trans_fc'):
|
| 120 |
+
new_k = k.replace('memory_trans_fc', 'module.transformer.enc_output')
|
| 121 |
+
elif k.startswith('memory_trans_norm'):
|
| 122 |
+
new_k = k.replace('memory_trans_norm', 'module.transformer.enc_output_norm')
|
| 123 |
+
|
| 124 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 125 |
+
|
| 126 |
+
#########################################################################
|
| 127 |
+
|
| 128 |
+
elif k.startswith('encoder'):
|
| 129 |
+
new_k = k.replace('encoder', 'module.transformer.encoder')
|
| 130 |
+
new_k = new_k.replace('norms.0', 'norm1')
|
| 131 |
+
new_k = new_k.replace('norms.1', 'norm2')
|
| 132 |
+
new_k = new_k.replace('norms.2', 'norm3')
|
| 133 |
+
new_k = new_k.replace('ffn.layers.0.0', 'linear1')
|
| 134 |
+
new_k = new_k.replace('ffn.layers.1', 'linear2')
|
| 135 |
+
if 'text_layers' in new_k:
|
| 136 |
+
new_k = new_k.replace('self_attn.attn', 'self_attn')
|
| 137 |
+
|
| 138 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 139 |
+
|
| 140 |
+
#########################################################################
|
| 141 |
+
|
| 142 |
+
elif k.startswith('level_embed'):
|
| 143 |
+
new_k = k.replace('level_embed', 'module.transformer.level_embed')
|
| 144 |
+
|
| 145 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 146 |
+
|
| 147 |
+
#########################################################################
|
| 148 |
+
|
| 149 |
+
elif k.startswith('neck.convs'):
|
| 150 |
+
new_k = k.replace('neck.convs', 'module.input_proj')
|
| 151 |
+
new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3')
|
| 152 |
+
new_k = new_k.replace('conv.weight', '0.weight')
|
| 153 |
+
new_k = new_k.replace('conv.bias', '0.bias')
|
| 154 |
+
new_k = new_k.replace('gn.weight', '1.weight')
|
| 155 |
+
new_k = new_k.replace('gn.bias', '1.bias')
|
| 156 |
+
|
| 157 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 158 |
+
|
| 159 |
+
#########################################################################
|
| 160 |
+
|
| 161 |
+
elif 'neck.extra_convs.0' in k:
|
| 162 |
+
new_k = k.replace('neck.extra_convs.0', 'neck.convs.3')
|
| 163 |
+
new_k = new_k.replace('neck.convs', 'module.input_proj')
|
| 164 |
+
new_k = new_k.replace('conv.weight', '0.weight')
|
| 165 |
+
new_k = new_k.replace('conv.bias', '0.bias')
|
| 166 |
+
new_k = new_k.replace('gn.weight', '1.weight')
|
| 167 |
+
new_k = new_k.replace('gn.bias', '1.bias')
|
| 168 |
+
|
| 169 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 170 |
+
|
| 171 |
+
#########################################################################
|
| 172 |
+
|
| 173 |
+
elif k.startswith('text_feat_map'):
|
| 174 |
+
new_k = k.replace('text_feat_map', 'module.feat_map')
|
| 175 |
+
|
| 176 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 177 |
+
|
| 178 |
+
#########################################################################
|
| 179 |
+
|
| 180 |
+
elif k.startswith('language_model.language_backbone.body.model'):
|
| 181 |
+
new_k = k.replace('language_model.language_backbone.body.model', 'module.bert')
|
| 182 |
+
|
| 183 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 184 |
+
|
| 185 |
+
#########################################################################
|
| 186 |
+
|
| 187 |
+
elif k.startswith('backbone'):
|
| 188 |
+
new_k = k.replace('backbone', 'module.backbone.0')
|
| 189 |
+
if 'patch_embed.projection' in new_k:
|
| 190 |
+
new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj')
|
| 191 |
+
elif 'drop_after_pos' in new_k:
|
| 192 |
+
new_k = new_k.replace('drop_after_pos', 'pos_drop')
|
| 193 |
+
|
| 194 |
+
if 'stages' in new_k:
|
| 195 |
+
new_k = new_k.replace('stages', 'layers')
|
| 196 |
+
if 'ffn.layers.0.0' in new_k:
|
| 197 |
+
new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1')
|
| 198 |
+
elif 'ffn.layers.1' in new_k:
|
| 199 |
+
new_k = new_k.replace('ffn.layers.1', 'mlp.fc2')
|
| 200 |
+
elif 'attn.w_msa' in new_k:
|
| 201 |
+
new_k = new_k.replace('attn.w_msa', 'attn')
|
| 202 |
+
|
| 203 |
+
if 'downsample' in k:
|
| 204 |
+
if 'reduction.' in k:
|
| 205 |
+
new_v = correct_unfold_reduction_order(v)
|
| 206 |
+
elif 'norm.' in k:
|
| 207 |
+
new_v = correct_unfold_norm_order(v)
|
| 208 |
+
|
| 209 |
+
new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
|
| 210 |
+
|
| 211 |
+
#########################################################################
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
print('skip:', k)
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
# if 'module.transformer.decoder.bbox_embed' in new_k:
|
| 218 |
+
# new_k = new_k.replace('module.transformer.decoder.bbox_embed', 'module.bbox_embed')
|
| 219 |
+
# if new_k.startswith('module'):
|
| 220 |
+
# new_k = new_k.replace('module.', '')
|
| 221 |
+
|
| 222 |
+
return new_ckpt
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
parser = argparse.ArgumentParser(
|
| 226 |
+
description='Convert keys to GroundingDINO style.')
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
'src',
|
| 229 |
+
nargs='?',
|
| 230 |
+
default='grounding_dino_swin-t_pretrain_obj365_goldg_v3det_20231218_095741-e316e297.pth',
|
| 231 |
+
help='src model path or url')
|
| 232 |
+
# The dst path must be a full path of the new checkpoint.
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
'dst',
|
| 235 |
+
nargs='?',
|
| 236 |
+
default='check_mmdet_to_groundingdino.pth',
|
| 237 |
+
help='save path')
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
| 241 |
+
|
| 242 |
+
# mmdet中是state_dict而不是model
|
| 243 |
+
if 'state_dict' in checkpoint:
|
| 244 |
+
state_dict = checkpoint['state_dict']
|
| 245 |
+
else:
|
| 246 |
+
state_dict = checkpoint
|
| 247 |
+
|
| 248 |
+
weight = convert(state_dict)
|
| 249 |
+
torch.save(weight, args.dst)
|
| 250 |
+
# sha = subprocess.check_output(['sha256sum', args.dst]).decode()
|
| 251 |
+
# sha = calculate_sha256(args.dst)
|
| 252 |
+
# final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8])
|
| 253 |
+
# subprocess.Popen(['mv', args.dst, final_file])
|
| 254 |
+
print(f'Done!!, save to {args.dst}')
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
main()
|
| 258 |
+
|
| 259 |
+
# skip: dn_query_generator.label_embedding.weight
|
hf_model/modeling_grounding_dino.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from transformers import GroundingDinoProcessor
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def prepare_targets(points, caption, shapes, emb_size, device, llmdet_processor):
|
| 11 |
+
gt_points_b = [np.array(points) / np.array(shapes)[::-1]]
|
| 12 |
+
gt_points_b[0] = gt_points_b[0].squeeze(0)
|
| 13 |
+
|
| 14 |
+
gt_points = [torch.from_numpy(img_points).float() for img_points in gt_points_b]
|
| 15 |
+
gt_logits = [torch.zeros((img_points.shape[0], emb_size)) for img_points in gt_points]
|
| 16 |
+
|
| 17 |
+
tokenized = llmdet_processor.tokenizer(caption[0], padding="longest", return_tensors="pt")
|
| 18 |
+
end_idxes = [torch.where(ids == 1012)[0][-1] for ids in tokenized['input_ids']]
|
| 19 |
+
for i, end_idx in enumerate(end_idxes):
|
| 20 |
+
gt_logits[i][:, :end_idx] = 1.0
|
| 21 |
+
caption_sizes = [idx + 2 for idx in end_idxes]
|
| 22 |
+
|
| 23 |
+
targets = [{"points": p.to(device), "labels": l.to(device), "caption_size": c}
|
| 24 |
+
for p, l, c in zip(gt_points, gt_logits, caption_sizes)]
|
| 25 |
+
|
| 26 |
+
return targets
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def post_process_grounded_object_detection(
|
| 30 |
+
outputs,
|
| 31 |
+
box_threshold: float = 0.4,
|
| 32 |
+
):
|
| 33 |
+
# for the fine-tuning model, the box threshold should be set to 0.50
|
| 34 |
+
logits, boxes = outputs.logits, outputs.pred_boxes
|
| 35 |
+
|
| 36 |
+
probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
|
| 37 |
+
scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
|
| 38 |
+
|
| 39 |
+
results = []
|
| 40 |
+
for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
|
| 41 |
+
score = s[s > box_threshold]
|
| 42 |
+
box = b[s > box_threshold]
|
| 43 |
+
prob = p[s > box_threshold]
|
| 44 |
+
results.append({"scores": score, "boxes": box})
|
| 45 |
+
|
| 46 |
+
return results
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class collator:
|
| 50 |
+
def __init__(self, processor=None, use_negative=True):
|
| 51 |
+
model_id = "fushh7/llmdet_swin_tiny_hf"
|
| 52 |
+
self.llmdet_processor = GroundingDinoProcessor.from_pretrained(model_id)
|
| 53 |
+
self.use_negative = use_negative
|
| 54 |
+
|
| 55 |
+
def __call__(self, batch):
|
| 56 |
+
# assume batch size is 1
|
| 57 |
+
example = batch[0]
|
| 58 |
+
image = example['image']
|
| 59 |
+
pil_image = example['image']
|
| 60 |
+
w, h = image.size
|
| 61 |
+
pos_caption = example['pos_caption']
|
| 62 |
+
neg_caption = example['neg_caption']
|
| 63 |
+
pos_points = example['pos_points']
|
| 64 |
+
neg_points = example['neg_points']
|
| 65 |
+
pos_count = example['pos_count']
|
| 66 |
+
neg_count = example['neg_count']
|
| 67 |
+
annotated_pos_count = example['annotated_pos_count']
|
| 68 |
+
annotated_neg_count = example['annotated_neg_count']
|
| 69 |
+
|
| 70 |
+
if 'type' in example:
|
| 71 |
+
sample_type = example['type']
|
| 72 |
+
else:
|
| 73 |
+
sample_type = 'eval'
|
| 74 |
+
category = example['category']
|
| 75 |
+
image_name = "{}_{}_{}_{}_{}".format(category, pos_caption, neg_caption, pos_count, neg_count)
|
| 76 |
+
pos_llm_det_inputs = self.llmdet_processor(images=image, text=pos_caption, return_tensors="pt", padding=True)
|
| 77 |
+
neg_llm_det_inputs = self.llmdet_processor(images=image, text=neg_caption, return_tensors="pt", padding=True)
|
| 78 |
+
pos_caption = [[pos_caption]]
|
| 79 |
+
neg_caption = [[neg_caption]]
|
| 80 |
+
shapes = [(w, h)]
|
| 81 |
+
pos_points = [pos_points]
|
| 82 |
+
neg_points = [neg_points]
|
| 83 |
+
|
| 84 |
+
# exemplars
|
| 85 |
+
if 'positive_exemplars' in example and 'negative_exemplars' in example and example[
|
| 86 |
+
'positive_exemplars'] is not None and example['negative_exemplars'] is not None:
|
| 87 |
+
pos_exemplars = example['positive_exemplars']
|
| 88 |
+
neg_exemplars = example['negative_exemplars']
|
| 89 |
+
img_height, img_width = pil_image.size
|
| 90 |
+
norm_pos_exemplars = []
|
| 91 |
+
norm_neg_exemplars = []
|
| 92 |
+
exemplar_valid = True
|
| 93 |
+
for exemplars in pos_exemplars:
|
| 94 |
+
tly, tlx, bry, brx = exemplars
|
| 95 |
+
tlx = tlx / img_width
|
| 96 |
+
tly = tly / img_height
|
| 97 |
+
brx = brx / img_width
|
| 98 |
+
bry = bry / img_height
|
| 99 |
+
if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0:
|
| 100 |
+
exemplar_valid = False
|
| 101 |
+
if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0:
|
| 102 |
+
exemplar_valid = False
|
| 103 |
+
if tlx >= brx or tly >= bry:
|
| 104 |
+
exemplar_valid = False
|
| 105 |
+
tlx = max(tlx, 0)
|
| 106 |
+
tly = max(tly, 0)
|
| 107 |
+
tly = min(tly, 1 - 1e-4)
|
| 108 |
+
tlx = min(tlx, 1 - 1e-4)
|
| 109 |
+
brx = min(brx, 1)
|
| 110 |
+
bry = min(bry, 1)
|
| 111 |
+
brx = max(brx, tlx)
|
| 112 |
+
bry = max(bry, tly)
|
| 113 |
+
assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}"
|
| 114 |
+
norm_pos_exemplars.append([tlx, tly, brx, bry])
|
| 115 |
+
for exemplars in neg_exemplars:
|
| 116 |
+
tly, tlx, bry, brx = exemplars
|
| 117 |
+
tlx = tlx / img_width
|
| 118 |
+
tly = tly / img_height
|
| 119 |
+
brx = brx / img_width
|
| 120 |
+
bry = bry / img_height
|
| 121 |
+
if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0:
|
| 122 |
+
exemplar_valid = False
|
| 123 |
+
if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0:
|
| 124 |
+
exemplar_valid = False
|
| 125 |
+
if tlx >= brx or tly >= bry:
|
| 126 |
+
exemplar_valid = False
|
| 127 |
+
tlx = max(tlx, 0)
|
| 128 |
+
tly = max(tly, 0)
|
| 129 |
+
tly = min(tly, 1 - 1e-4)
|
| 130 |
+
tlx = min(tlx, 1 - 1e-4)
|
| 131 |
+
brx = min(brx, 1)
|
| 132 |
+
bry = min(bry, 1)
|
| 133 |
+
brx = max(brx, tlx)
|
| 134 |
+
bry = max(bry, tly)
|
| 135 |
+
assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}"
|
| 136 |
+
norm_neg_exemplars.append([tlx, tly, brx, bry])
|
| 137 |
+
|
| 138 |
+
if exemplar_valid:
|
| 139 |
+
pos_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_pos_exemplars]
|
| 140 |
+
neg_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_neg_exemplars]
|
| 141 |
+
pos_exemplars = torch.stack(pos_exemplars)
|
| 142 |
+
neg_exemplars = torch.stack(neg_exemplars)
|
| 143 |
+
batch_dict = {
|
| 144 |
+
'pos_llm_det_inputs': pos_llm_det_inputs,
|
| 145 |
+
'neg_llm_det_inputs': neg_llm_det_inputs,
|
| 146 |
+
'pos_caption': pos_caption,
|
| 147 |
+
'neg_caption': neg_caption,
|
| 148 |
+
'shapes': shapes,
|
| 149 |
+
'pos_points': pos_points,
|
| 150 |
+
'neg_points': neg_points,
|
| 151 |
+
'pos_count': pos_count,
|
| 152 |
+
'neg_count': neg_count,
|
| 153 |
+
'annotated_pos_count': annotated_pos_count,
|
| 154 |
+
'annotated_neg_count': annotated_neg_count,
|
| 155 |
+
'image': pil_image,
|
| 156 |
+
'category': category,
|
| 157 |
+
'type': sample_type,
|
| 158 |
+
'pos_exemplars': pos_exemplars,
|
| 159 |
+
'neg_exemplars': neg_exemplars,
|
| 160 |
+
'image_name': image_name,
|
| 161 |
+
}
|
| 162 |
+
else:
|
| 163 |
+
batch_dict = {
|
| 164 |
+
'pos_llm_det_inputs': pos_llm_det_inputs,
|
| 165 |
+
'neg_llm_det_inputs': neg_llm_det_inputs,
|
| 166 |
+
'pos_caption': pos_caption,
|
| 167 |
+
'neg_caption': neg_caption,
|
| 168 |
+
'shapes': shapes,
|
| 169 |
+
'pos_points': pos_points,
|
| 170 |
+
'neg_points': neg_points,
|
| 171 |
+
'pos_count': pos_count,
|
| 172 |
+
'neg_count': neg_count,
|
| 173 |
+
'annotated_pos_count': annotated_pos_count,
|
| 174 |
+
'annotated_neg_count': annotated_neg_count,
|
| 175 |
+
'image': pil_image,
|
| 176 |
+
'category': category,
|
| 177 |
+
'type': sample_type,
|
| 178 |
+
'image_name': image_name,
|
| 179 |
+
}
|
| 180 |
+
else:
|
| 181 |
+
batch_dict = {
|
| 182 |
+
'pos_llm_det_inputs': pos_llm_det_inputs,
|
| 183 |
+
'neg_llm_det_inputs': neg_llm_det_inputs,
|
| 184 |
+
'pos_caption': pos_caption,
|
| 185 |
+
'neg_caption': neg_caption,
|
| 186 |
+
'shapes': shapes,
|
| 187 |
+
'pos_points': pos_points,
|
| 188 |
+
'neg_points': neg_points,
|
| 189 |
+
'pos_count': pos_count,
|
| 190 |
+
'neg_count': neg_count,
|
| 191 |
+
'annotated_pos_count': annotated_pos_count,
|
| 192 |
+
'annotated_neg_count': annotated_neg_count,
|
| 193 |
+
'image': pil_image,
|
| 194 |
+
'category': category,
|
| 195 |
+
'type': sample_type,
|
| 196 |
+
'image_name': image_name,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
return batch_dict
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
import torch.distributed as dist
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def rank0_print(*args):
|
| 206 |
+
if dist.is_initialized():
|
| 207 |
+
if dist.get_rank() == 0:
|
| 208 |
+
print(f"Rank {dist.get_rank()}: ", *args)
|
| 209 |
+
else:
|
| 210 |
+
print(*args)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def build_dataset(data_args):
|
| 214 |
+
from datasets import load_from_disk, concatenate_datasets
|
| 215 |
+
|
| 216 |
+
categories = ["FOO", "FUN", "OFF", "OTR", "HOU"]
|
| 217 |
+
|
| 218 |
+
if data_args.data_split not in categories:
|
| 219 |
+
rank0_print(f"Warning: Invalid data_split '{data_args.data_split}'. Switching to 'all' mode.")
|
| 220 |
+
data_args.data_split = "all"
|
| 221 |
+
|
| 222 |
+
if data_args.data_split == "all":
|
| 223 |
+
train_dataset = load_from_disk(data_args.train_data_path)
|
| 224 |
+
train_dataset = concatenate_datasets(
|
| 225 |
+
[train_dataset["FOO"], train_dataset["FUN"], train_dataset["OFF"], train_dataset["OTR"],
|
| 226 |
+
train_dataset["HOU"]])
|
| 227 |
+
|
| 228 |
+
val_dataset = load_from_disk(data_args.val_data_path)
|
| 229 |
+
val_dataset = concatenate_datasets(
|
| 230 |
+
[val_dataset["FOO"], val_dataset["FUN"], val_dataset["OFF"], val_dataset["OTR"], val_dataset["HOU"]])
|
| 231 |
+
|
| 232 |
+
test_dataset = load_from_disk(data_args.test_data_path)
|
| 233 |
+
test_dataset = concatenate_datasets(
|
| 234 |
+
[test_dataset["FOO"], test_dataset["FUN"], test_dataset["OFF"], test_dataset["OTR"], test_dataset["HOU"]])
|
| 235 |
+
|
| 236 |
+
weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path)
|
| 237 |
+
weakly_supervised_data = concatenate_datasets(
|
| 238 |
+
[weakly_supervised_data["FOO"], weakly_supervised_data["FUN"], weakly_supervised_data["OFF"],
|
| 239 |
+
weakly_supervised_data["OTR"], weakly_supervised_data["HOU"]])
|
| 240 |
+
|
| 241 |
+
rank0_print("Using 'all' mode: all categories for train/val/test")
|
| 242 |
+
|
| 243 |
+
else:
|
| 244 |
+
test_category = data_args.data_split
|
| 245 |
+
train_categories = [cat for cat in categories if cat != test_category]
|
| 246 |
+
train_dataset = load_from_disk(data_args.train_data_path)
|
| 247 |
+
print(train_categories, train_dataset.keys())
|
| 248 |
+
train_datasets = [train_dataset[cat] for cat in train_categories]
|
| 249 |
+
train_dataset = concatenate_datasets(train_datasets)
|
| 250 |
+
|
| 251 |
+
weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path)
|
| 252 |
+
weakly_supervised_data = [weakly_supervised_data[cat] for cat in train_categories]
|
| 253 |
+
weakly_supervised_data = concatenate_datasets(weakly_supervised_data)
|
| 254 |
+
|
| 255 |
+
val_dataset = load_from_disk(data_args.val_data_path)
|
| 256 |
+
val_dataset = val_dataset[test_category]
|
| 257 |
+
|
| 258 |
+
test_dataset = load_from_disk(data_args.test_data_path)
|
| 259 |
+
test_dataset = test_dataset[test_category]
|
| 260 |
+
|
| 261 |
+
rank0_print(f"Cross-validation mode: using {train_categories} for train, {test_category} for val/test")
|
| 262 |
+
|
| 263 |
+
rank0_print('train_dataset: ', len(train_dataset))
|
| 264 |
+
rank0_print('val_dataset: ', len(val_dataset))
|
| 265 |
+
rank0_print('test_dataset: ', len(test_dataset))
|
| 266 |
+
rank0_print('weakly_supervised_data: ', len(weakly_supervised_data))
|
| 267 |
+
|
| 268 |
+
return train_dataset, val_dataset, test_dataset, weakly_supervised_data
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def generate_pseudo_density_map(points_norm: torch.Tensor,
|
| 272 |
+
output_size: tuple[int, int],
|
| 273 |
+
sigma: float = 4.0,
|
| 274 |
+
normalize: bool = True) -> torch.Tensor:
|
| 275 |
+
device = points_norm.device
|
| 276 |
+
H, W = output_size
|
| 277 |
+
N = points_norm.shape[0]
|
| 278 |
+
|
| 279 |
+
ys = torch.arange(H, device=device).float()
|
| 280 |
+
xs = torch.arange(W, device=device).float()
|
| 281 |
+
grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij') # (H, W)
|
| 282 |
+
|
| 283 |
+
pts_px = points_norm.clone()
|
| 284 |
+
pts_px[:, 0] *= (W - 1) # x
|
| 285 |
+
pts_px[:, 1] *= (H - 1) # y
|
| 286 |
+
|
| 287 |
+
dx = grid_x.unsqueeze(0) - pts_px[:, 0].view(-1, 1, 1) # (N, H, W)
|
| 288 |
+
dy = grid_y.unsqueeze(0) - pts_px[:, 1].view(-1, 1, 1) # (N, H, W)
|
| 289 |
+
dist2 = dx ** 2 + dy ** 2
|
| 290 |
+
gaussians = torch.exp(-dist2 / (2 * sigma ** 2)) # (N, H, W)
|
| 291 |
+
density_map = gaussians.sum(dim=0, keepdim=True) # (1, H, W)
|
| 292 |
+
|
| 293 |
+
if normalize and N > 0:
|
| 294 |
+
density_map = density_map * (N / density_map.sum())
|
| 295 |
+
|
| 296 |
+
return density_map.unsqueeze(0)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def show_density_map(density_map: torch.Tensor,
|
| 300 |
+
points_norm: torch.Tensor | None = None,
|
| 301 |
+
figsize: tuple[int, int] = (6, 8),
|
| 302 |
+
cmap: str = "jet") -> None:
|
| 303 |
+
dm = density_map.squeeze().detach().cpu().numpy() # (H, W)
|
| 304 |
+
H, W = dm.shape
|
| 305 |
+
|
| 306 |
+
plt.figure(figsize=figsize)
|
| 307 |
+
plt.imshow(dm, cmap=cmap, origin="upper")
|
| 308 |
+
plt.colorbar(label="Density")
|
| 309 |
+
|
| 310 |
+
if points_norm is not None and points_norm.numel() > 0:
|
| 311 |
+
pts = points_norm.detach().cpu().numpy()
|
| 312 |
+
xs = pts[:, 0] * (W - 1)
|
| 313 |
+
ys = pts[:, 1] * (H - 1)
|
| 314 |
+
plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5)
|
| 315 |
+
|
| 316 |
+
plt.title(f"Density map (sum = {dm.sum():.2f})")
|
| 317 |
+
plt.axis("off")
|
| 318 |
+
plt.tight_layout()
|
| 319 |
+
plt.show()
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def show_image_with_density(pil_img: Image.Image,
|
| 323 |
+
density_map: torch.Tensor,
|
| 324 |
+
points_norm: torch.Tensor | None = None,
|
| 325 |
+
cmap: str = "jet",
|
| 326 |
+
alpha: float = 0.45,
|
| 327 |
+
figsize: tuple[int, int] = (6, 8)) -> None:
|
| 328 |
+
dm = density_map.squeeze().detach().cpu().numpy() # (H, W)
|
| 329 |
+
H, W = dm.shape
|
| 330 |
+
|
| 331 |
+
img_resized = pil_img.resize((W, H), Image.BILINEAR) # or LANCZOS
|
| 332 |
+
img_np = np.asarray(img_resized)
|
| 333 |
+
|
| 334 |
+
plt.figure(figsize=figsize)
|
| 335 |
+
plt.imshow(img_np, origin="upper")
|
| 336 |
+
plt.imshow(dm, cmap=cmap, alpha=alpha, origin="upper")
|
| 337 |
+
|
| 338 |
+
if points_norm is not None and points_norm.numel() > 0:
|
| 339 |
+
pts = points_norm.detach().cpu().numpy()
|
| 340 |
+
xs = pts[:, 0] * (W - 1)
|
| 341 |
+
ys = pts[:, 1] * (H - 1)
|
| 342 |
+
plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5)
|
| 343 |
+
|
| 344 |
+
plt.title(f"Overlay (density sum = {dm.sum():.2f})")
|
| 345 |
+
plt.axis("off")
|
| 346 |
+
plt.tight_layout()
|
| 347 |
+
plt.show()
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def build_point_count_map(feat_maps: torch.Tensor,
|
| 351 |
+
pts_norm_list: list[torch.Tensor]) -> torch.Tensor:
|
| 352 |
+
assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)"
|
| 353 |
+
B, H, W, _ = feat_maps.shape
|
| 354 |
+
device = feat_maps.device
|
| 355 |
+
|
| 356 |
+
count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device)
|
| 357 |
+
|
| 358 |
+
for b in range(B):
|
| 359 |
+
pts = pts_norm_list[b].to(device).float() # (Ni, 2)
|
| 360 |
+
if pts.numel() == 0:
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
idx_xy = (pts * torch.tensor([W, H], device=device)).long()
|
| 364 |
+
idx_xy[..., 0].clamp_(0, W - 1) # x
|
| 365 |
+
idx_xy[..., 1].clamp_(0, H - 1) # y
|
| 366 |
+
|
| 367 |
+
lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,)
|
| 368 |
+
one = torch.ones_like(lin_idx, dtype=torch.float32)
|
| 369 |
+
|
| 370 |
+
flat = torch.zeros(H * W, dtype=torch.float32, device=device)
|
| 371 |
+
flat.scatter_add_(0, lin_idx, one)
|
| 372 |
+
count_map[b] = flat.view(H, W)
|
| 373 |
+
|
| 374 |
+
return count_map
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
import torch
|
| 378 |
+
import torch.nn.functional as F
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def extract_pos_tokens_single(feat_maps: torch.Tensor,
|
| 382 |
+
count_map: torch.Tensor):
|
| 383 |
+
assert feat_maps.dim() == 4 and count_map.dim() == 3, "维度应为 (B,H,W,D) / (B,H,W)"
|
| 384 |
+
B, H, W, D = feat_maps.shape
|
| 385 |
+
assert B == 1, "当前函数假设 batch_size == 1"
|
| 386 |
+
feat = feat_maps[0] # (H,W,D)
|
| 387 |
+
cnt = count_map[0] # (H,W)
|
| 388 |
+
pos_mask = cnt > 0 # Bool (H,W)
|
| 389 |
+
if pos_mask.sum() == 0:
|
| 390 |
+
empty = torch.empty(0, device=feat.device)
|
| 391 |
+
return empty.reshape(0, D), empty.long()
|
| 392 |
+
pos_tokens = feat[pos_mask] # (N_pos, D)
|
| 393 |
+
y_idx, x_idx = torch.nonzero(pos_mask, as_tuple=True)
|
| 394 |
+
lin_index = y_idx * W + x_idx # (N_pos,)
|
| 395 |
+
return pos_tokens, lin_index
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def filter_overlap(pos_tok, lin_pos, neg_tok, lin_neg):
|
| 399 |
+
pos_only_mask = ~torch.isin(lin_pos, lin_neg)
|
| 400 |
+
neg_only_mask = ~torch.isin(lin_neg, lin_pos)
|
| 401 |
+
return pos_tok[pos_only_mask], neg_tok[neg_only_mask]
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# ------------------------------------------------------------
|
| 405 |
+
# 2) supervised contrastive loss
|
| 406 |
+
# ------------------------------------------------------------
|
| 407 |
+
def supcon_pos_neg(pos_tokens, neg_tokens, temperature=0.07):
|
| 408 |
+
"""
|
| 409 |
+
pos_tokens : (Np, D) Pos token
|
| 410 |
+
neg_tokens : (Nn, D) Neg token
|
| 411 |
+
"""
|
| 412 |
+
if pos_tokens.numel() == 0 or neg_tokens.numel() == 0:
|
| 413 |
+
return torch.tensor(0., device=pos_tokens.device, requires_grad=True)
|
| 414 |
+
pos_tokens = F.normalize(pos_tokens, dim=-1)
|
| 415 |
+
neg_tokens = F.normalize(neg_tokens, dim=-1)
|
| 416 |
+
feats = torch.cat([pos_tokens, neg_tokens], dim=0) # (N, D)
|
| 417 |
+
labels = torch.cat([torch.zeros(len(pos_tokens), device=feats.device, dtype=torch.long),
|
| 418 |
+
torch.ones(len(neg_tokens), device=feats.device, dtype=torch.long)], dim=0) # (N,)
|
| 419 |
+
logits = feats @ feats.T / temperature # (N, N)
|
| 420 |
+
logits.fill_diagonal_(-1e4)
|
| 421 |
+
mask_pos = labels.unsqueeze(0) == labels.unsqueeze(1) # (N, N)
|
| 422 |
+
mask_pos.fill_diagonal_(False)
|
| 423 |
+
exp_logits = logits.exp()
|
| 424 |
+
denom = exp_logits.sum(dim=1, keepdim=True) # Σ_{a≠i} exp
|
| 425 |
+
log_prob = logits - denom.log() # log softmax
|
| 426 |
+
loss_i = -(mask_pos * log_prob).sum(1) / mask_pos.sum(1).clamp_min(1)
|
| 427 |
+
loss = loss_i.mean()
|
| 428 |
+
return loss
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def build_point_count_map(feat_maps: torch.Tensor,
|
| 432 |
+
pts_norm_list: list[torch.Tensor]) -> torch.Tensor:
|
| 433 |
+
assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)"
|
| 434 |
+
B, H, W, _ = feat_maps.shape
|
| 435 |
+
device = feat_maps.device
|
| 436 |
+
|
| 437 |
+
count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device)
|
| 438 |
+
|
| 439 |
+
for b in range(B):
|
| 440 |
+
pts = pts_norm_list[b].to(device).float() # (Ni, 2)
|
| 441 |
+
if pts.numel() == 0:
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
idx_xy = (pts * torch.tensor([W, H], device=device)).long()
|
| 445 |
+
idx_xy[..., 0].clamp_(0, W - 1) # x
|
| 446 |
+
idx_xy[..., 1].clamp_(0, H - 1) # y
|
| 447 |
+
|
| 448 |
+
lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,)
|
| 449 |
+
one = torch.ones_like(lin_idx, dtype=torch.float32)
|
| 450 |
+
|
| 451 |
+
flat = torch.zeros(H * W, dtype=torch.float32, device=device)
|
| 452 |
+
flat.scatter_add_(0, lin_idx, one)
|
| 453 |
+
count_map[b] = flat.view(H, W)
|
| 454 |
+
|
| 455 |
+
return count_map
|