yifehuang97 commited on
Commit
74af434
·
1 Parent(s): caa3cab
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image, ImageDraw
5
+ from transformers import GroundingDinoProcessor
6
+ from hf_model import CountEX
7
+ from utils import post_process_grounded_object_detection
8
+
9
+ # Global variables for model and processor
10
+ model = None
11
+ processor = None
12
+ device = None
13
+
14
+
15
+ def load_model():
16
+ """Load model and processor once at startup"""
17
+ global model, processor, device
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # Load model - change path for HF Spaces
22
+ model_id = "BBVisual/CountEX-KC" # Change to your HF model repo
23
+ model = CountEX.from_pretrained(model_id)
24
+ model = model.to(torch.bfloat16)
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ # Load processor
29
+ processor_id = "fushh7/llmdet_swin_tiny_hf"
30
+ processor = GroundingDinoProcessor.from_pretrained(processor_id)
31
+
32
+ return model, processor, device
33
+
34
+
35
+ def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color):
36
+ """
37
+ Main inference function for counting objects
38
+
39
+ Args:
40
+ image: Input PIL Image
41
+ pos_caption: Positive prompt (objects to count)
42
+ neg_caption: Negative prompt (objects to exclude)
43
+ box_threshold: Detection confidence threshold
44
+ point_radius: Radius of visualization points
45
+ point_color: Color of visualization points
46
+
47
+ Returns:
48
+ Annotated image and count
49
+ """
50
+ global model, processor, device
51
+
52
+ if model is None:
53
+ load_model()
54
+
55
+ # Ensure image is RGB
56
+ if image.mode != "RGB":
57
+ image = image.convert("RGB")
58
+
59
+ # Ensure captions end with period
60
+ if not pos_caption.endswith('.'):
61
+ pos_caption = pos_caption + '.'
62
+ if neg_caption and not neg_caption.endswith('.'):
63
+ neg_caption = neg_caption + '.'
64
+
65
+ # Process positive caption
66
+ pos_inputs = processor(
67
+ images=image,
68
+ text=pos_caption,
69
+ return_tensors="pt",
70
+ padding=True
71
+ )
72
+ pos_inputs = pos_inputs.to(device)
73
+ pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16)
74
+
75
+ # Process negative caption if provided
76
+ use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.')
77
+
78
+ if use_neg:
79
+ neg_inputs = processor(
80
+ images=image,
81
+ text=neg_caption,
82
+ return_tensors="pt",
83
+ padding=True
84
+ )
85
+ neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()}
86
+ neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16)
87
+
88
+ # Add negative inputs to positive inputs dict
89
+ pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids']
90
+ pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask']
91
+ pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask']
92
+ pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values']
93
+ pos_inputs['neg_input_ids'] = neg_inputs['input_ids']
94
+ pos_inputs['use_neg'] = True
95
+ else:
96
+ pos_inputs['use_neg'] = False
97
+
98
+ # Run inference
99
+ with torch.no_grad():
100
+ outputs = model(**pos_inputs)
101
+
102
+ # Post-process outputs
103
+ outputs["pred_points"] = outputs["pred_boxes"][:, :, :2]
104
+ outputs["pred_logits"] = outputs["logits"]
105
+
106
+ # Use custom threshold if provided, otherwise use model default
107
+ threshold = box_threshold if box_threshold > 0 else model.box_threshold
108
+ results = post_process_grounded_object_detection(outputs, box_threshold=threshold)[0]
109
+
110
+ # Extract points
111
+ boxes = results["boxes"]
112
+ boxes = [box.tolist() for box in boxes]
113
+ points = [[box[0], box[1]] for box in boxes]
114
+
115
+ # Visualize results
116
+ img_w, img_h = image.size
117
+ img_draw = image.copy()
118
+ draw = ImageDraw.Draw(img_draw)
119
+
120
+ for point in points:
121
+ x = point[0] * img_w
122
+ y = point[1] * img_h
123
+ draw.ellipse(
124
+ [x - point_radius, y - point_radius, x + point_radius, y + point_radius],
125
+ fill=point_color
126
+ )
127
+
128
+ count = len(points)
129
+
130
+ return img_draw, f"Count: {count}"
131
+
132
+
133
+ # Create Gradio interface
134
+ def create_demo():
135
+ with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo:
136
+ gr.Markdown("""
137
+ # CountEx: Discriminative Visual Counting
138
+
139
+ Count specific objects in images using positive and negative text prompts.
140
+
141
+ **Positive Prompt**: Describe what you want to count (e.g., "Green Apple")
142
+
143
+ **Negative Prompt**: Describe what you want to exclude (e.g., "Red Apple")
144
+ """)
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=1):
148
+ input_image = gr.Image(type="pil", label="Input Image")
149
+
150
+ pos_caption = gr.Textbox(
151
+ label="Positive Prompt",
152
+ placeholder="e.g., Green Apple",
153
+ value="Green Apple"
154
+ )
155
+
156
+ neg_caption = gr.Textbox(
157
+ label="Negative Prompt (optional)",
158
+ placeholder="e.g., Red Apple",
159
+ value="Red Apple"
160
+ )
161
+
162
+ with gr.Accordion("Advanced Settings", open=False):
163
+ box_threshold = gr.Slider(
164
+ minimum=0.0,
165
+ maximum=1.0,
166
+ value=0.0,
167
+ step=0.01,
168
+ label="Detection Threshold (0 = use model default)"
169
+ )
170
+
171
+ point_radius = gr.Slider(
172
+ minimum=1,
173
+ maximum=20,
174
+ value=5,
175
+ step=1,
176
+ label="Point Radius"
177
+ )
178
+
179
+ point_color = gr.Dropdown(
180
+ choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"],
181
+ value="blue",
182
+ label="Point Color"
183
+ )
184
+
185
+ submit_btn = gr.Button("Count Objects", variant="primary")
186
+
187
+ with gr.Column(scale=1):
188
+ output_image = gr.Image(type="pil", label="Result")
189
+ count_output = gr.Textbox(label="Count Result")
190
+
191
+ # Example images
192
+ # gr.Examples(
193
+ # examples=[
194
+ # ["examples/apples.jpg", "Green Apple", "Red Apple"],
195
+ # ["examples/cars.jpg", "Red Car", "Blue Car"],
196
+ # ["examples/people.jpg", "Person wearing hat", "Person without hat"],
197
+ # ],
198
+ # inputs=[input_image, pos_caption, neg_caption],
199
+ # outputs=[output_image, count_output],
200
+ # fn=count_objects,
201
+ # cache_examples=False,
202
+ # )
203
+
204
+ submit_btn.click(
205
+ fn=count_objects,
206
+ inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color],
207
+ outputs=[output_image, count_output]
208
+ )
209
+
210
+ return demo
211
+
212
+
213
+ if __name__ == "__main__":
214
+ # Load model at startup
215
+ print("Loading model...")
216
+ load_model()
217
+ print("Model loaded!")
218
+
219
+ # Create and launch demo
220
+ demo = create_demo()
221
+ demo.launch()
hf_model/CountEX.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ Negative Grounding DINO Model for Object Detection with Negative Caption Support.
4
+
5
+ This module extends the original GroundingDinoForObjectDetection to support negative captions
6
+ for improved object detection performance.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import Dict, List, Optional, Tuple, Union
12
+ from transformers.modeling_outputs import ModelOutput
13
+ import torch.nn.functional as F
14
+ from .modeling_grounding_dino import (
15
+ GroundingDinoForObjectDetection,
16
+ GroundingDinoObjectDetectionOutput,
17
+ GroundingDinoEncoderOutput,
18
+ )
19
+
20
+
21
+ # density_fpn_head.py
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ def _bilinear(x, size):
28
+ return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
29
+
30
+
31
+ class DensityFPNHead(nn.Module):
32
+ def __init__(self,
33
+ in_channels: int = 512,
34
+ mid_channels: int = 128,
35
+ act_layer=nn.ReLU,
36
+ norm_layer=nn.BatchNorm2d):
37
+ super().__init__()
38
+
39
+ # ---- 1×1 lateral convs (P3–P6) ----
40
+ self.lateral = nn.ModuleList([
41
+ nn.Conv2d(in_channels, mid_channels, 1) for _ in range(4)
42
+ ])
43
+
44
+ # ---- smooth convs after add ----
45
+ self.smooth = nn.ModuleList([
46
+ nn.Sequential(
47
+ nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False),
48
+ norm_layer(mid_channels),
49
+ act_layer(inplace=True),
50
+ ) for _ in range(3) # P6→P5, P5→P4, P4→P3
51
+ ])
52
+
53
+ self.up_blocks = nn.ModuleList([
54
+ nn.Sequential(
55
+ act_layer(inplace=True),
56
+ nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False),
57
+ norm_layer(mid_channels),
58
+ act_layer(inplace=True),
59
+ ) for _ in range(3) # 167×94 → … → 1336×752
60
+ ])
61
+
62
+ # ---- output 3×3 conv -> 1 ----
63
+ self.out_conv = nn.Conv2d(mid_channels, 1, 3, padding=1, bias=False)
64
+
65
+ def forward(self, feats):
66
+ assert len(feats) == 4, "Expect feats list = [P3,P4,P5,P6]"
67
+
68
+ # lateral 1×1
69
+ lat = [l(f) for l, f in zip(self.lateral, feats)]
70
+
71
+ # top-down FPN fusion
72
+ x = lat[-1] # P6
73
+ for i in range(3)[::-1]: # P5,P4,P3
74
+ x = _bilinear(x, lat[i].shape[-2:])
75
+ x = x + lat[i]
76
+ x = self.smooth[i](x)
77
+
78
+ # three-stage upsample + conv
79
+ for up in self.up_blocks:
80
+ h, w = x.shape[-2], x.shape[-1]
81
+ x = _bilinear(x, (h * 2, w * 2))
82
+ x = up(x)
83
+
84
+ x = self.out_conv(x)
85
+ return F.relu(x)
86
+
87
+
88
+ import torch
89
+ import torch.nn as nn
90
+ import torch.nn.functional as F
91
+
92
+ def l2norm(x, dim=-1, eps=1e-6):
93
+ return x / (x.norm(dim=dim, keepdim=True) + eps)
94
+
95
+ # -----------------------------------
96
+ # 1) CommonFinderSimple
97
+ # learn r "common prototypes", representing the common representative of positive/negative
98
+ # non fancy: only MHA pooling + two light regularizations (shareability + diversity)
99
+ # -----------------------------------
100
+ class CommonFinderSimple(nn.Module):
101
+ """
102
+ Inputs:
103
+ Q_pos: [B, K, D]
104
+ Q_neg: [B, K, D]
105
+ Returns:
106
+ C_rows: [B, r, D] # batch copied r common prototypes (unitized)
107
+ loss: scalar # small regularization: shareability + diversity
108
+ stats: dict
109
+ """
110
+ def __init__(self, d_model=256, r=64, nhead=4,
111
+ share_w=0.02, div_w=0.02, ln_after=False):
112
+ super().__init__()
113
+ self.r = r
114
+ self.share_w = share_w
115
+ self.div_w = div_w
116
+
117
+ proto = torch.randn(r, d_model)
118
+ self.proto = nn.Parameter(l2norm(proto, -1)) # r×D learnable "core queries"
119
+ self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
120
+ self.post = nn.Linear(d_model, d_model)
121
+ self.ln = nn.LayerNorm(d_model) if ln_after else nn.Identity()
122
+
123
+ def forward(self, Q_pos: torch.Tensor, Q_neg: torch.Tensor):
124
+ B, K, D = Q_pos.shape
125
+ seeds = self.proto[None].expand(B, -1, -1).contiguous() # [B,r,D]
126
+ X = torch.cat([Q_pos, Q_neg], dim=1) # [B,2K,D]
127
+
128
+ # use seeds to do one attention pooling on positive and negative sets, get r "common prototypes"
129
+ C, _ = self.attn(query=seeds, key=X, value=X) # [B,r,D]
130
+ C = l2norm(self.ln(self.post(C)), -1) # unitization
131
+
132
+ # ---- Simple regularization: encourage C to be close to both Q_pos and Q_neg, and diverse from each other ----
133
+ # Shareability: average of maximum cosine similarity between C and Q_pos/Q_neg
134
+ cos_pos = torch.einsum('brd,bkd->brk', C, l2norm(Q_pos, -1)) # [B,r,K]
135
+ cos_neg = torch.einsum('brd,bkd->brk', C, l2norm(Q_neg, -1))
136
+ share_term = -(cos_pos.amax(dim=-1).mean() + cos_neg.amax(dim=-1).mean())
137
+
138
+ # Diversity: cosine between C should not collapse
139
+ C0 = l2norm(self.proto, -1) # [r,D]
140
+ gram = C0 @ C0.t() # [r,r]
141
+ div_term = (gram - torch.eye(self.r, device=gram.device)).pow(2).mean()
142
+
143
+ loss = self.share_w * share_term + self.div_w * div_term
144
+ stats = {
145
+ 'share_term': share_term.detach(),
146
+ 'div_term': div_term.detach(),
147
+ 'mean_cos_pos': cos_pos.mean().detach(),
148
+ 'mean_cos_neg': cos_neg.mean().detach()
149
+ }
150
+ return C, loss, stats
151
+
152
+
153
+ # -----------------------------------
154
+ # 2) NegExclusiveSimple
155
+ # Remove "common" information from negative queries: two simple strategies can be used independently or together
156
+ # (A) Soft removal: subtract the projection onto C (residual keeps non-common)
157
+ # (B) Filtering: only keep the Top-M negative samples least similar to C
158
+ # -----------------------------------
159
+ class NegExclusiveSimple(nn.Module):
160
+ """
161
+ Inputs:
162
+ Q_neg: [B,K,D]
163
+ C_rows: [B,r,D] # common prototypes
164
+ Args:
165
+ mode: 'residual' | 'filter' | 'both'
166
+ M: Top-M for 'filter'
167
+ thresh: Filter threshold (max_cos_neg < thresh to keep), None means only use Top-M
168
+ Returns:
169
+ neg_refs: [B, M_or_K, D] # as negative reference (for next fusion)
170
+ aux: dict
171
+ """
172
+ def __init__(self, mode='residual', M=16, thresh=None):
173
+ super().__init__()
174
+ assert mode in ('residual', 'filter', 'both')
175
+ self.mode = mode
176
+ self.M = M
177
+ self.thresh = thresh
178
+
179
+ def forward(self, Q_neg: torch.Tensor, C_rows: torch.Tensor):
180
+ B, K, D = Q_neg.shape
181
+ r = C_rows.size(1)
182
+ Qn = l2norm(Q_neg, -1)
183
+ C = l2norm(C_rows, -1)
184
+
185
+ sim = torch.einsum('bkd,brd->bkr', Qn, C).amax(dim=-1) # [B,K]
186
+
187
+ outputs = {}
188
+ if self.mode in ('residual', 'both'):
189
+ # proj = (Q · C^T) C -> [B,K,D]; first weight [B,K,r], then multiply C [B,r,D]
190
+ w = torch.einsum('bkd,brd->bkr', Qn, C) # [B,K,r]
191
+ proj = torch.einsum('bkr,brd->bkd', w, C) # [B,K,D]
192
+ neg_resid = l2norm(Qn - proj, -1) # non-common residual
193
+ outputs['residual'] = neg_resid
194
+
195
+ if self.mode in ('filter', 'both'):
196
+ excl_score = 1.0 - sim # large = away from common
197
+ if self.thresh is not None:
198
+ mask = (sim < self.thresh).float()
199
+ excl_score = excl_score * mask + (-1e4) * (1 - mask)
200
+ M = min(self.M, K)
201
+ topv, topi = torch.topk(excl_score, k=M, dim=1) # [B,M]
202
+ neg_top = torch.gather(Qn, 1, topi.unsqueeze(-1).expand(-1, -1, D))
203
+ outputs['filtered'] = neg_top
204
+
205
+ if self.mode == 'residual':
206
+ neg_refs = outputs['residual']
207
+ elif self.mode == 'filter':
208
+ neg_refs = outputs['filtered']
209
+ else:
210
+ R = outputs['residual'] # [B,K,D]
211
+ excl_score = 1.0 - sim
212
+ M = min(self.M, K)
213
+ topv, topi = torch.topk(excl_score, k=M, dim=1)
214
+ neg_refs = torch.gather(R, 1, topi.unsqueeze(-1).expand(-1, -1, D)) # [B,M,D]
215
+
216
+ aux = {
217
+ 'mean_sim_to_common': sim.mean().detach(),
218
+ 'kept_M': neg_refs.size(1)
219
+ }
220
+ return neg_refs, aux
221
+
222
+ import torch
223
+ import torch.nn as nn
224
+ import torch.nn.functional as F
225
+
226
+ def l2norm(x, dim=-1, eps=1e-6):
227
+ return x / (x.norm(dim=dim, keepdim=True) + eps)
228
+
229
+ class FusionNoGate(nn.Module):
230
+ """
231
+ Direct fusion (no gating): fuse neg_ref into Q_pos via one cross-attn.
232
+ Variants:
233
+ - 'residual_sub': Q_new = Q_pos - scale * LN(Z)
234
+ - 'residual_add': Q_new = Q_pos + scale * LN(Z)
235
+ - 'concat_linear': Q_new = Q_pos + Linear([Q_pos; Z])
236
+ """
237
+ def __init__(self, d_model=256, nhead=4, fusion_mode='residual_sub',
238
+ init_scale=0.2, dropout_p=0.0):
239
+ super().__init__()
240
+ assert fusion_mode in ('residual_sub', 'residual_add', 'concat_linear')
241
+ self.fusion_mode = fusion_mode
242
+ self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
243
+ self.ln_z = nn.LayerNorm(d_model)
244
+ self.drop = nn.Dropout(dropout_p) if dropout_p > 0 else nn.Identity()
245
+ self.scale = nn.Parameter(torch.tensor(float(init_scale)))
246
+ if fusion_mode == 'concat_linear':
247
+ self.mix = nn.Linear(2 * d_model, d_model)
248
+ nn.init.zeros_(self.mix.weight)
249
+ nn.init.zeros_(self.mix.bias)
250
+
251
+ def forward(self, Q_pos: torch.Tensor, neg_ref: torch.Tensor):
252
+ """
253
+ Q_pos: [B, K, D]
254
+ neg_ref: [B, M, D]
255
+ return: Q_new [B, K, D], stats dict
256
+ """
257
+ B, K, D = Q_pos.shape
258
+ M = neg_ref.size(1)
259
+ if M == 0:
260
+ return Q_pos, {'kept': 0, 'scale': self.scale.detach()}
261
+
262
+ # 1) Cross-attention:
263
+ Z, attn_w = self.attn(query=Q_pos, key=neg_ref, value=neg_ref) # Z:[B,K,D]
264
+ Z = self.ln_z(Z)
265
+ Z = self.drop(Z)
266
+
267
+ # 2) wo gating
268
+ if self.fusion_mode == 'residual_sub':
269
+ Q_new = Q_pos - self.scale * Z
270
+ # print("z: ", Z.sum())
271
+ # print(torch.abs(Q_new - Q_pos).sum())
272
+ elif self.fusion_mode == 'residual_add':
273
+ Q_new = Q_pos + self.scale * Z
274
+ else: # 'concat_linear'
275
+ fused = torch.cat([Q_pos, Z], dim=-1) # [B,K,2D]
276
+ delta = self.mix(fused) # [B,K,D]
277
+ Q_new = Q_pos + delta
278
+
279
+ stats = {
280
+ 'kept': M,
281
+ 'attn_mean': attn_w.mean().detach(),
282
+ 'fusion_scale': self.scale.detach()
283
+ }
284
+ return Q_new, stats
285
+
286
+ class QuerySideNegNaive(nn.Module):
287
+ def __init__(self, d_model=256, r=64, M=64, nhead=4,
288
+ excl_mode='both', excl_thresh=0.5, gamma_max=0.7,
289
+ share_w=0.02, div_w=0.02):
290
+ super().__init__()
291
+ self.common = CommonFinderSimple(d_model, r, nhead, share_w, div_w)
292
+ self.excl = NegExclusiveSimple(mode=excl_mode, M=M, thresh=excl_thresh)
293
+ self.fuse = FusionNoGate(d_model=d_model,
294
+ nhead=4,
295
+ fusion_mode='residual_sub', # or 'concat_linear'
296
+ init_scale=0.25,
297
+ dropout_p=0.1)
298
+
299
+ def forward(self, Q_pos: torch.Tensor, Q_neg: torch.Tensor):
300
+ C_rows, l_common, common_stats = self.common(Q_pos, Q_neg)
301
+ neg_refs, excl_stats = self.excl(Q_neg, C_rows)
302
+ Q_new, fuse_stats = self.fuse(Q_pos, neg_refs)
303
+ loss = l_common
304
+ stats = {}
305
+ stats.update(common_stats); stats.update(excl_stats); stats.update(fuse_stats)
306
+ return Q_new, loss, stats
307
+
308
+ def set_fusion_scale(self, scale: float):
309
+ del self.fuse.scale
310
+ self.fuse.scale = nn.Parameter(torch.tensor(scale))
311
+
312
+
313
+ class CountEX(GroundingDinoForObjectDetection):
314
+ """
315
+ Grounding DINO Model with negative caption support for improved object detection.
316
+
317
+ This model extends the original GroundingDinoForObjectDetection by adding
318
+ support for negative captions, which helps improve detection accuracy by
319
+ learning what NOT to detect.
320
+ """
321
+
322
+ def __init__(self, config):
323
+ super().__init__(config)
324
+
325
+ # Initialize negative fusion modules directly in __init__
326
+ self.query_side_neg_pipeline = QuerySideNegNaive()
327
+ self.density_head = DensityFPNHead()
328
+ self.config = config
329
+ self.box_threshold = getattr(config, 'box_threshold', 0.4)
330
+
331
+ def forward(
332
+ self,
333
+ pixel_values: torch.FloatTensor,
334
+ input_ids: torch.LongTensor,
335
+ token_type_ids: torch.LongTensor = None,
336
+ attention_mask: torch.LongTensor = None,
337
+ pixel_mask: Optional[torch.BoolTensor] = None,
338
+ encoder_outputs: Optional[Union[GroundingDinoEncoderOutput, Tuple]] = None,
339
+ output_attentions: Optional[bool] = None,
340
+ output_hidden_states: Optional[bool] = None,
341
+ return_dict: Optional[bool] = None,
342
+ labels: List[Dict[str, Union[torch.LongTensor, torch.FloatTensor]]] = None,
343
+ # Negative prompt parameters
344
+ neg_pixel_values: Optional[torch.FloatTensor] = None,
345
+ neg_input_ids: Optional[torch.LongTensor] = None,
346
+ neg_token_type_ids: Optional[torch.LongTensor] = None,
347
+ neg_attention_mask: Optional[torch.LongTensor] = None,
348
+ neg_pixel_mask: Optional[torch.BoolTensor] = None,
349
+ **kwargs,
350
+ ):
351
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
+ use_neg = kwargs.get('use_neg', True)
353
+ # Get positive outputs
354
+ pos_kwargs = {
355
+ 'exemplars': kwargs.get('pos_exemplars', None),
356
+ }
357
+ outputs = self.model(
358
+ pixel_values=pixel_values,
359
+ input_ids=input_ids,
360
+ token_type_ids=token_type_ids,
361
+ attention_mask=attention_mask,
362
+ pixel_mask=pixel_mask,
363
+ encoder_outputs=encoder_outputs,
364
+ output_attentions=output_attentions,
365
+ output_hidden_states=output_hidden_states,
366
+ return_dict=return_dict,
367
+ **pos_kwargs,
368
+ )
369
+
370
+ spatial_shapes = outputs.spatial_shapes
371
+ token_num = 0
372
+ token_num_list = [0]
373
+ for i in range(len(spatial_shapes)):
374
+ token_num += spatial_shapes[i][0] * spatial_shapes[i][1]
375
+ token_num_list.append(token_num.item())
376
+
377
+ positive_feature_maps = []
378
+ encoder_last_hidden_state_vision = outputs.encoder_last_hidden_state_vision
379
+ for i in range(len(spatial_shapes)):
380
+ feature_map = encoder_last_hidden_state_vision[:, token_num_list[i]:token_num_list[i+1], :]
381
+ spatial_shape = spatial_shapes[i]
382
+ b, t, d = feature_map.shape
383
+ feature_map = feature_map.reshape(b, spatial_shape[0], spatial_shape[1], d)
384
+ positive_feature_maps.append(feature_map)
385
+
386
+ # Get negative outputs
387
+ neg_kwargs = {
388
+ 'exemplars': kwargs.get('neg_exemplars', None),
389
+ }
390
+ # print(kwargs)
391
+ neg_outputs = self.model(
392
+ pixel_values=neg_pixel_values,
393
+ input_ids=neg_input_ids,
394
+ token_type_ids=neg_token_type_ids,
395
+ attention_mask=neg_attention_mask,
396
+ pixel_mask=neg_pixel_mask,
397
+ encoder_outputs=encoder_outputs,
398
+ output_attentions=output_attentions,
399
+ output_hidden_states=output_hidden_states,
400
+ return_dict=return_dict,
401
+ **neg_kwargs,
402
+ )
403
+
404
+ neg_encoder_last_hidden_state_vision = neg_outputs.encoder_last_hidden_state_vision
405
+ neg_positive_feature_maps = []
406
+ for i in range(len(spatial_shapes)):
407
+ feature_map = neg_encoder_last_hidden_state_vision[:, token_num_list[i]:token_num_list[i+1], :]
408
+ spatial_shape = spatial_shapes[i]
409
+ b, t, d = feature_map.shape
410
+ feature_map = feature_map.reshape(b, spatial_shape[0], spatial_shape[1], d)
411
+ neg_positive_feature_maps.append(feature_map)
412
+
413
+ if return_dict:
414
+ hidden_states = outputs.intermediate_hidden_states
415
+ neg_hidden_states = neg_outputs.intermediate_hidden_states
416
+ else:
417
+ hidden_states = outputs[2]
418
+ neg_hidden_states = neg_outputs[2]
419
+
420
+ idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
421
+ enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]
422
+ hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
423
+ init_reference_points = outputs.init_reference_points if return_dict else outputs[1]
424
+ inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3]
425
+
426
+ # drop the exemplar tokens if used
427
+ pos_exemplars = pos_kwargs.get('pos_exemplars', None)
428
+ neg_exemplars = neg_kwargs.get('neg_exemplars', None)
429
+ if pos_exemplars is not None or neg_exemplars is not None or attention_mask.shape[1] != enc_text_hidden_state.shape[1]:
430
+ enc_text_hidden_state = enc_text_hidden_state[:, :enc_text_hidden_state.shape[1] - 3, :]
431
+
432
+ # class logits + predicted bounding boxes
433
+ outputs_classes = []
434
+ outputs_coords = []
435
+
436
+ # Apply negative fusion
437
+ if use_neg:
438
+ # print("Using negative fusions")
439
+ #neg_hidden_states = self.negative_semantic_extractor(neg_hidden_states)
440
+ #hidden_states = self.negative_fusion_module(hidden_states, neg_hidden_states)
441
+ hidden_states = hidden_states.squeeze(0)
442
+ neg_hidden_states = neg_hidden_states.squeeze(0)
443
+ hidden_states, extra_loss, logs = self.query_side_neg_pipeline(hidden_states, neg_hidden_states)
444
+ hidden_states = hidden_states.unsqueeze(0)
445
+ neg_hidden_states = neg_hidden_states.unsqueeze(0)
446
+ # print("extra_loss: ", extra_loss)
447
+ else:
448
+ # print("Not using negative fusions")
449
+ extra_loss = None
450
+ logs = None
451
+ # print("Not using negative fusion")
452
+ # print("extra_loss: ", extra_loss)
453
+
454
+ # predict class and bounding box deltas for each stage
455
+ num_levels = hidden_states.shape[1]
456
+ for level in range(num_levels):
457
+ if level == 0:
458
+ reference = init_reference_points
459
+ else:
460
+ reference = inter_references_points[:, level - 1]
461
+ reference = torch.special.logit(reference, eps=1e-5)
462
+
463
+ # print("hidden_states[:, level]: ", hidden_states[:, level].shape)
464
+ # print("enc_text_hidden_state: ", enc_text_hidden_state.shape)
465
+ # print("attention_mask: ", attention_mask.shape)
466
+
467
+ assert attention_mask.shape[1] == enc_text_hidden_state.shape[1], "Attention mask and text hidden state have different lengths: {} != {}".format(attention_mask.shape[1], enc_text_hidden_state.shape[1])
468
+ outputs_class = self.class_embed[level](
469
+ vision_hidden_state=hidden_states[:, level],
470
+ text_hidden_state=enc_text_hidden_state,
471
+ text_token_mask=attention_mask.bool(),
472
+ )
473
+ delta_bbox = self.bbox_embed[level](hidden_states[:, level])
474
+
475
+ reference_coordinates = reference.shape[-1]
476
+ if reference_coordinates == 4:
477
+ outputs_coord_logits = delta_bbox + reference
478
+ elif reference_coordinates == 2:
479
+ delta_bbox[..., :2] += reference
480
+ outputs_coord_logits = delta_bbox
481
+ else:
482
+ raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
483
+ outputs_coord = outputs_coord_logits.sigmoid()
484
+ outputs_classes.append(outputs_class)
485
+ outputs_coords.append(outputs_coord)
486
+ outputs_class = torch.stack(outputs_classes)
487
+ outputs_coord = torch.stack(outputs_coords)
488
+
489
+ logits = outputs_class[-1]
490
+ pred_boxes = outputs_coord[-1]
491
+
492
+ loss, loss_dict, auxiliary_outputs = None, None, None
493
+ if not return_dict:
494
+ if auxiliary_outputs is not None:
495
+ output = (logits, pred_boxes) + auxiliary_outputs + outputs
496
+ else:
497
+ output = (logits, pred_boxes) + outputs
498
+ tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
499
+
500
+ return tuple_outputs
501
+
502
+ all_feats = []
503
+ for pf, npf in zip(positive_feature_maps, neg_positive_feature_maps):
504
+ pf = pf.permute(0, 3, 1, 2)
505
+ npf = npf.permute(0, 3, 1, 2)
506
+ all_feats.append(torch.cat([pf, npf], dim=1))
507
+
508
+
509
+ # pos_feat = positive_feature_maps[0].permute(0, 3, 1, 2)
510
+ # neg_feat = neg_positive_feature_maps[0].permute(0, 3, 1, 2)
511
+ # pos_minus_neg_feat = F.relu(pos_feat - neg_feat)
512
+ # density_feat_map = torch.cat([pos_feat, neg_feat, pos_minus_neg_feat], dim=1)
513
+ # density_feat_map = torch.cat([pos_feat, neg_feat], dim=1)
514
+ density_map_pred = self.density_head(all_feats)
515
+
516
+ dict_outputs = GroundingDinoObjectDetectionOutput(
517
+ loss=loss,
518
+ loss_dict=loss_dict,
519
+ logits=logits,
520
+ pred_boxes=pred_boxes,
521
+ last_hidden_state=outputs.last_hidden_state,
522
+ auxiliary_outputs=auxiliary_outputs,
523
+ decoder_hidden_states=outputs.decoder_hidden_states,
524
+ decoder_attentions=outputs.decoder_attentions,
525
+ encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision,
526
+ encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text,
527
+ encoder_vision_hidden_states=outputs.encoder_vision_hidden_states,
528
+ encoder_text_hidden_states=outputs.encoder_text_hidden_states,
529
+ encoder_attentions=outputs.encoder_attentions,
530
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
531
+ intermediate_reference_points=outputs.intermediate_reference_points,
532
+ init_reference_points=outputs.init_reference_points,
533
+ enc_outputs_class=outputs.enc_outputs_class,
534
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
535
+ spatial_shapes=outputs.spatial_shapes,
536
+ positive_feature_maps=positive_feature_maps,
537
+ negative_feature_maps=neg_positive_feature_maps,
538
+ density_map_pred=density_map_pred,
539
+ extra_loss=extra_loss,
540
+ extra_logs=logs,
541
+ )
542
+
543
+ return dict_outputs
hf_model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ HF Model package for Grounding DINO with negative caption support.
4
+ """
5
+
6
+ from .modeling_grounding_dino import (
7
+ GroundingDinoForObjectDetection,
8
+ )
9
+
10
+ from .CountEX import (
11
+ CountEX
12
+ )
13
+
14
+ __all__ = [
15
+ "CountEX",
16
+ ]
hf_model/mmdet2groundingdino_swinb.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mmdet to groundingdino
2
+ import argparse
3
+ from collections import OrderedDict
4
+ import torch
5
+ from mmengine.runner import CheckpointLoader
6
+
7
+ # convert the functions from mmdet to groundingdino
8
+ def correct_unfold_reduction_order(x):
9
+ out_channel, in_channel = x.shape
10
+ x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2)
11
+ x = x[:, [0, 2, 1, 3], :]
12
+ x = x.reshape(out_channel, in_channel)
13
+ return x
14
+
15
+ def correct_unfold_norm_order(x):
16
+ in_channel = x.shape[0]
17
+ x = x.reshape(in_channel // 4, 4).transpose(0, 1)
18
+ x = x[[0, 2, 1, 3], :]
19
+ x = x.reshape(in_channel)
20
+ return x
21
+
22
+ def convert(ckpt):
23
+ """Inverse mapping of checkpoint parameters to their original names."""
24
+ # Create a dictionary to hold the reversed checkpoint
25
+ new_ckpt = OrderedDict()
26
+
27
+ for k, v in list(ckpt.items()):
28
+ new_v = v # Start with the original value
29
+
30
+ # Inverse rules based on the convert function (from specific to general)
31
+ if k.startswith('decoder'):
32
+ new_k = k.replace('decoder', 'transformer.decoder')
33
+ if 'norms.2' in new_k:
34
+ new_k = new_k.replace('norms.2', 'norm1')
35
+ if 'norms.1' in new_k:
36
+ new_k = new_k.replace('norms.1', 'catext_norm')
37
+ if 'norms.0' in new_k:
38
+ new_k = new_k.replace('norms.0', 'norm2')
39
+ if 'norms.3' in new_k:
40
+ new_k = new_k.replace('norms.3', 'norm3')
41
+ if 'cross_attn_text' in new_k:
42
+ new_k = new_k.replace('cross_attn_text', 'ca_text')
43
+ new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight')
44
+ new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias')
45
+ new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight')
46
+ new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias')
47
+ if 'ffn.layers.0.0' in new_k:
48
+ new_k = new_k.replace('ffn.layers.0.0', 'linear1')
49
+ if 'ffn.layers.1' in new_k:
50
+ new_k = new_k.replace('ffn.layers.1', 'linear2')
51
+ if 'self_attn.attn' in new_k:
52
+ new_k = new_k.replace('self_attn.attn', 'self_attn')
53
+
54
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
55
+
56
+ #########################################################################
57
+
58
+ # encoder部分最后的reg_layer_id是6,和decoder区分开来
59
+ elif k.startswith('bbox_head.reg_branches.6'):
60
+ if k.startswith('bbox_head.reg_branches.6.0'):
61
+ new_k = k.replace('bbox_head.reg_branches.6.0',
62
+ 'transformer.enc_out_bbox_embed.layers.0')
63
+ if k.startswith('bbox_head.reg_branches.6.2'):
64
+ new_k = k.replace('bbox_head.reg_branches.6.2',
65
+ 'transformer.enc_out_bbox_embed.layers.1')
66
+ if k.startswith('bbox_head.reg_branches.6.4'):
67
+ new_k = k.replace('bbox_head.reg_branches.6.4',
68
+ 'transformer.enc_out_bbox_embed.layers.2')
69
+
70
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
71
+
72
+ #########################################################################
73
+
74
+ elif k.startswith('query_embedding'):
75
+ new_k = k.replace('query_embedding', 'transformer.tgt_embed')
76
+
77
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
78
+
79
+ #########################################################################
80
+
81
+ elif k.startswith('bbox_head.reg_branches'):
82
+ # mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint
83
+ # groundingdino有两部分参数值是一致的
84
+ # 分别是bbox_embed和transformer.decoder.embed
85
+ # 所以mmdet直接将两部分参数进行了“合并”
86
+ reg_layer_id = int(k.split('.')[2])
87
+ linear_id = int(k.split('.')[3])
88
+ weight_or_bias = k.split('.')[-1]
89
+ new_k1 = 'transformer.decoder.bbox_embed.' + \
90
+ str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
91
+ new_k2 = 'bbox_embed.' + \
92
+ str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
93
+
94
+ new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
95
+ new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
96
+
97
+ #########################################################################
98
+
99
+ elif k.startswith('bbox_head.cls_branches.6'):
100
+ # mmdet在contrastive_embed中添加了bias项
101
+ # 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed
102
+ new_k = 'transformer.enc_out_class_embed.bias'
103
+
104
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
105
+
106
+ #########################################################################
107
+
108
+ elif k.startswith('bbox_head.cls_branches'):
109
+ # mmdet在contrastive_embed中添加了bias项
110
+ new_k1 = 'transformer.decoder.class_embed.' + k[-6:]
111
+ new_k2 = 'class_embed.' + k[-6:]
112
+
113
+ new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
114
+ new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
115
+
116
+ #########################################################################
117
+
118
+ elif k.startswith('memory_trans_'):
119
+ if k.startswith('memory_trans_fc'):
120
+ new_k = k.replace('memory_trans_fc', 'transformer.enc_output')
121
+ elif k.startswith('memory_trans_norm'):
122
+ new_k = k.replace('memory_trans_norm', 'transformer.enc_output_norm')
123
+
124
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
125
+
126
+ #########################################################################
127
+
128
+ elif k.startswith('encoder'):
129
+ new_k = k.replace('encoder', 'transformer.encoder')
130
+ new_k = new_k.replace('norms.0', 'norm1')
131
+ new_k = new_k.replace('norms.1', 'norm2')
132
+ new_k = new_k.replace('norms.2', 'norm3')
133
+ new_k = new_k.replace('ffn.layers.0.0', 'linear1')
134
+ new_k = new_k.replace('ffn.layers.1', 'linear2')
135
+ if 'text_layers' in new_k:
136
+ new_k = new_k.replace('self_attn.attn', 'self_attn')
137
+
138
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
139
+
140
+ #########################################################################
141
+
142
+ elif k.startswith('level_embed'):
143
+ new_k = k.replace('level_embed', 'transformer.level_embed')
144
+
145
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
146
+
147
+ #########################################################################
148
+
149
+ elif k.startswith('neck.convs'):
150
+ new_k = k.replace('neck.convs', 'input_proj')
151
+ new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3')
152
+ new_k = new_k.replace('conv.weight', '0.weight')
153
+ new_k = new_k.replace('conv.bias', '0.bias')
154
+ new_k = new_k.replace('gn.weight', '1.weight')
155
+ new_k = new_k.replace('gn.bias', '1.bias')
156
+
157
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
158
+
159
+ #########################################################################
160
+
161
+ elif 'neck.extra_convs.0' in k:
162
+ new_k = k.replace('neck.extra_convs.0', 'neck.convs.3')
163
+ new_k = new_k.replace('neck.convs', 'input_proj')
164
+ new_k = new_k.replace('conv.weight', '0.weight')
165
+ new_k = new_k.replace('conv.bias', '0.bias')
166
+ new_k = new_k.replace('gn.weight', '1.weight')
167
+ new_k = new_k.replace('gn.bias', '1.bias')
168
+
169
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
170
+
171
+ #########################################################################
172
+
173
+ elif k.startswith('text_feat_map'):
174
+ new_k = k.replace('text_feat_map', 'feat_map')
175
+
176
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
177
+
178
+ #########################################################################
179
+
180
+ elif k.startswith('language_model.language_backbone.body.model'):
181
+ new_k = k.replace('language_model.language_backbone.body.model', 'bert')
182
+
183
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
184
+
185
+ #########################################################################
186
+
187
+ elif k.startswith('backbone'):
188
+ new_k = k.replace('backbone', 'backbone.0')
189
+ if 'patch_embed.projection' in new_k:
190
+ new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj')
191
+ elif 'drop_after_pos' in new_k:
192
+ new_k = new_k.replace('drop_after_pos', 'pos_drop')
193
+
194
+ if 'stages' in new_k:
195
+ new_k = new_k.replace('stages', 'layers')
196
+ if 'ffn.layers.0.0' in new_k:
197
+ new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1')
198
+ elif 'ffn.layers.1' in new_k:
199
+ new_k = new_k.replace('ffn.layers.1', 'mlp.fc2')
200
+ elif 'attn.w_msa' in new_k:
201
+ new_k = new_k.replace('attn.w_msa', 'attn')
202
+
203
+ if 'downsample' in k:
204
+ if 'reduction.' in k:
205
+ new_v = correct_unfold_reduction_order(v)
206
+ elif 'norm.' in k:
207
+ new_v = correct_unfold_norm_order(v)
208
+
209
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
210
+
211
+ #########################################################################
212
+
213
+ else:
214
+ print('skip:', k)
215
+ continue
216
+
217
+ # if 'transformer.decoder.bbox_embed' in new_k:
218
+ # new_k = new_k.replace('transformer.decoder.bbox_embed', 'bbox_embed')
219
+ # if new_k.startswith('module.'):
220
+ # new_k = new_k.replace('module.', '')
221
+
222
+ return new_ckpt
223
+
224
+ def main():
225
+ parser = argparse.ArgumentParser(
226
+ description='Convert keys to GroundingDINO style.')
227
+ parser.add_argument(
228
+ 'src',
229
+ nargs='?',
230
+ default='grounding_dino_swin-b_pretrain_all-f9818a7c.pth',
231
+ help='src model path or url')
232
+ # The dst path must be a full path of the new checkpoint.
233
+ parser.add_argument(
234
+ 'dst',
235
+ nargs='?',
236
+ default='mmdet_swinb_cogcoor.pth_groundingdino.pth',
237
+ help='save path')
238
+ args = parser.parse_args()
239
+
240
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
241
+
242
+ # mmdet中是state_dict而不是model
243
+ if 'state_dict' in checkpoint:
244
+ state_dict = checkpoint['state_dict']
245
+ else:
246
+ state_dict = checkpoint
247
+
248
+ weight = convert(state_dict)
249
+ torch.save(weight, args.dst)
250
+ # sha = subprocess.check_output(['sha256sum', args.dst]).decode()
251
+ # sha = calculate_sha256(args.dst)
252
+ # final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8])
253
+ # subprocess.Popen(['mv', args.dst, final_file])
254
+ print(f'Done!!, save to {args.dst}')
255
+
256
+ if __name__ == '__main__':
257
+ main()
258
+
259
+ # skip: dn_query_generator.label_embedding.weight
hf_model/mmdet2groundingdino_swinl.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mmdet to groundingdino
2
+ import argparse
3
+ from collections import OrderedDict
4
+ import torch
5
+ from mmengine.runner import CheckpointLoader
6
+
7
+ # convert the functions from mmdet to groundingdino
8
+ def correct_unfold_reduction_order(x):
9
+ out_channel, in_channel = x.shape
10
+ x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2)
11
+ x = x[:, [0, 2, 1, 3], :]
12
+ x = x.reshape(out_channel, in_channel)
13
+ return x
14
+
15
+ def correct_unfold_norm_order(x):
16
+ in_channel = x.shape[0]
17
+ x = x.reshape(in_channel // 4, 4).transpose(0, 1)
18
+ x = x[[0, 2, 1, 3], :]
19
+ x = x.reshape(in_channel)
20
+ return x
21
+
22
+ def convert(ckpt):
23
+ """Inverse mapping of checkpoint parameters to their original names."""
24
+ # Create a dictionary to hold the reversed checkpoint
25
+ new_ckpt = OrderedDict()
26
+
27
+ for k, v in list(ckpt.items()):
28
+ new_v = v # Start with the original value
29
+
30
+ # Inverse rules based on the convert function (from specific to general)
31
+ if k.startswith('decoder'):
32
+ new_k = k.replace('decoder', 'transformer.decoder')
33
+ if 'norms.2' in new_k:
34
+ new_k = new_k.replace('norms.2', 'norm1')
35
+ if 'norms.1' in new_k:
36
+ new_k = new_k.replace('norms.1', 'catext_norm')
37
+ if 'norms.0' in new_k:
38
+ new_k = new_k.replace('norms.0', 'norm2')
39
+ if 'norms.3' in new_k:
40
+ new_k = new_k.replace('norms.3', 'norm3')
41
+ if 'cross_attn_text' in new_k:
42
+ new_k = new_k.replace('cross_attn_text', 'ca_text')
43
+ new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight')
44
+ new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias')
45
+ new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight')
46
+ new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias')
47
+ if 'ffn.layers.0.0' in new_k:
48
+ new_k = new_k.replace('ffn.layers.0.0', 'linear1')
49
+ if 'ffn.layers.1' in new_k:
50
+ new_k = new_k.replace('ffn.layers.1', 'linear2')
51
+ if 'self_attn.attn' in new_k:
52
+ new_k = new_k.replace('self_attn.attn', 'self_attn')
53
+
54
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
55
+
56
+ #########################################################################
57
+
58
+ # encoder部分最后的reg_layer_id是6,和decoder区分开来
59
+ elif k.startswith('bbox_head.reg_branches.6'):
60
+ if k.startswith('bbox_head.reg_branches.6.0'):
61
+ new_k = k.replace('bbox_head.reg_branches.6.0',
62
+ 'transformer.enc_out_bbox_embed.layers.0')
63
+ if k.startswith('bbox_head.reg_branches.6.2'):
64
+ new_k = k.replace('bbox_head.reg_branches.6.2',
65
+ 'transformer.enc_out_bbox_embed.layers.1')
66
+ if k.startswith('bbox_head.reg_branches.6.4'):
67
+ new_k = k.replace('bbox_head.reg_branches.6.4',
68
+ 'transformer.enc_out_bbox_embed.layers.2')
69
+
70
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
71
+
72
+ #########################################################################
73
+
74
+ elif k.startswith('query_embedding'):
75
+ new_k = k.replace('query_embedding', 'transformer.tgt_embed')
76
+
77
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
78
+
79
+ #########################################################################
80
+
81
+ elif k.startswith('bbox_head.reg_branches'):
82
+ # mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint
83
+ # groundingdino有两部分参数值是一致的
84
+ # 分别是bbox_embed和transformer.decoder.embed
85
+ # 所以mmdet直接将两部分参数进行了“合并”
86
+ reg_layer_id = int(k.split('.')[2])
87
+ linear_id = int(k.split('.')[3])
88
+ weight_or_bias = k.split('.')[-1]
89
+ new_k1 = 'transformer.decoder.bbox_embed.' + \
90
+ str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
91
+ new_k2 = 'bbox_embed.' + \
92
+ str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
93
+
94
+ new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
95
+ new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
96
+
97
+ #########################################################################
98
+
99
+ elif k.startswith('bbox_head.cls_branches.6'):
100
+ # mmdet在contrastive_embed中添加了bias项
101
+ # 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed
102
+ new_k = 'transformer.enc_out_class_embed.bias'
103
+
104
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
105
+
106
+ #########################################################################
107
+
108
+ elif k.startswith('bbox_head.cls_branches'):
109
+ # mmdet在contrastive_embed中添加了bias项
110
+ new_k1 = 'transformer.decoder.class_embed.' + k[-6:]
111
+ new_k2 = 'class_embed.' + k[-6:]
112
+
113
+ new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
114
+ new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
115
+
116
+ #########################################################################
117
+
118
+ elif k.startswith('memory_trans_'):
119
+ if k.startswith('memory_trans_fc'):
120
+ new_k = k.replace('memory_trans_fc', 'transformer.enc_output')
121
+ elif k.startswith('memory_trans_norm'):
122
+ new_k = k.replace('memory_trans_norm', 'transformer.enc_output_norm')
123
+
124
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
125
+
126
+ #########################################################################
127
+
128
+ elif k.startswith('encoder'):
129
+ new_k = k.replace('encoder', 'transformer.encoder')
130
+ new_k = new_k.replace('norms.0', 'norm1')
131
+ new_k = new_k.replace('norms.1', 'norm2')
132
+ new_k = new_k.replace('norms.2', 'norm3')
133
+ new_k = new_k.replace('ffn.layers.0.0', 'linear1')
134
+ new_k = new_k.replace('ffn.layers.1', 'linear2')
135
+ if 'text_layers' in new_k:
136
+ new_k = new_k.replace('self_attn.attn', 'self_attn')
137
+
138
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
139
+
140
+ #########################################################################
141
+
142
+ elif k.startswith('level_embed'):
143
+ new_k = k.replace('level_embed', 'transformer.level_embed')
144
+
145
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
146
+
147
+ #########################################################################
148
+
149
+ elif k.startswith('neck.convs'):
150
+ new_k = k.replace('neck.convs', 'input_proj')
151
+ new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3')
152
+ new_k = new_k.replace('conv.weight', '0.weight')
153
+ new_k = new_k.replace('conv.bias', '0.bias')
154
+ new_k = new_k.replace('gn.weight', '1.weight')
155
+ new_k = new_k.replace('gn.bias', '1.bias')
156
+
157
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
158
+
159
+ #########################################################################
160
+
161
+ elif 'neck.extra_convs.0' in k:
162
+ new_k = k.replace('neck.extra_convs.0', 'neck.convs.4')
163
+ new_k = new_k.replace('neck.convs', 'input_proj')
164
+ new_k = new_k.replace('conv.weight', '0.weight')
165
+ new_k = new_k.replace('conv.bias', '0.bias')
166
+ new_k = new_k.replace('gn.weight', '1.weight')
167
+ new_k = new_k.replace('gn.bias', '1.bias')
168
+
169
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
170
+
171
+ #########################################################################
172
+
173
+ elif k.startswith('text_feat_map'):
174
+ new_k = k.replace('text_feat_map', 'feat_map')
175
+
176
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
177
+
178
+ #########################################################################
179
+
180
+ elif k.startswith('language_model.language_backbone.body.model'):
181
+ new_k = k.replace('language_model.language_backbone.body.model', 'bert')
182
+
183
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
184
+
185
+ #########################################################################
186
+
187
+ elif k.startswith('backbone'):
188
+ new_k = k.replace('backbone', 'backbone.0')
189
+ if 'patch_embed.projection' in new_k:
190
+ new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj')
191
+ elif 'drop_after_pos' in new_k:
192
+ new_k = new_k.replace('drop_after_pos', 'pos_drop')
193
+
194
+ if 'stages' in new_k:
195
+ new_k = new_k.replace('stages', 'layers')
196
+ if 'ffn.layers.0.0' in new_k:
197
+ new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1')
198
+ elif 'ffn.layers.1' in new_k:
199
+ new_k = new_k.replace('ffn.layers.1', 'mlp.fc2')
200
+ elif 'attn.w_msa' in new_k:
201
+ new_k = new_k.replace('attn.w_msa', 'attn')
202
+
203
+ if 'downsample' in k:
204
+ if 'reduction.' in k:
205
+ new_v = correct_unfold_reduction_order(v)
206
+ elif 'norm.' in k:
207
+ new_v = correct_unfold_norm_order(v)
208
+
209
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
210
+
211
+ #########################################################################
212
+
213
+ else:
214
+ print('skip:', k)
215
+ continue
216
+
217
+ # if 'transformer.decoder.bbox_embed' in new_k:
218
+ # new_k = new_k.replace('transformer.decoder.bbox_embed', 'bbox_embed')
219
+ # if new_k.startswith('module.'):
220
+ # new_k = new_k.replace('module.', '')
221
+
222
+ return new_ckpt
223
+
224
+ def main():
225
+ parser = argparse.ArgumentParser(
226
+ description='Convert keys to GroundingDINO style.')
227
+ parser.add_argument(
228
+ 'src',
229
+ nargs='?',
230
+ default='grounding_dino_swin-l_pretrain_all-56d69e78.pth',
231
+ help='src model path or url')
232
+ # The dst path must be a full path of the new checkpoint.
233
+ parser.add_argument(
234
+ 'dst',
235
+ nargs='?',
236
+ default='mmdet_swinl.pth_groundingdino.pth',
237
+ help='save path')
238
+ args = parser.parse_args()
239
+
240
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
241
+
242
+ # mmdet中是state_dict而不是model
243
+ if 'state_dict' in checkpoint:
244
+ state_dict = checkpoint['state_dict']
245
+ else:
246
+ state_dict = checkpoint
247
+
248
+ weight = convert(state_dict)
249
+ torch.save(weight, args.dst)
250
+ # sha = subprocess.check_output(['sha256sum', args.dst]).decode()
251
+ # sha = calculate_sha256(args.dst)
252
+ # final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8])
253
+ # subprocess.Popen(['mv', args.dst, final_file])
254
+ print(f'Done!!, save to {args.dst}')
255
+
256
+ if __name__ == '__main__':
257
+ main()
258
+
259
+ # skip: dn_query_generator.label_embedding.weight
hf_model/mmdet2groundingdino_swint.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mmdet to groundingdino
2
+ import argparse
3
+ from collections import OrderedDict
4
+ import torch
5
+ from mmengine.runner import CheckpointLoader
6
+
7
+ # convert the functions from mmdet to groundingdino
8
+ def correct_unfold_reduction_order(x):
9
+ out_channel, in_channel = x.shape
10
+ x = x.reshape(out_channel, in_channel // 4, 4).transpose(1, 2)
11
+ x = x[:, [0, 2, 1, 3], :]
12
+ x = x.reshape(out_channel, in_channel)
13
+ return x
14
+
15
+ def correct_unfold_norm_order(x):
16
+ in_channel = x.shape[0]
17
+ x = x.reshape(in_channel // 4, 4).transpose(0, 1)
18
+ x = x[[0, 2, 1, 3], :]
19
+ x = x.reshape(in_channel)
20
+ return x
21
+
22
+ def convert(ckpt):
23
+ """Inverse mapping of checkpoint parameters to their original names."""
24
+ # Create a dictionary to hold the reversed checkpoint
25
+ new_ckpt = OrderedDict()
26
+
27
+ for k, v in list(ckpt.items()):
28
+ new_v = v # Start with the original value
29
+
30
+ # Inverse rules based on the convert function (from specific to general)
31
+ if k.startswith('decoder'):
32
+ new_k = k.replace('decoder', 'module.transformer.decoder')
33
+ if 'norms.2' in new_k:
34
+ new_k = new_k.replace('norms.2', 'norm1')
35
+ if 'norms.1' in new_k:
36
+ new_k = new_k.replace('norms.1', 'catext_norm')
37
+ if 'norms.0' in new_k:
38
+ new_k = new_k.replace('norms.0', 'norm2')
39
+ if 'norms.3' in new_k:
40
+ new_k = new_k.replace('norms.3', 'norm3')
41
+ if 'cross_attn_text' in new_k:
42
+ new_k = new_k.replace('cross_attn_text', 'ca_text')
43
+ new_k = new_k.replace('attn.in_proj_weight', 'in_proj_weight')
44
+ new_k = new_k.replace('attn.in_proj_bias', 'in_proj_bias')
45
+ new_k = new_k.replace('attn.out_proj.weight', 'out_proj.weight')
46
+ new_k = new_k.replace('attn.out_proj.bias', 'out_proj.bias')
47
+ if 'ffn.layers.0.0' in new_k:
48
+ new_k = new_k.replace('ffn.layers.0.0', 'linear1')
49
+ if 'ffn.layers.1' in new_k:
50
+ new_k = new_k.replace('ffn.layers.1', 'linear2')
51
+ if 'self_attn.attn' in new_k:
52
+ new_k = new_k.replace('self_attn.attn', 'self_attn')
53
+
54
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
55
+
56
+ #########################################################################
57
+
58
+ # encoder部分最后的reg_layer_id是6,和decoder区分开来
59
+ elif k.startswith('bbox_head.reg_branches.6'):
60
+ if k.startswith('bbox_head.reg_branches.6.0'):
61
+ new_k = k.replace('bbox_head.reg_branches.6.0',
62
+ 'module.transformer.enc_out_bbox_embed.layers.0')
63
+ if k.startswith('bbox_head.reg_branches.6.2'):
64
+ new_k = k.replace('bbox_head.reg_branches.6.2',
65
+ 'module.transformer.enc_out_bbox_embed.layers.1')
66
+ if k.startswith('bbox_head.reg_branches.6.4'):
67
+ new_k = k.replace('bbox_head.reg_branches.6.4',
68
+ 'module.transformer.enc_out_bbox_embed.layers.2')
69
+
70
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
71
+
72
+ #########################################################################
73
+
74
+ elif k.startswith('query_embedding'):
75
+ new_k = k.replace('query_embedding', 'module.transformer.tgt_embed')
76
+
77
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
78
+
79
+ #########################################################################
80
+
81
+ elif k.startswith('bbox_head.reg_branches'):
82
+ # mmdet直接省略了参数名的一部分,需要查看groundingdino的checkpoint
83
+ # groundingdino有两部分参数值是一致的
84
+ # 分别是module.bbox_embed和module.transformer.decoder.embed
85
+ # 所以mmdet直接将两部分参数进行了“合并”
86
+ reg_layer_id = int(k.split('.')[2])
87
+ linear_id = int(k.split('.')[3])
88
+ weight_or_bias = k.split('.')[-1]
89
+ new_k1 = 'module.transformer.decoder.bbox_embed.' + \
90
+ str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
91
+ new_k2 = 'module.bbox_embed.' + \
92
+ str(reg_layer_id) + '.layers.' + str(linear_id // 2) + '.' + weight_or_bias
93
+
94
+ new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
95
+ new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
96
+
97
+ #########################################################################
98
+
99
+ elif k.startswith('bbox_head.cls_branches.6'):
100
+ # mmdet在contrastive_embed中添加了bias项
101
+ # 但是decoder应该是0~5,所以6应该是采取两阶段微调后对应的enc_out.class_embed
102
+ new_k = 'module.transformer.enc_out_class_embed.bias'
103
+
104
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
105
+
106
+ #########################################################################
107
+
108
+ elif k.startswith('bbox_head.cls_branches'):
109
+ # mmdet在contrastive_embed中添加了bias项
110
+ new_k1 = 'module.transformer.decoder.class_embed.' + k[-6:]
111
+ new_k2 = 'module.class_embed.' + k[-6:]
112
+
113
+ new_ckpt[new_k1] = new_v # Add the key and value to the original checkpoint dict
114
+ new_ckpt[new_k2] = new_v # Add the key and value to the original checkpoint dict
115
+
116
+ #########################################################################
117
+
118
+ elif k.startswith('memory_trans_'):
119
+ if k.startswith('memory_trans_fc'):
120
+ new_k = k.replace('memory_trans_fc', 'module.transformer.enc_output')
121
+ elif k.startswith('memory_trans_norm'):
122
+ new_k = k.replace('memory_trans_norm', 'module.transformer.enc_output_norm')
123
+
124
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
125
+
126
+ #########################################################################
127
+
128
+ elif k.startswith('encoder'):
129
+ new_k = k.replace('encoder', 'module.transformer.encoder')
130
+ new_k = new_k.replace('norms.0', 'norm1')
131
+ new_k = new_k.replace('norms.1', 'norm2')
132
+ new_k = new_k.replace('norms.2', 'norm3')
133
+ new_k = new_k.replace('ffn.layers.0.0', 'linear1')
134
+ new_k = new_k.replace('ffn.layers.1', 'linear2')
135
+ if 'text_layers' in new_k:
136
+ new_k = new_k.replace('self_attn.attn', 'self_attn')
137
+
138
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
139
+
140
+ #########################################################################
141
+
142
+ elif k.startswith('level_embed'):
143
+ new_k = k.replace('level_embed', 'module.transformer.level_embed')
144
+
145
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
146
+
147
+ #########################################################################
148
+
149
+ elif k.startswith('neck.convs'):
150
+ new_k = k.replace('neck.convs', 'module.input_proj')
151
+ new_k = new_k.replace('neck.extra_convs.0', 'neck.convs.3')
152
+ new_k = new_k.replace('conv.weight', '0.weight')
153
+ new_k = new_k.replace('conv.bias', '0.bias')
154
+ new_k = new_k.replace('gn.weight', '1.weight')
155
+ new_k = new_k.replace('gn.bias', '1.bias')
156
+
157
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
158
+
159
+ #########################################################################
160
+
161
+ elif 'neck.extra_convs.0' in k:
162
+ new_k = k.replace('neck.extra_convs.0', 'neck.convs.3')
163
+ new_k = new_k.replace('neck.convs', 'module.input_proj')
164
+ new_k = new_k.replace('conv.weight', '0.weight')
165
+ new_k = new_k.replace('conv.bias', '0.bias')
166
+ new_k = new_k.replace('gn.weight', '1.weight')
167
+ new_k = new_k.replace('gn.bias', '1.bias')
168
+
169
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
170
+
171
+ #########################################################################
172
+
173
+ elif k.startswith('text_feat_map'):
174
+ new_k = k.replace('text_feat_map', 'module.feat_map')
175
+
176
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
177
+
178
+ #########################################################################
179
+
180
+ elif k.startswith('language_model.language_backbone.body.model'):
181
+ new_k = k.replace('language_model.language_backbone.body.model', 'module.bert')
182
+
183
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
184
+
185
+ #########################################################################
186
+
187
+ elif k.startswith('backbone'):
188
+ new_k = k.replace('backbone', 'module.backbone.0')
189
+ if 'patch_embed.projection' in new_k:
190
+ new_k = new_k.replace('patch_embed.projection', 'patch_embed.proj')
191
+ elif 'drop_after_pos' in new_k:
192
+ new_k = new_k.replace('drop_after_pos', 'pos_drop')
193
+
194
+ if 'stages' in new_k:
195
+ new_k = new_k.replace('stages', 'layers')
196
+ if 'ffn.layers.0.0' in new_k:
197
+ new_k = new_k.replace('ffn.layers.0.0', 'mlp.fc1')
198
+ elif 'ffn.layers.1' in new_k:
199
+ new_k = new_k.replace('ffn.layers.1', 'mlp.fc2')
200
+ elif 'attn.w_msa' in new_k:
201
+ new_k = new_k.replace('attn.w_msa', 'attn')
202
+
203
+ if 'downsample' in k:
204
+ if 'reduction.' in k:
205
+ new_v = correct_unfold_reduction_order(v)
206
+ elif 'norm.' in k:
207
+ new_v = correct_unfold_norm_order(v)
208
+
209
+ new_ckpt[new_k] = new_v # Add the key and value to the original checkpoint dict
210
+
211
+ #########################################################################
212
+
213
+ else:
214
+ print('skip:', k)
215
+ continue
216
+
217
+ # if 'module.transformer.decoder.bbox_embed' in new_k:
218
+ # new_k = new_k.replace('module.transformer.decoder.bbox_embed', 'module.bbox_embed')
219
+ # if new_k.startswith('module'):
220
+ # new_k = new_k.replace('module.', '')
221
+
222
+ return new_ckpt
223
+
224
+ def main():
225
+ parser = argparse.ArgumentParser(
226
+ description='Convert keys to GroundingDINO style.')
227
+ parser.add_argument(
228
+ 'src',
229
+ nargs='?',
230
+ default='grounding_dino_swin-t_pretrain_obj365_goldg_v3det_20231218_095741-e316e297.pth',
231
+ help='src model path or url')
232
+ # The dst path must be a full path of the new checkpoint.
233
+ parser.add_argument(
234
+ 'dst',
235
+ nargs='?',
236
+ default='check_mmdet_to_groundingdino.pth',
237
+ help='save path')
238
+ args = parser.parse_args()
239
+
240
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
241
+
242
+ # mmdet中是state_dict而不是model
243
+ if 'state_dict' in checkpoint:
244
+ state_dict = checkpoint['state_dict']
245
+ else:
246
+ state_dict = checkpoint
247
+
248
+ weight = convert(state_dict)
249
+ torch.save(weight, args.dst)
250
+ # sha = subprocess.check_output(['sha256sum', args.dst]).decode()
251
+ # sha = calculate_sha256(args.dst)
252
+ # final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8])
253
+ # subprocess.Popen(['mv', args.dst, final_file])
254
+ print(f'Done!!, save to {args.dst}')
255
+
256
+ if __name__ == '__main__':
257
+ main()
258
+
259
+ # skip: dn_query_generator.label_embedding.weight
hf_model/modeling_grounding_dino.py ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import GroundingDinoProcessor
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import torch
8
+
9
+
10
+ def prepare_targets(points, caption, shapes, emb_size, device, llmdet_processor):
11
+ gt_points_b = [np.array(points) / np.array(shapes)[::-1]]
12
+ gt_points_b[0] = gt_points_b[0].squeeze(0)
13
+
14
+ gt_points = [torch.from_numpy(img_points).float() for img_points in gt_points_b]
15
+ gt_logits = [torch.zeros((img_points.shape[0], emb_size)) for img_points in gt_points]
16
+
17
+ tokenized = llmdet_processor.tokenizer(caption[0], padding="longest", return_tensors="pt")
18
+ end_idxes = [torch.where(ids == 1012)[0][-1] for ids in tokenized['input_ids']]
19
+ for i, end_idx in enumerate(end_idxes):
20
+ gt_logits[i][:, :end_idx] = 1.0
21
+ caption_sizes = [idx + 2 for idx in end_idxes]
22
+
23
+ targets = [{"points": p.to(device), "labels": l.to(device), "caption_size": c}
24
+ for p, l, c in zip(gt_points, gt_logits, caption_sizes)]
25
+
26
+ return targets
27
+
28
+
29
+ def post_process_grounded_object_detection(
30
+ outputs,
31
+ box_threshold: float = 0.4,
32
+ ):
33
+ # for the fine-tuning model, the box threshold should be set to 0.50
34
+ logits, boxes = outputs.logits, outputs.pred_boxes
35
+
36
+ probs = torch.sigmoid(logits) # (batch_size, num_queries, 256)
37
+ scores = torch.max(probs, dim=-1)[0] # (batch_size, num_queries)
38
+
39
+ results = []
40
+ for idx, (s, b, p) in enumerate(zip(scores, boxes, probs)):
41
+ score = s[s > box_threshold]
42
+ box = b[s > box_threshold]
43
+ prob = p[s > box_threshold]
44
+ results.append({"scores": score, "boxes": box})
45
+
46
+ return results
47
+
48
+
49
+ class collator:
50
+ def __init__(self, processor=None, use_negative=True):
51
+ model_id = "fushh7/llmdet_swin_tiny_hf"
52
+ self.llmdet_processor = GroundingDinoProcessor.from_pretrained(model_id)
53
+ self.use_negative = use_negative
54
+
55
+ def __call__(self, batch):
56
+ # assume batch size is 1
57
+ example = batch[0]
58
+ image = example['image']
59
+ pil_image = example['image']
60
+ w, h = image.size
61
+ pos_caption = example['pos_caption']
62
+ neg_caption = example['neg_caption']
63
+ pos_points = example['pos_points']
64
+ neg_points = example['neg_points']
65
+ pos_count = example['pos_count']
66
+ neg_count = example['neg_count']
67
+ annotated_pos_count = example['annotated_pos_count']
68
+ annotated_neg_count = example['annotated_neg_count']
69
+
70
+ if 'type' in example:
71
+ sample_type = example['type']
72
+ else:
73
+ sample_type = 'eval'
74
+ category = example['category']
75
+ image_name = "{}_{}_{}_{}_{}".format(category, pos_caption, neg_caption, pos_count, neg_count)
76
+ pos_llm_det_inputs = self.llmdet_processor(images=image, text=pos_caption, return_tensors="pt", padding=True)
77
+ neg_llm_det_inputs = self.llmdet_processor(images=image, text=neg_caption, return_tensors="pt", padding=True)
78
+ pos_caption = [[pos_caption]]
79
+ neg_caption = [[neg_caption]]
80
+ shapes = [(w, h)]
81
+ pos_points = [pos_points]
82
+ neg_points = [neg_points]
83
+
84
+ # exemplars
85
+ if 'positive_exemplars' in example and 'negative_exemplars' in example and example[
86
+ 'positive_exemplars'] is not None and example['negative_exemplars'] is not None:
87
+ pos_exemplars = example['positive_exemplars']
88
+ neg_exemplars = example['negative_exemplars']
89
+ img_height, img_width = pil_image.size
90
+ norm_pos_exemplars = []
91
+ norm_neg_exemplars = []
92
+ exemplar_valid = True
93
+ for exemplars in pos_exemplars:
94
+ tly, tlx, bry, brx = exemplars
95
+ tlx = tlx / img_width
96
+ tly = tly / img_height
97
+ brx = brx / img_width
98
+ bry = bry / img_height
99
+ if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0:
100
+ exemplar_valid = False
101
+ if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0:
102
+ exemplar_valid = False
103
+ if tlx >= brx or tly >= bry:
104
+ exemplar_valid = False
105
+ tlx = max(tlx, 0)
106
+ tly = max(tly, 0)
107
+ tly = min(tly, 1 - 1e-4)
108
+ tlx = min(tlx, 1 - 1e-4)
109
+ brx = min(brx, 1)
110
+ bry = min(bry, 1)
111
+ brx = max(brx, tlx)
112
+ bry = max(bry, tly)
113
+ assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}"
114
+ norm_pos_exemplars.append([tlx, tly, brx, bry])
115
+ for exemplars in neg_exemplars:
116
+ tly, tlx, bry, brx = exemplars
117
+ tlx = tlx / img_width
118
+ tly = tly / img_height
119
+ brx = brx / img_width
120
+ bry = bry / img_height
121
+ if tlx < 0 or tly < 0 or tlx > 1.0 or tly > 1.0:
122
+ exemplar_valid = False
123
+ if brx < 0 or bry < 0 or brx > 1.0 or bry > 1.0:
124
+ exemplar_valid = False
125
+ if tlx >= brx or tly >= bry:
126
+ exemplar_valid = False
127
+ tlx = max(tlx, 0)
128
+ tly = max(tly, 0)
129
+ tly = min(tly, 1 - 1e-4)
130
+ tlx = min(tlx, 1 - 1e-4)
131
+ brx = min(brx, 1)
132
+ bry = min(bry, 1)
133
+ brx = max(brx, tlx)
134
+ bry = max(bry, tly)
135
+ assert tlx >= 0 and tly >= 0 and brx <= 1 and bry <= 1 and tlx <= brx and tly <= bry, f"tlx: {tlx}, tly: {tly}, brx: {brx}, bry: {bry}"
136
+ norm_neg_exemplars.append([tlx, tly, brx, bry])
137
+
138
+ if exemplar_valid:
139
+ pos_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_pos_exemplars]
140
+ neg_exemplars = [torch.from_numpy(np.array(exemplars)).float() for exemplars in norm_neg_exemplars]
141
+ pos_exemplars = torch.stack(pos_exemplars)
142
+ neg_exemplars = torch.stack(neg_exemplars)
143
+ batch_dict = {
144
+ 'pos_llm_det_inputs': pos_llm_det_inputs,
145
+ 'neg_llm_det_inputs': neg_llm_det_inputs,
146
+ 'pos_caption': pos_caption,
147
+ 'neg_caption': neg_caption,
148
+ 'shapes': shapes,
149
+ 'pos_points': pos_points,
150
+ 'neg_points': neg_points,
151
+ 'pos_count': pos_count,
152
+ 'neg_count': neg_count,
153
+ 'annotated_pos_count': annotated_pos_count,
154
+ 'annotated_neg_count': annotated_neg_count,
155
+ 'image': pil_image,
156
+ 'category': category,
157
+ 'type': sample_type,
158
+ 'pos_exemplars': pos_exemplars,
159
+ 'neg_exemplars': neg_exemplars,
160
+ 'image_name': image_name,
161
+ }
162
+ else:
163
+ batch_dict = {
164
+ 'pos_llm_det_inputs': pos_llm_det_inputs,
165
+ 'neg_llm_det_inputs': neg_llm_det_inputs,
166
+ 'pos_caption': pos_caption,
167
+ 'neg_caption': neg_caption,
168
+ 'shapes': shapes,
169
+ 'pos_points': pos_points,
170
+ 'neg_points': neg_points,
171
+ 'pos_count': pos_count,
172
+ 'neg_count': neg_count,
173
+ 'annotated_pos_count': annotated_pos_count,
174
+ 'annotated_neg_count': annotated_neg_count,
175
+ 'image': pil_image,
176
+ 'category': category,
177
+ 'type': sample_type,
178
+ 'image_name': image_name,
179
+ }
180
+ else:
181
+ batch_dict = {
182
+ 'pos_llm_det_inputs': pos_llm_det_inputs,
183
+ 'neg_llm_det_inputs': neg_llm_det_inputs,
184
+ 'pos_caption': pos_caption,
185
+ 'neg_caption': neg_caption,
186
+ 'shapes': shapes,
187
+ 'pos_points': pos_points,
188
+ 'neg_points': neg_points,
189
+ 'pos_count': pos_count,
190
+ 'neg_count': neg_count,
191
+ 'annotated_pos_count': annotated_pos_count,
192
+ 'annotated_neg_count': annotated_neg_count,
193
+ 'image': pil_image,
194
+ 'category': category,
195
+ 'type': sample_type,
196
+ 'image_name': image_name,
197
+ }
198
+
199
+ return batch_dict
200
+
201
+
202
+ import torch.distributed as dist
203
+
204
+
205
+ def rank0_print(*args):
206
+ if dist.is_initialized():
207
+ if dist.get_rank() == 0:
208
+ print(f"Rank {dist.get_rank()}: ", *args)
209
+ else:
210
+ print(*args)
211
+
212
+
213
+ def build_dataset(data_args):
214
+ from datasets import load_from_disk, concatenate_datasets
215
+
216
+ categories = ["FOO", "FUN", "OFF", "OTR", "HOU"]
217
+
218
+ if data_args.data_split not in categories:
219
+ rank0_print(f"Warning: Invalid data_split '{data_args.data_split}'. Switching to 'all' mode.")
220
+ data_args.data_split = "all"
221
+
222
+ if data_args.data_split == "all":
223
+ train_dataset = load_from_disk(data_args.train_data_path)
224
+ train_dataset = concatenate_datasets(
225
+ [train_dataset["FOO"], train_dataset["FUN"], train_dataset["OFF"], train_dataset["OTR"],
226
+ train_dataset["HOU"]])
227
+
228
+ val_dataset = load_from_disk(data_args.val_data_path)
229
+ val_dataset = concatenate_datasets(
230
+ [val_dataset["FOO"], val_dataset["FUN"], val_dataset["OFF"], val_dataset["OTR"], val_dataset["HOU"]])
231
+
232
+ test_dataset = load_from_disk(data_args.test_data_path)
233
+ test_dataset = concatenate_datasets(
234
+ [test_dataset["FOO"], test_dataset["FUN"], test_dataset["OFF"], test_dataset["OTR"], test_dataset["HOU"]])
235
+
236
+ weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path)
237
+ weakly_supervised_data = concatenate_datasets(
238
+ [weakly_supervised_data["FOO"], weakly_supervised_data["FUN"], weakly_supervised_data["OFF"],
239
+ weakly_supervised_data["OTR"], weakly_supervised_data["HOU"]])
240
+
241
+ rank0_print("Using 'all' mode: all categories for train/val/test")
242
+
243
+ else:
244
+ test_category = data_args.data_split
245
+ train_categories = [cat for cat in categories if cat != test_category]
246
+ train_dataset = load_from_disk(data_args.train_data_path)
247
+ print(train_categories, train_dataset.keys())
248
+ train_datasets = [train_dataset[cat] for cat in train_categories]
249
+ train_dataset = concatenate_datasets(train_datasets)
250
+
251
+ weakly_supervised_data = load_from_disk(data_args.weakly_supervised_data_path)
252
+ weakly_supervised_data = [weakly_supervised_data[cat] for cat in train_categories]
253
+ weakly_supervised_data = concatenate_datasets(weakly_supervised_data)
254
+
255
+ val_dataset = load_from_disk(data_args.val_data_path)
256
+ val_dataset = val_dataset[test_category]
257
+
258
+ test_dataset = load_from_disk(data_args.test_data_path)
259
+ test_dataset = test_dataset[test_category]
260
+
261
+ rank0_print(f"Cross-validation mode: using {train_categories} for train, {test_category} for val/test")
262
+
263
+ rank0_print('train_dataset: ', len(train_dataset))
264
+ rank0_print('val_dataset: ', len(val_dataset))
265
+ rank0_print('test_dataset: ', len(test_dataset))
266
+ rank0_print('weakly_supervised_data: ', len(weakly_supervised_data))
267
+
268
+ return train_dataset, val_dataset, test_dataset, weakly_supervised_data
269
+
270
+
271
+ def generate_pseudo_density_map(points_norm: torch.Tensor,
272
+ output_size: tuple[int, int],
273
+ sigma: float = 4.0,
274
+ normalize: bool = True) -> torch.Tensor:
275
+ device = points_norm.device
276
+ H, W = output_size
277
+ N = points_norm.shape[0]
278
+
279
+ ys = torch.arange(H, device=device).float()
280
+ xs = torch.arange(W, device=device).float()
281
+ grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij') # (H, W)
282
+
283
+ pts_px = points_norm.clone()
284
+ pts_px[:, 0] *= (W - 1) # x
285
+ pts_px[:, 1] *= (H - 1) # y
286
+
287
+ dx = grid_x.unsqueeze(0) - pts_px[:, 0].view(-1, 1, 1) # (N, H, W)
288
+ dy = grid_y.unsqueeze(0) - pts_px[:, 1].view(-1, 1, 1) # (N, H, W)
289
+ dist2 = dx ** 2 + dy ** 2
290
+ gaussians = torch.exp(-dist2 / (2 * sigma ** 2)) # (N, H, W)
291
+ density_map = gaussians.sum(dim=0, keepdim=True) # (1, H, W)
292
+
293
+ if normalize and N > 0:
294
+ density_map = density_map * (N / density_map.sum())
295
+
296
+ return density_map.unsqueeze(0)
297
+
298
+
299
+ def show_density_map(density_map: torch.Tensor,
300
+ points_norm: torch.Tensor | None = None,
301
+ figsize: tuple[int, int] = (6, 8),
302
+ cmap: str = "jet") -> None:
303
+ dm = density_map.squeeze().detach().cpu().numpy() # (H, W)
304
+ H, W = dm.shape
305
+
306
+ plt.figure(figsize=figsize)
307
+ plt.imshow(dm, cmap=cmap, origin="upper")
308
+ plt.colorbar(label="Density")
309
+
310
+ if points_norm is not None and points_norm.numel() > 0:
311
+ pts = points_norm.detach().cpu().numpy()
312
+ xs = pts[:, 0] * (W - 1)
313
+ ys = pts[:, 1] * (H - 1)
314
+ plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5)
315
+
316
+ plt.title(f"Density map (sum = {dm.sum():.2f})")
317
+ plt.axis("off")
318
+ plt.tight_layout()
319
+ plt.show()
320
+
321
+
322
+ def show_image_with_density(pil_img: Image.Image,
323
+ density_map: torch.Tensor,
324
+ points_norm: torch.Tensor | None = None,
325
+ cmap: str = "jet",
326
+ alpha: float = 0.45,
327
+ figsize: tuple[int, int] = (6, 8)) -> None:
328
+ dm = density_map.squeeze().detach().cpu().numpy() # (H, W)
329
+ H, W = dm.shape
330
+
331
+ img_resized = pil_img.resize((W, H), Image.BILINEAR) # or LANCZOS
332
+ img_np = np.asarray(img_resized)
333
+
334
+ plt.figure(figsize=figsize)
335
+ plt.imshow(img_np, origin="upper")
336
+ plt.imshow(dm, cmap=cmap, alpha=alpha, origin="upper")
337
+
338
+ if points_norm is not None and points_norm.numel() > 0:
339
+ pts = points_norm.detach().cpu().numpy()
340
+ xs = pts[:, 0] * (W - 1)
341
+ ys = pts[:, 1] * (H - 1)
342
+ plt.scatter(xs, ys, c="white", s=12, edgecolors="black", linewidths=0.5)
343
+
344
+ plt.title(f"Overlay (density sum = {dm.sum():.2f})")
345
+ plt.axis("off")
346
+ plt.tight_layout()
347
+ plt.show()
348
+
349
+
350
+ def build_point_count_map(feat_maps: torch.Tensor,
351
+ pts_norm_list: list[torch.Tensor]) -> torch.Tensor:
352
+ assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)"
353
+ B, H, W, _ = feat_maps.shape
354
+ device = feat_maps.device
355
+
356
+ count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device)
357
+
358
+ for b in range(B):
359
+ pts = pts_norm_list[b].to(device).float() # (Ni, 2)
360
+ if pts.numel() == 0:
361
+ continue
362
+
363
+ idx_xy = (pts * torch.tensor([W, H], device=device)).long()
364
+ idx_xy[..., 0].clamp_(0, W - 1) # x
365
+ idx_xy[..., 1].clamp_(0, H - 1) # y
366
+
367
+ lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,)
368
+ one = torch.ones_like(lin_idx, dtype=torch.float32)
369
+
370
+ flat = torch.zeros(H * W, dtype=torch.float32, device=device)
371
+ flat.scatter_add_(0, lin_idx, one)
372
+ count_map[b] = flat.view(H, W)
373
+
374
+ return count_map
375
+
376
+
377
+ import torch
378
+ import torch.nn.functional as F
379
+
380
+
381
+ def extract_pos_tokens_single(feat_maps: torch.Tensor,
382
+ count_map: torch.Tensor):
383
+ assert feat_maps.dim() == 4 and count_map.dim() == 3, "维度应为 (B,H,W,D) / (B,H,W)"
384
+ B, H, W, D = feat_maps.shape
385
+ assert B == 1, "当前函数假设 batch_size == 1"
386
+ feat = feat_maps[0] # (H,W,D)
387
+ cnt = count_map[0] # (H,W)
388
+ pos_mask = cnt > 0 # Bool (H,W)
389
+ if pos_mask.sum() == 0:
390
+ empty = torch.empty(0, device=feat.device)
391
+ return empty.reshape(0, D), empty.long()
392
+ pos_tokens = feat[pos_mask] # (N_pos, D)
393
+ y_idx, x_idx = torch.nonzero(pos_mask, as_tuple=True)
394
+ lin_index = y_idx * W + x_idx # (N_pos,)
395
+ return pos_tokens, lin_index
396
+
397
+
398
+ def filter_overlap(pos_tok, lin_pos, neg_tok, lin_neg):
399
+ pos_only_mask = ~torch.isin(lin_pos, lin_neg)
400
+ neg_only_mask = ~torch.isin(lin_neg, lin_pos)
401
+ return pos_tok[pos_only_mask], neg_tok[neg_only_mask]
402
+
403
+
404
+ # ------------------------------------------------------------
405
+ # 2) supervised contrastive loss
406
+ # ------------------------------------------------------------
407
+ def supcon_pos_neg(pos_tokens, neg_tokens, temperature=0.07):
408
+ """
409
+ pos_tokens : (Np, D) Pos token
410
+ neg_tokens : (Nn, D) Neg token
411
+ """
412
+ if pos_tokens.numel() == 0 or neg_tokens.numel() == 0:
413
+ return torch.tensor(0., device=pos_tokens.device, requires_grad=True)
414
+ pos_tokens = F.normalize(pos_tokens, dim=-1)
415
+ neg_tokens = F.normalize(neg_tokens, dim=-1)
416
+ feats = torch.cat([pos_tokens, neg_tokens], dim=0) # (N, D)
417
+ labels = torch.cat([torch.zeros(len(pos_tokens), device=feats.device, dtype=torch.long),
418
+ torch.ones(len(neg_tokens), device=feats.device, dtype=torch.long)], dim=0) # (N,)
419
+ logits = feats @ feats.T / temperature # (N, N)
420
+ logits.fill_diagonal_(-1e4)
421
+ mask_pos = labels.unsqueeze(0) == labels.unsqueeze(1) # (N, N)
422
+ mask_pos.fill_diagonal_(False)
423
+ exp_logits = logits.exp()
424
+ denom = exp_logits.sum(dim=1, keepdim=True) # Σ_{a≠i} exp
425
+ log_prob = logits - denom.log() # log softmax
426
+ loss_i = -(mask_pos * log_prob).sum(1) / mask_pos.sum(1).clamp_min(1)
427
+ loss = loss_i.mean()
428
+ return loss
429
+
430
+
431
+ def build_point_count_map(feat_maps: torch.Tensor,
432
+ pts_norm_list: list[torch.Tensor]) -> torch.Tensor:
433
+ assert feat_maps.dim() == 4, "expect NHWC: (B,H,W,D)"
434
+ B, H, W, _ = feat_maps.shape
435
+ device = feat_maps.device
436
+
437
+ count_map = torch.zeros((B, H, W), dtype=torch.float32, device=device)
438
+
439
+ for b in range(B):
440
+ pts = pts_norm_list[b].to(device).float() # (Ni, 2)
441
+ if pts.numel() == 0:
442
+ continue
443
+
444
+ idx_xy = (pts * torch.tensor([W, H], device=device)).long()
445
+ idx_xy[..., 0].clamp_(0, W - 1) # x
446
+ idx_xy[..., 1].clamp_(0, H - 1) # y
447
+
448
+ lin_idx = idx_xy[:, 1] * W + idx_xy[:, 0] # (Ni,)
449
+ one = torch.ones_like(lin_idx, dtype=torch.float32)
450
+
451
+ flat = torch.zeros(H * W, dtype=torch.float32, device=device)
452
+ flat.scatter_add_(0, lin_idx, one)
453
+ count_map[b] = flat.view(H, W)
454
+
455
+ return count_map