image_embedding / app.py
liyb1995's picture
Update app.py
4b9239a verified
import torch
import gradio as gr
from PIL import Image
import requests
import io
from transformers import AutoImageProcessor, AutoModel
# ======== 使用 DINOv2-Large(推荐)========
model_name = "facebook/dinov2-large"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
def image_to_embedding(url):
try:
# 下载图像
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, timeout=10,headers=headers,stream=True)
response.raise_for_status()
except Exception as e:
return {"error": f"Failed to download image: {e}"}
image = Image.open(io.BytesIO(response.content)).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Dinov2Model 输出 last_hidden_state: [1, num_patches+1, 1024]
# CLS token 在第一个位置
emb = outputs.last_hidden_state[:, 0, :]
emb = torch.nn.functional.normalize(emb, p=2, dim=-1)
return emb[0].cpu().numpy().tolist()
def process(image):
if image is None:
return None, "No image uploaded"
emb = image_to_embedding(image)
return emb, len(emb)
with gr.Blocks(title="DINOv2 Image Embedding") as demo:
gr.Markdown("# DINOv2 Image Embedding Service")
gr.Markdown("Upload an image to extract visual embedding")
# 图像接口
gr.Interface(
fn=image_to_embedding,
inputs="text", # 输入 URL
outputs="json",
api_name="imageEmbedding"
)
demo.launch()