Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import requests | |
| import logging | |
| from typing import Tuple, List, Dict | |
| logger = logging.getLogger(__name__) | |
| class SafetyGuard: | |
| """ | |
| Wrapper around NVIDIA Llama Guard (meta/llama-guard-4-12b) hosted at | |
| https://integrate.api.nvidia.com/v1/chat/completions | |
| Exposes helpers to validate: | |
| - user input safety | |
| - model output safety (in context of the user question) | |
| """ | |
| def __init__(self): | |
| self.api_key = os.getenv("NVIDIA_URI") | |
| if not self.api_key: | |
| raise ValueError("NVIDIA_URI environment variable not set for SafetyGuard") | |
| self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| self.model = "meta/llama-guard-4-12b" | |
| self.timeout_s = 30 | |
| def _chunk_text(text: str, chunk_size: int = 2800, overlap: int = 200) -> List[str]: | |
| """Chunk long text to keep request payloads small enough for the guard. | |
| Uses character-based approximation with small overlap. | |
| """ | |
| if not text: | |
| return [""] | |
| n = len(text) | |
| if n <= chunk_size: | |
| return [text] | |
| chunks: List[str] = [] | |
| start = 0 | |
| while start < n: | |
| end = min(start + chunk_size, n) | |
| chunks.append(text[start:end]) | |
| if end == n: | |
| break | |
| start = max(0, end - overlap) | |
| return chunks | |
| def _call_guard(self, messages: List[Dict], max_tokens: int = 512) -> str: | |
| # Enhance messages with medical context if detected | |
| enhanced_messages = self._enhance_messages_with_context(messages) | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", | |
| } | |
| # Try OpenAI-compatible schema first | |
| payload_chat = { | |
| "model": self.model, | |
| "messages": enhanced_messages, | |
| "temperature": 0.2, | |
| "top_p": 0.7, | |
| "max_tokens": max_tokens, | |
| "stream": False, | |
| } | |
| # Alternative schema (some NVIDIA deployments require message content objects) | |
| alt_messages = [] | |
| for m in enhanced_messages: | |
| content = m.get("content", "") | |
| if isinstance(content, str): | |
| content = [{"type": "text", "text": content}] | |
| alt_messages.append({"role": m.get("role", "user"), "content": content}) | |
| payload_alt = { | |
| "model": self.model, | |
| "messages": alt_messages, | |
| "temperature": 0.2, | |
| "top_p": 0.7, | |
| "max_tokens": max_tokens, | |
| "stream": False, | |
| } | |
| # Attempt primary, then fallback | |
| for payload in (payload_chat, payload_alt): | |
| try: | |
| resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s) | |
| if resp.status_code >= 400: | |
| # Log server message for debugging payload issues | |
| try: | |
| logger.error(f"[SafetyGuard] HTTP {resp.status_code}: {resp.text[:400]}") | |
| except Exception: | |
| pass | |
| resp.raise_for_status() | |
| data = resp.json() | |
| content = ( | |
| data.get("choices", [{}])[0] | |
| .get("message", {}) | |
| .get("content", "") | |
| .strip() | |
| ) | |
| if content: | |
| return content | |
| except Exception as e: | |
| # Try next payload shape | |
| logger.error(f"[SafetyGuard] Guard API call failed: {e}") | |
| continue | |
| # All attempts failed | |
| return "" | |
| def _parse_guard_reply(text: str) -> Tuple[bool, str]: | |
| """Parse guard reply; expect 'SAFE' or 'UNSAFE: <reason>' (case-insensitive).""" | |
| if not text: | |
| # Fail-open: treat as SAFE if guard unavailable to avoid false blocks | |
| return True, "guard_unavailable" | |
| t = text.strip() | |
| upper = t.upper() | |
| if upper.startswith("SAFE") and not upper.startswith("SAFEGUARD"): | |
| return True, "" | |
| if upper.startswith("UNSAFE"): | |
| # Extract reason after the first colon if present | |
| parts = t.split(":", 1) | |
| reason = parts[1].strip() if len(parts) > 1 else "policy violation" | |
| return False, reason | |
| # Fallback: treat unknown response as unsafe | |
| return False, t[:180] | |
| def _is_medical_query(self, query: str) -> bool: | |
| """Check if query is clearly medical in nature using comprehensive patterns.""" | |
| if not query: | |
| return False | |
| query_lower = query.lower() | |
| # Medical keyword categories | |
| medical_categories = { | |
| 'symptoms': [ | |
| 'symptom', 'pain', 'ache', 'hurt', 'sore', 'tender', 'stiff', 'numb', | |
| 'headache', 'migraine', 'fever', 'cough', 'cold', 'flu', 'sneeze', | |
| 'nausea', 'vomit', 'diarrhea', 'constipation', 'bloating', 'gas', | |
| 'dizziness', 'vertigo', 'fatigue', 'weakness', 'tired', 'exhausted', | |
| 'shortness of breath', 'wheezing', 'chest pain', 'heart palpitations', | |
| 'joint pain', 'muscle pain', 'back pain', 'neck pain', 'stomach pain', | |
| 'abdominal pain', 'pelvic pain', 'menstrual pain', 'cramps' | |
| ], | |
| 'conditions': [ | |
| 'disease', 'condition', 'disorder', 'syndrome', 'illness', 'sickness', | |
| 'infection', 'inflammation', 'allergy', 'asthma', 'diabetes', 'hypertension', | |
| 'depression', 'anxiety', 'stress', 'panic', 'phobia', 'ocd', 'ptsd', | |
| 'adhd', 'autism', 'dementia', 'alzheimer', 'parkinson', 'epilepsy', | |
| 'cancer', 'tumor', 'cancerous', 'malignant', 'benign', 'metastasis', | |
| 'heart disease', 'stroke', 'heart attack', 'coronary', 'arrhythmia', | |
| 'pneumonia', 'bronchitis', 'copd', 'emphysema', 'tuberculosis', | |
| 'migraine', 'headache', 'chronic migraine', 'cluster headache', | |
| 'tension headache', 'sinus headache', 'cure', 'treat', 'treatment' | |
| ], | |
| 'treatments': [ | |
| 'treatment', 'therapy', 'medication', 'medicine', 'drug', 'pill', 'tablet', | |
| 'injection', 'vaccine', 'immunization', 'surgery', 'operation', 'procedure', | |
| 'chemotherapy', 'radiation', 'physical therapy', 'occupational therapy', | |
| 'psychotherapy', 'counseling', 'rehabilitation', 'recovery', 'healing', | |
| 'prescription', 'dosage', 'side effects', 'contraindications' | |
| ], | |
| 'body_parts': [ | |
| 'head', 'brain', 'eye', 'ear', 'nose', 'mouth', 'throat', 'neck', | |
| 'chest', 'heart', 'lung', 'liver', 'kidney', 'stomach', 'intestine', | |
| 'back', 'spine', 'joint', 'muscle', 'bone', 'skin', 'hair', 'nail', | |
| 'arm', 'leg', 'hand', 'foot', 'finger', 'toe', 'pelvis', 'genital' | |
| ], | |
| 'medical_context': [ | |
| 'doctor', 'physician', 'nurse', 'specialist', 'surgeon', 'dentist', | |
| 'medical', 'health', 'healthcare', 'hospital', 'clinic', 'emergency', | |
| 'ambulance', 'paramedic', 'pharmacy', 'pharmacist', 'lab', 'test', | |
| 'diagnosis', 'prognosis', 'examination', 'checkup', 'screening', | |
| 'patient', 'case', 'history', 'medical history', 'family history' | |
| ], | |
| 'life_stages': [ | |
| 'pregnancy', 'pregnant', 'baby', 'infant', 'newborn', 'child', 'pediatric', | |
| 'teenager', 'adolescent', 'adult', 'elderly', 'senior', 'geriatric', | |
| 'menopause', 'puberty', 'aging', 'birth', 'delivery', 'miscarriage' | |
| ], | |
| 'vital_signs': [ | |
| 'blood pressure', 'heart rate', 'pulse', 'temperature', 'fever', | |
| 'respiratory rate', 'oxygen saturation', 'weight', 'height', 'bmi', | |
| 'blood sugar', 'glucose', 'cholesterol', 'hemoglobin', 'white blood cell' | |
| ] | |
| } | |
| # Check for medical keywords | |
| for category, keywords in medical_categories.items(): | |
| if any(keyword in query_lower for keyword in keywords): | |
| return True | |
| # Check for medical question patterns | |
| medical_patterns = [ | |
| r'\b(what|how|why|when|where)\s+(causes?|treats?|prevents?|symptoms?|signs?)\b', | |
| r'\b(is|are)\s+(.*?)\s+(dangerous|serious|harmful|safe|normal)\b', | |
| r'\b(should|can|may|might)\s+(i|you|we)\s+(take|use|do|avoid)\b', | |
| r'\b(diagnosis|diagnosed|symptoms|treatment|medicine|drug)\b', | |
| r'\b(medical|health|doctor|physician|hospital|clinic)\b', | |
| r'\b(pain|hurt|ache|sore|fever|cough|headache)\b', | |
| r'\b(which\s+medication|best\s+medication|how\s+to\s+cure|without\s+medications)\b', | |
| r'\b(chronic\s+migraine|migraine\s+treatment|migraine\s+cure)\b', | |
| r'\b(cure|treat|heal|relief|remedy|solution)\b' | |
| ] | |
| for pattern in medical_patterns: | |
| if re.search(pattern, query_lower): | |
| return True | |
| return False | |
| def check_user_query(self, user_query: str) -> Tuple[bool, str]: | |
| """Validate the user query is safe to process with medical context awareness.""" | |
| text = user_query or "" | |
| # For medical queries, be more permissive | |
| if self._is_medical_query(text): | |
| logger.info("[SafetyGuard] Medical query detected, skipping strict validation") | |
| return True, "medical_query" | |
| # If too long, validate each chunk; any UNSAFE makes overall UNSAFE | |
| for part in self._chunk_text(text): | |
| messages = [{"role": "user", "content": part}] | |
| reply = self._call_guard(messages, max_tokens=64) | |
| ok, reason = self._parse_guard_reply(reply) | |
| if not ok: | |
| return False, reason | |
| return True, "" | |
| def _detect_harmful_content(self, text: str) -> Tuple[bool, str]: | |
| """Detect harmful content using sophisticated pattern matching.""" | |
| if not text: | |
| return True, "" | |
| text_lower = text.lower() | |
| # First check if this is clearly medical content - be more permissive | |
| if self._is_medical_query(text): | |
| # For medical content, only check for truly dangerous patterns | |
| dangerous_medical_patterns = { | |
| 'suicide_self_harm': [ | |
| r'\b(kill\s+yourself|suicide|end\s+your\s+life|take\s+your\s+life)\b', | |
| r'\b(self\s*harm|self\s*injury|cut\s+yourself|hurt\s+yourself)\b', | |
| r'\b(overdose|poison\s+yourself|hang\s+yourself)\b' | |
| ], | |
| 'dangerous_medical_advice': [ | |
| r'\b(overdose\s+on|take\s+too\s+much|excessive\s+dosage)\b', | |
| r'\b(mix\s+drugs|combine\s+medications|contraindicated)\b', | |
| r'\b(stop\s+taking\s+prescribed|discontinue\s+medication)\b', | |
| r'\b(don\'t\s+call\s+911|avoid\s+emergency\s+room|ignore\s+severe\s+symptoms)\b' | |
| ], | |
| 'illegal_drugs': [ | |
| r'\b(cocaine|heroin|methamphetamine|ecstasy|lsd|marijuana\s+abuse)\b', | |
| r'\b(prescription\s+abuse|drug\s+dealing|illegal\s+substances)\b' | |
| ] | |
| } | |
| # Check only dangerous medical patterns | |
| for category, patterns in dangerous_medical_patterns.items(): | |
| for pattern in patterns: | |
| if re.search(pattern, text_lower): | |
| return False, f"contains {category}: {pattern}" | |
| # Allow medical content through | |
| return True, "medical_content" | |
| # For non-medical content, use stricter patterns | |
| critical_patterns = { | |
| 'suicide_self_harm': [ | |
| r'\b(kill\s+yourself|suicide|end\s+your\s+life|take\s+your\s+life)\b', | |
| r'\b(self\s*harm|self\s*injury|cut\s+yourself|hurt\s+yourself)\b', | |
| r'\b(overdose|poison\s+yourself|hang\s+yourself)\b' | |
| ], | |
| 'violence': [ | |
| r'\b(kill\s+someone|murder|assassinate|violence|harm\s+others)\b', | |
| r'\b(bomb|explosive|weapon|gun|knife)\b', | |
| r'\b(attack\s+(someone|people|others|innocent))\b' # More specific attack pattern | |
| ], | |
| 'illegal_drugs': [ | |
| r'\b(cocaine|heroin|methamphetamine|ecstasy|lsd|marijuana\s+abuse)\b', | |
| r'\b(prescription\s+abuse|drug\s+dealing|illegal\s+substances)\b' | |
| ], | |
| 'dangerous_medical': [ | |
| r'\b(overdose\s+on|take\s+too\s+much|excessive\s+dosage)\b', | |
| r'\b(mix\s+drugs|combine\s+medications|contraindicated)\b', | |
| r'\b(stop\s+taking\s+prescribed|discontinue\s+medication)\b' | |
| ] | |
| } | |
| # Check critical patterns | |
| for category, patterns in critical_patterns.items(): | |
| for pattern in patterns: | |
| if re.search(pattern, text_lower): | |
| return False, f"contains {category}: {pattern}" | |
| # Context-aware medical safety checks | |
| medical_safety_patterns = { | |
| 'dosage_warnings': [ | |
| r'\b(take\s+more\s+than\s+prescribed|exceed\s+recommended\s+dose)\b', | |
| r'\b(double\s+up\s+on\s+medication|take\s+extra\s+pills)\b' | |
| ], | |
| 'emergency_advice': [ | |
| r'\b(don\'t\s+call\s+911|avoid\s+emergency\s+room|ignore\s+severe\s+symptoms)\b', | |
| r'\b(self\s*treat\s+emergency|handle\s+at\s+home\s+when\s+critical)\b' | |
| ] | |
| } | |
| # Check medical safety patterns | |
| for category, patterns in medical_safety_patterns.items(): | |
| for pattern in patterns: | |
| if re.search(pattern, text_lower): | |
| return False, f"contains dangerous medical advice: {pattern}" | |
| return True, "" | |
| def _enhance_messages_with_context(self, messages: List[Dict]) -> List[Dict]: | |
| """Enhance messages with medical context awareness for better guard performance.""" | |
| if not messages: | |
| return messages | |
| # Check if any message contains medical content | |
| has_medical_content = any( | |
| self._is_medical_query(msg.get("content", "")) for msg in messages | |
| ) | |
| if has_medical_content: | |
| # Add medical context to the system prompt | |
| enhanced_messages = [] | |
| for i, msg in enumerate(messages): | |
| if i == 0 and msg.get("role") == "user": | |
| # Add medical context to first user message | |
| original_content = msg.get("content", "") | |
| enhanced_content = f"""MEDICAL CONTEXT: This is a medical consultation. The user is asking about health-related topics. | |
| ORIGINAL MESSAGE: {original_content} | |
| Please evaluate this message for safety, keeping in mind that: | |
| - Medical questions and health concerns are legitimate | |
| - General medical information sharing is acceptable | |
| - Only block content that is clearly harmful, dangerous, or illegal | |
| - Medical advice should be educational, not prescriptive""" | |
| enhanced_messages.append({ | |
| "role": msg["role"], | |
| "content": enhanced_content | |
| }) | |
| else: | |
| enhanced_messages.append(msg) | |
| return enhanced_messages | |
| return messages | |
| def _assess_risk_level(self, text: str) -> Tuple[str, float]: | |
| """Assess the risk level of content using multiple indicators.""" | |
| if not text: | |
| return "low", 0.0 | |
| text_lower = text.lower() | |
| # If this is medical content, be more lenient | |
| if self._is_medical_query(text): | |
| # For medical content, only flag truly dangerous patterns | |
| dangerous_medical_indicators = { | |
| 'high': [ | |
| 'suicide', 'kill yourself', 'end your life', 'self harm', | |
| 'overdose', 'poison yourself', 'illegal drugs', 'violence' | |
| ], | |
| 'medium': [ | |
| 'prescription abuse', 'excessive dosage', 'mix drugs', | |
| 'stop taking prescribed', 'ignore severe symptoms' | |
| ] | |
| } | |
| risk_score = 0.0 | |
| for level, indicators in dangerous_medical_indicators.items(): | |
| for indicator in indicators: | |
| if indicator in text_lower: | |
| if level == 'high': | |
| risk_score += 3.0 | |
| elif level == 'medium': | |
| risk_score += 1.5 | |
| # Normalize score for medical content (more lenient) | |
| risk_score = min(risk_score / 15.0, 1.0) | |
| if risk_score >= 0.6: | |
| return "high", risk_score | |
| elif risk_score >= 0.2: | |
| return "medium", risk_score | |
| else: | |
| return "low", risk_score | |
| # For non-medical content, use original risk assessment | |
| risk_indicators = { | |
| 'high': [ | |
| 'suicide', 'kill yourself', 'end your life', 'self harm', | |
| 'overdose', 'poison', 'illegal drugs', 'violence', 'harm others' | |
| ], | |
| 'medium': [ | |
| 'prescription abuse', 'excessive dosage', 'mix drugs', | |
| 'stop medication', 'ignore symptoms', 'avoid doctor' | |
| ], | |
| 'low': [ | |
| 'pain', 'headache', 'fever', 'cough', 'treatment', | |
| 'medicine', 'doctor', 'hospital', 'symptoms' | |
| ] | |
| } | |
| risk_score = 0.0 | |
| for level, indicators in risk_indicators.items(): | |
| for indicator in indicators: | |
| if indicator in text_lower: | |
| if level == 'high': | |
| risk_score += 3.0 | |
| elif level == 'medium': | |
| risk_score += 1.5 | |
| else: | |
| risk_score += 0.5 | |
| # Normalize score | |
| risk_score = min(risk_score / 10.0, 1.0) | |
| if risk_score >= 0.7: | |
| return "high", risk_score | |
| elif risk_score >= 0.3: | |
| return "medium", risk_score | |
| else: | |
| return "low", risk_score | |
| def check_model_answer(self, user_query: str, model_answer: str) -> Tuple[bool, str]: | |
| """Validate the model's answer is safe with medical context awareness.""" | |
| uq = user_query or "" | |
| ans = model_answer or "" | |
| # Assess risk level first | |
| risk_level, risk_score = self._assess_risk_level(ans) | |
| logger.info(f"[SafetyGuard] Risk assessment: {risk_level} (score: {risk_score:.2f})") | |
| # Always check for harmful content first | |
| is_safe, reason = self._detect_harmful_content(ans) | |
| if not is_safe: | |
| return False, reason | |
| # For high-risk content, always use strict validation | |
| if risk_level == "high": | |
| logger.warning("[SafetyGuard] High-risk content detected, using strict validation") | |
| user_parts = self._chunk_text(uq, chunk_size=2000) | |
| user_context = user_parts[0] if user_parts else "" | |
| for ans_part in self._chunk_text(ans): | |
| messages = [ | |
| {"role": "user", "content": user_context}, | |
| {"role": "assistant", "content": ans_part}, | |
| ] | |
| reply = self._call_guard(messages, max_tokens=96) | |
| ok, reason = self._parse_guard_reply(reply) | |
| if not ok: | |
| return False, reason | |
| return True, "high_risk_validated" | |
| # For medical queries and answers, use relaxed validation | |
| if self._is_medical_query(uq) or self._is_medical_query(ans): | |
| logger.info("[SafetyGuard] Medical content detected, using relaxed validation") | |
| return True, "medical_content" | |
| # For medium-risk non-medical content, use guard validation | |
| if risk_level == "medium": | |
| logger.info("[SafetyGuard] Medium-risk content detected, using guard validation") | |
| user_parts = self._chunk_text(uq, chunk_size=2000) | |
| user_context = user_parts[0] if user_parts else "" | |
| for ans_part in self._chunk_text(ans): | |
| messages = [ | |
| {"role": "user", "content": user_context}, | |
| {"role": "assistant", "content": ans_part}, | |
| ] | |
| reply = self._call_guard(messages, max_tokens=96) | |
| ok, reason = self._parse_guard_reply(reply) | |
| if not ok: | |
| return False, reason | |
| return True, "medium_risk_validated" | |
| # For low-risk content, allow through | |
| logger.info("[SafetyGuard] Low-risk content detected, allowing through") | |
| return True, "low_risk" | |
| # Global instance (optional convenience) | |
| safety_guard = SafetyGuard() |