import os import gradio as gr import torch from PIL import Image, ImageDraw from transformers import GroundingDinoProcessor from hf_model import CountEX from utils import post_process_grounded_object_detection # Global variables for model and processor model = None processor = None device = None def load_model(): """Load model and processor once at startup""" global model, processor, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model - change path for HF Spaces model_id = "BBVisual/CountEX-KC" # Change to your HF model repo model = CountEX.from_pretrained(model_id) model = model.to(torch.bfloat16) model = model.to(device) model.eval() # Load processor processor_id = "fushh7/llmdet_swin_tiny_hf" processor = GroundingDinoProcessor.from_pretrained(processor_id) return model, processor, device def count_objects(image, pos_caption, neg_caption, box_threshold, point_radius, point_color): """ Main inference function for counting objects Args: image: Input PIL Image pos_caption: Positive prompt (objects to count) neg_caption: Negative prompt (objects to exclude) box_threshold: Detection confidence threshold point_radius: Radius of visualization points point_color: Color of visualization points Returns: Annotated image and count """ global model, processor, device if model is None: load_model() # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Ensure captions end with period if not pos_caption.endswith('.'): pos_caption = pos_caption + '.' if neg_caption and not neg_caption.endswith('.'): neg_caption = neg_caption + '.' # Process positive caption pos_inputs = processor( images=image, text=pos_caption, return_tensors="pt", padding=True ) pos_inputs = pos_inputs.to(device) pos_inputs['pixel_values'] = pos_inputs['pixel_values'].to(torch.bfloat16) # Process negative caption if provided use_neg = bool(neg_caption and neg_caption.strip() and neg_caption != '.') if use_neg: neg_inputs = processor( images=image, text=neg_caption, return_tensors="pt", padding=True ) neg_inputs = {k: v.to(device) for k, v in neg_inputs.items()} neg_inputs['pixel_values'] = neg_inputs['pixel_values'].to(torch.bfloat16) # Add negative inputs to positive inputs dict pos_inputs['neg_token_type_ids'] = neg_inputs['token_type_ids'] pos_inputs['neg_attention_mask'] = neg_inputs['attention_mask'] pos_inputs['neg_pixel_mask'] = neg_inputs['pixel_mask'] pos_inputs['neg_pixel_values'] = neg_inputs['pixel_values'] pos_inputs['neg_input_ids'] = neg_inputs['input_ids'] pos_inputs['use_neg'] = True else: pos_inputs['use_neg'] = False # Run inference with torch.no_grad(): outputs = model(**pos_inputs) # Post-process outputs outputs["pred_points"] = outputs["pred_boxes"][:, :, :2] outputs["pred_logits"] = outputs["logits"] # Use custom threshold if provided, otherwise use model default threshold = box_threshold if box_threshold > 0 else model.box_threshold results = post_process_grounded_object_detection(outputs, box_threshold=threshold)[0] # Extract points boxes = results["boxes"] boxes = [box.tolist() for box in boxes] points = [[box[0], box[1]] for box in boxes] # Visualize results img_w, img_h = image.size img_draw = image.copy() draw = ImageDraw.Draw(img_draw) for point in points: x = point[0] * img_w y = point[1] * img_h draw.ellipse( [x - point_radius, y - point_radius, x + point_radius, y + point_radius], fill=point_color ) count = len(points) return img_draw, f"Count: {count}" # Create Gradio interface def create_demo(): with gr.Blocks(title="CountEx: Discriminative Visual Counting") as demo: gr.Markdown(""" # CountEx: Fine-Grained Counting via Exemplars and Exclusion Count specific objects in images using positive and negative text prompts. **Important Note: Both the Positive and Negative prompts must end with a period (.) for the model to correctly interpret the instruction.** """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") pos_caption = gr.Textbox( label="Positive Prompt", placeholder="e.g., Green Apple", value="Pos Caption Here." ) neg_caption = gr.Textbox( label="Negative Prompt (optional)", placeholder="e.g., Red Apple", value="None." ) box_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.42, step=0.01, label="Detection Threshold (0.42 = use model default)" ) point_radius = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Point Radius" ) point_color = gr.Dropdown( choices=["blue", "red", "green", "yellow", "cyan", "magenta", "white"], value="blue", label="Point Color" ) submit_btn = gr.Button("Count Objects", variant="primary") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Result") count_output = gr.Textbox(label="Count Result") # Example images gr.Examples( examples=[ ["examples/apples.png", "Green Apple.", "Red Apple."], ], inputs=[input_image, pos_caption, neg_caption], outputs=[output_image, count_output], fn=count_objects, cache_examples=False, ) submit_btn.click( fn=count_objects, inputs=[input_image, pos_caption, neg_caption, box_threshold, point_radius, point_color], outputs=[output_image, count_output] ) return demo if __name__ == "__main__": # Load model at startup print("Loading model...") load_model() print("Model loaded!") # Create and launch demo demo = create_demo() demo.launch()