Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from transformers import AutoImageProcessor | |
| from transformers import AutoModelForObjectDetection | |
| from PIL import Image | |
| model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_data_only" | |
| image_processor = AutoImageProcessor.from_pretrained(model_save_path) | |
| model = AutoModelForObjectDetection.from_pretrained(model_save_path) | |
| id2label = model.config.id2label | |
| color_dict = { | |
| "not_trash": "red", | |
| "bin": "green", | |
| "trash": "blue", | |
| "hand": "purple" | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| def predict_on_image(image, conf_threshold=0.25): | |
| with torch.no_grad(): | |
| inputs = image_processor(images=[image], return_tensors="pt") | |
| outputs = model(**inputs.to(device)) | |
| target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width | |
| results = image_processor.post_process_object_detection(outputs, | |
| threshold=conf_threshold, | |
| target_sizes=target_sizes)[0] | |
| # Return all items in results to CPU | |
| for key, value in results.items(): | |
| try: | |
| results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block | |
| except: | |
| results[key] = value.cpu() | |
| # Can return results as plotted on a PIL image (then display the image) | |
| draw = ImageDraw.Draw(image) | |
| for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
| # Create coordinates | |
| x, y, x2, y2 = tuple(box.tolist()) | |
| # Get label_name | |
| label_name = id2label[label.item()] | |
| targ_color = color_dict[label_name] | |
| # Draw the rectangle | |
| draw.rectangle(xy=(x, y, x2, y2), | |
| outline=targ_color, | |
| width=3) | |
| # Create a text string to display | |
| text_string_to_show = f"{label_name} ({round(score.item(), 3)})" | |
| # Draw the text on the image | |
| draw.text(xy=(x, y), | |
| text=text_string_to_show, | |
| fill="white") | |
| # Remove the draw each time | |
| del draw | |
| return image | |
| demo = gr.Interface( | |
| fn=predict_on_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Target Image"), | |
| gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold") | |
| ], | |
| outputs=gr.Image(type="pil"), | |
| title="๐ฎ Trashify Object Detection Demo", | |
| description="Upload an image to detect whether there's a bin, a hand or trash in it." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |