# ============================================================ # Colposcopy Inference Backend # Production-ready | VS Code | Hugging Face compatible # ============================================================ import os import cv2 import numpy as np import torch import torch.nn as nn import joblib from torchvision import transforms, models from PIL import Image # ------------------------------------------------------------ # DEVICE # ------------------------------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------------------------------------ # PATHS (RELATIVE — REQUIRED FOR DEPLOYMENT) # ------------------------------------------------------------ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_DIR = os.path.join(BASE_DIR, "models") OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") os.makedirs(OUTPUT_DIR, exist_ok=True) SEG_MODEL_PATH = os.path.join(MODEL_DIR, "seg_yolov8n_best.pt") FUSION_MODEL_PATH = os.path.join(MODEL_DIR, "fusion_model.pth") CLF_PATH = os.path.join(MODEL_DIR, "logreg_classifier.joblib") # ------------------------------------------------------------ # LOAD MODELS (ONCE) # ------------------------------------------------------------ from ultralytics import YOLO seg_model = YOLO(SEG_MODEL_PATH) clf = joblib.load(CLF_PATH) # ------------------------------------------------------------ # FUSION MODEL DEFINITION # ------------------------------------------------------------ class ImageEncoder(nn.Module): def __init__(self): super().__init__() base = models.resnet18(pretrained=False) self.backbone = nn.Sequential(*list(base.children())[:-1]) self.fc = nn.Linear(512, 512) def forward(self, x): x = self.backbone(x) return self.fc(x.view(x.size(0), -1)) class FeatureEncoder(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(7, 64), nn.ReLU(), nn.Linear(64, 64) ) def forward(self, x): return self.net(x) class FusionModel(nn.Module): def __init__(self): super().__init__() self.img_enc = ImageEncoder() self.feat_enc = FeatureEncoder() self.norm = nn.BatchNorm1d(576) def forward(self, img, feat): img_emb = self.img_enc(img) feat_emb = self.feat_enc(feat) return self.norm(torch.cat([img_emb, feat_emb], dim=1)) fusion_model = FusionModel().to(device) fusion_model.load_state_dict(torch.load(FUSION_MODEL_PATH, map_location=device)) fusion_model.eval() # ------------------------------------------------------------ # IMAGE TRANSFORM # ------------------------------------------------------------ transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # ------------------------------------------------------------ # CONSTANTS # ------------------------------------------------------------ CERVIX_ID = 0 SCJ_ID = 1 ACET_ID = 3 MIN_ACET_RATIO = 0.01 # ------------------------------------------------------------ # GEOMETRY UTILITIES # ------------------------------------------------------------ def polygon_to_mask(polygon, H, W): pts = np.array([[int(x * W), int(y * H)] for x, y in polygon], np.int32) mask = np.zeros((H, W), dtype=np.uint8) cv2.fillPoly(mask, [pts], 1) return mask def mask_area(mask): return mask.sum() / mask.size def centroid_distance(mask1, mask2): if mask2 is None: return 1.0 ys1, xs1 = np.where(mask1 == 1) ys2, xs2 = np.where(mask2 == 1) if len(xs1) == 0 or len(xs2) == 0: return 1.0 c1 = np.array([xs1.mean(), ys1.mean()]) c2 = np.array([xs2.mean(), ys2.mean()]) return np.linalg.norm(c1 - c2) / max(mask1.shape) def overlap_ratio(mask1, mask2): if mask2 is None: return 0.0 inter = np.logical_and(mask1, mask2).sum() return inter / mask1.sum() if mask1.sum() > 0 else 0.0 # ------------------------------------------------------------ # LOAD YOLO POLYGONS # ------------------------------------------------------------ def load_yolo_segmentation(label_path): objects = [] if not os.path.exists(label_path): return objects with open(label_path) as f: for line in f: parts = list(map(float, line.strip().split())) cls = int(parts[0]) coords = parts[1:] polygon = [(coords[i], coords[i + 1]) for i in range(0, len(coords), 2)] objects.append({"cls": cls, "polygon": polygon}) return objects # ------------------------------------------------------------ # FEATURE EXTRACTION # ------------------------------------------------------------ def extract_features_from_label(label_path, H, W): objects = load_yolo_segmentation(label_path) cervix_masks, scj_masks, acet_masks = [], [], [] for obj in objects: m = polygon_to_mask(obj["polygon"], H, W) if obj["cls"] == CERVIX_ID: cervix_masks.append(m) elif obj["cls"] == SCJ_ID: scj_masks.append(m) elif obj["cls"] == ACET_ID: acet_masks.append(m) cervix = max(cervix_masks, key=lambda m: m.sum()) if cervix_masks else np.zeros((H, W)) scj = max(scj_masks, key=lambda m: m.sum()) if scj_masks else None cervix_area = mask_area(cervix) acet_union = np.zeros((H, W), dtype=np.uint8) for m in acet_masks: acet_union = np.maximum(acet_union, m) acet_union = acet_union * cervix if acet_union.sum() > 0: acet_union = cv2.morphologyEx( acet_union, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8) ) acet_area = mask_area(acet_union) acet_present = int(cervix_area > 0 and acet_area / cervix_area >= MIN_ACET_RATIO) if acet_present: dist_acet_scj = centroid_distance(acet_union, scj) lesion_center_dist = centroid_distance(acet_union, cervix) overlap_lesion_scj = overlap_ratio(acet_union, scj) else: dist_acet_scj = lesion_center_dist = 1.0 overlap_lesion_scj = 0.0 return torch.tensor([ acet_present, 1 if acet_present else 0, acet_area if acet_present else 0.0, acet_area / cervix_area if acet_present else 0.0, dist_acet_scj, lesion_center_dist, overlap_lesion_scj ], dtype=torch.float32) # ------------------------------------------------------------ # SAVE VISUALIZATION FOR UI # ------------------------------------------------------------ def save_overlay(image_path, label_path, out_path): image = np.array(Image.open(image_path).convert("RGB")) H, W, _ = image.shape objects = load_yolo_segmentation(label_path) cervix = np.zeros((H, W)) scj = np.zeros((H, W)) acet = np.zeros((H, W)) for obj in objects: m = polygon_to_mask(obj["polygon"], H, W) if obj["cls"] == CERVIX_ID: cervix = np.maximum(cervix, m) elif obj["cls"] == SCJ_ID: scj = np.maximum(scj, m) elif obj["cls"] == ACET_ID: acet = np.maximum(acet, m) overlay = image.copy() overlay[cervix == 1] = 0.7 * overlay[cervix == 1] + 0.3 * np.array([0, 0, 255]) overlay[scj == 1] = 0.7 * overlay[scj == 1] + 0.3 * np.array([0, 255, 0]) overlay[acet == 1] = 0.7 * overlay[acet == 1] + 0.3 * np.array([255, 0, 0]) Image.fromarray(overlay.astype(np.uint8)).save(out_path) # ------------------------------------------------------------ # PUBLIC API — UI CALLS THIS # ------------------------------------------------------------ def run_inference(image_path: str) -> dict: results = seg_model(image_path, conf=0.15, save_txt=True, save=False) save_dir = results[0].save_dir name = os.path.splitext(os.path.basename(image_path))[0] label_path = os.path.join(save_dir, "labels", f"{name}.txt") if not os.path.exists(label_path): return {"decision": "Segmentation failed"} image = Image.open(image_path).convert("RGB") W, H = image.size img_tensor = transform(image).unsqueeze(0).to(device) feat = extract_features_from_label(label_path, H, W) feat_tensor = feat.unsqueeze(0).to(device) with torch.no_grad(): embedding = fusion_model(img_tensor, feat_tensor) prob = clf.predict_proba(embedding.cpu().numpy())[0, 1] acet_present = int(feat[0].item()) if acet_present == 0: decision = "Low-confidence normal (no acet detected)" if prob < 0.2 else "Uncertain – lesion may be subtle" else: decision = "Likely Normal" if prob < 0.2 else "Borderline – Review" if prob < 0.5 else "Likely Abnormal" overlay_path = os.path.join(OUTPUT_DIR, f"{name}_overlay.png") save_overlay(image_path, label_path, overlay_path) return { "decision": decision, "probability_abnormal": float(prob), "acet_present": acet_present, "overlay_image": overlay_path }