{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import numpy as np\n", "from PIL import Image, ImageDraw, ImageFont\n", "import matplotlib.pyplot as plt\n", "import cv2\n", "from segment_anything import sam_model_registry\n", "from segment_anything.predictor_sammed import SammedPredictor\n", "from argparse import Namespace\n", "import torch\n", "import torchvision\n", "import os, sys\n", "import random\n", "import warnings\n", "from scipy import ndimage\n", "import functools\n", "\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "args = Namespace()\n", "args.device = device\n", "args.image_size = 256\n", "args.encoder_adapter = True\n", "args.sam_checkpoint = \"pretrain_model/sam-med2d_b.pth\" #sam_vit_b.pth sam-med2d_b.pth" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_model(args):\n", " model = sam_model_registry[\"vit_b\"](args).to(args.device)\n", " model.eval()\n", " predictor = SammedPredictor(model)\n", " return predictor\n", "\n", "\n", "predictor_with_adapter = load_model(args)\n", "args.encoder_adapter = False\n", "predictor_without_adapter = load_model(args)\n", "\n", "def run_sammed(input_image, selected_points, last_mask, adapter_type):\n", " if adapter_type == \"SAM-Med2D-B\":\n", " predictor = predictor_with_adapter\n", " else:\n", " predictor = predictor_without_adapter\n", " \n", " image_pil = Image.fromarray(input_image) #.convert(\"RGB\")\n", " image = input_image\n", " H,W,_ = image.shape\n", " predictor.set_image(image)\n", " centers = np.array([a for a,b in selected_points ])\n", " point_coords = centers\n", " point_labels = np.array([b for a,b in selected_points ])\n", "\n", " masks, _, logits = predictor.predict(\n", " point_coords=point_coords,\n", " point_labels=point_labels,\n", " mask_input = last_mask,\n", " multimask_output=True \n", " ) \n", "\n", " mask_image = Image.new('RGBA', (W, H), color=(0, 0, 0, 0))\n", " mask_draw = ImageDraw.Draw(mask_image)\n", " for mask in masks:\n", " draw_mask(mask, mask_draw, random_color=False)\n", " image_draw = ImageDraw.Draw(image_pil)\n", "\n", " draw_point(selected_points, image_draw)\n", "\n", " image_pil = image_pil.convert('RGBA')\n", " image_pil.alpha_composite(mask_image)\n", " last_mask = torch.sigmoid(torch.as_tensor(logits, dtype=torch.float, device=device))\n", " return [(image_pil, mask_image), last_mask]\n", "\n", "\n", "def draw_mask(mask, draw, random_color=False):\n", " if random_color:\n", " color = (random.randint(0, 255), random.randint(\n", " 0, 255), random.randint(0, 255), 153)\n", " else:\n", " color = (30, 144, 255, 153)\n", "\n", " nonzero_coords = np.transpose(np.nonzero(mask))\n", "\n", " for coord in nonzero_coords:\n", " draw.point(coord[::-1], fill=color)\n", "\n", "def draw_point(point, draw, r=5):\n", " show_point = []\n", " for point, label in point:\n", " x,y = point\n", " if label == 1:\n", " draw.ellipse((x-r, y-r, x+r, y+r), fill='green')\n", " elif label == 0:\n", " draw.ellipse((x-r, y-r, x+r, y+r), fill='red')\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Keyboard interruption in main thread... closing server.\n" ] }, { "data": { "text/plain": [] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "colors = [(255, 0, 0), (0, 255, 0)]\n", "markers = [1, 5]\n", "block = gr.Blocks()\n", "with block:\n", " with gr.Row():\n", " gr.Markdown(\n", " '''# SAM-Med2D!🚀\n", " SAM-Med2D is an interactive segmentation model based on the SAM model for medical scenarios, supporting multi-point interactive segmentation and box interaction. \n", " Currently, only multi-point interaction is supported in this application. More information can be found on [**GitHub**](https://github.com/uni-medical/SAM-Med2D/tree/main).\n", " '''\n", " )\n", " with gr.Row():\n", " # select model\n", " adapter_type = gr.Dropdown([\"SAM-Med2D-B\", \"SAM-Med2D-B_w/o_adapter\"], value='SAM-Med2D-B', label=\"Select Adapter\")\n", " # adapter_type.change(fn = update_model, inputs=[adapter_type])\n", " \n", " with gr.Tab(label='Image'):\n", " with gr.Row().style(equal_height=True):\n", " with gr.Column():\n", " # input image\n", " original_image = gr.State(value=None) # store original image without points, default None\n", " input_image = gr.Image(type=\"numpy\")\n", " # point prompt\n", " with gr.Column():\n", " selected_points = gr.State([]) # store points\n", " last_mask = gr.State(None) \n", " with gr.Row():\n", " gr.Markdown('You can click on the image to select points prompt. Default: foreground_point.')\n", " undo_button = gr.Button('Undo point')\n", " radio = gr.Radio(['foreground_point', 'background_point'], label='point labels')\n", " button = gr.Button(\"Run!\")\n", " \n", " gallery_sammed = gr.Gallery(\n", " label=\"Generated images\", show_label=False, elem_id=\"gallery\").style(preview=True, grid=2,object_fit=\"scale-down\")\n", " \n", " def process_example(img):\n", " return img, [], None \n", " \n", " def store_img(img):\n", " return img, [], None # when new image is uploaded, `selected_points` should be empty\n", " input_image.upload(\n", " store_img,\n", " [input_image],\n", " [original_image, selected_points, last_mask]\n", " )\n", " # user click the image to get points, and show the points on the image\n", " def get_point(img, sel_pix, point_type, evt: gr.SelectData):\n", " if point_type == 'foreground_point':\n", " sel_pix.append((evt.index, 1)) # append the foreground_point\n", " elif point_type == 'background_point':\n", " sel_pix.append((evt.index, 0)) # append the background_point\n", " else:\n", " sel_pix.append((evt.index, 1)) # default foreground_point\n", " # draw points\n", " for point, label in sel_pix:\n", " cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)\n", " # if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB\n", " # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", " return img if isinstance(img, np.ndarray) else np.array(img)\n", " \n", " input_image.select(\n", " get_point,\n", " [input_image, selected_points, radio],\n", " [input_image],\n", " )\n", "\n", " # undo the selected point\n", " def undo_points(orig_img, sel_pix):\n", " if isinstance(orig_img, int): # if orig_img is int, the image if select from examples\n", " temp = cv2.imread(image_examples[orig_img][0])\n", " temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)\n", " else:\n", " temp = orig_img.copy()\n", " # draw points\n", " if len(sel_pix) != 0:\n", " sel_pix.pop()\n", " for point, label in sel_pix:\n", " cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)\n", " if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB\n", " temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)\n", " return temp, None if isinstance(temp, np.ndarray) else np.array(temp), None\n", " \n", " undo_button.click(\n", " undo_points,\n", " [original_image, selected_points],\n", " [input_image, last_mask]\n", " )\n", "\n", " with gr.Row():\n", " with gr.Column():\n", " gr.Examples([\"data_demo/images/amos_0507_31.png\", \"data_demo/images/s0114_111.png\" ], inputs=[input_image], outputs=[original_image, selected_points,last_mask], fn=process_example, run_on_click=True)\n", "\n", " button.click(fn=run_sammed, inputs=[original_image, selected_points, last_mask, adapter_type], outputs=[gallery_sammed, last_mask])\n", "\n", "block.launch(debug=True, share=True, show_error=True)\n" ] } ], "metadata": { "kernelspec": { "display_name": "MMseg", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }