import os import re import requests import logging from typing import Tuple, List, Dict logger = logging.getLogger(__name__) class SafetyGuard: """ Wrapper around NVIDIA Llama Guard (meta/llama-guard-4-12b) hosted at https://integrate.api.nvidia.com/v1/chat/completions Exposes helpers to validate: - user input safety - model output safety (in context of the user question) """ def __init__(self): self.api_key = os.getenv("NVIDIA_URI") if not self.api_key: raise ValueError("NVIDIA_URI environment variable not set for SafetyGuard") self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions" self.model = "meta/llama-guard-4-12b" self.timeout_s = 30 @staticmethod def _chunk_text(text: str, chunk_size: int = 2800, overlap: int = 200) -> List[str]: """Chunk long text to keep request payloads small enough for the guard. Uses character-based approximation with small overlap. """ if not text: return [""] n = len(text) if n <= chunk_size: return [text] chunks: List[str] = [] start = 0 while start < n: end = min(start + chunk_size, n) chunks.append(text[start:end]) if end == n: break start = max(0, end - overlap) return chunks def _call_guard(self, messages: List[Dict], max_tokens: int = 512) -> str: # Enhance messages with medical context if detected enhanced_messages = self._enhance_messages_with_context(messages) headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "application/json", } # Try OpenAI-compatible schema first payload_chat = { "model": self.model, "messages": enhanced_messages, "temperature": 0.2, "top_p": 0.7, "max_tokens": max_tokens, "stream": False, } # Alternative schema (some NVIDIA deployments require message content objects) alt_messages = [] for m in enhanced_messages: content = m.get("content", "") if isinstance(content, str): content = [{"type": "text", "text": content}] alt_messages.append({"role": m.get("role", "user"), "content": content}) payload_alt = { "model": self.model, "messages": alt_messages, "temperature": 0.2, "top_p": 0.7, "max_tokens": max_tokens, "stream": False, } # Attempt primary, then fallback for payload in (payload_chat, payload_alt): try: resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s) if resp.status_code >= 400: # Log server message for debugging payload issues try: logger.error(f"[SafetyGuard] HTTP {resp.status_code}: {resp.text[:400]}") except Exception: pass resp.raise_for_status() data = resp.json() content = ( data.get("choices", [{}])[0] .get("message", {}) .get("content", "") .strip() ) if content: return content except Exception as e: # Try next payload shape logger.error(f"[SafetyGuard] Guard API call failed: {e}") continue # All attempts failed return "" @staticmethod def _parse_guard_reply(text: str) -> Tuple[bool, str]: """Parse guard reply; expect 'SAFE' or 'UNSAFE: ' (case-insensitive).""" if not text: # Fail-open: treat as SAFE if guard unavailable to avoid false blocks return True, "guard_unavailable" t = text.strip() upper = t.upper() if upper.startswith("SAFE") and not upper.startswith("SAFEGUARD"): return True, "" if upper.startswith("UNSAFE"): # Extract reason after the first colon if present parts = t.split(":", 1) reason = parts[1].strip() if len(parts) > 1 else "policy violation" return False, reason # Fallback: treat unknown response as unsafe return False, t[:180] def _is_medical_query(self, query: str) -> bool: """Check if query is clearly medical in nature using comprehensive patterns.""" if not query: return False query_lower = query.lower() # Medical keyword categories medical_categories = { 'symptoms': [ 'symptom', 'pain', 'ache', 'hurt', 'sore', 'tender', 'stiff', 'numb', 'headache', 'migraine', 'fever', 'cough', 'cold', 'flu', 'sneeze', 'nausea', 'vomit', 'diarrhea', 'constipation', 'bloating', 'gas', 'dizziness', 'vertigo', 'fatigue', 'weakness', 'tired', 'exhausted', 'shortness of breath', 'wheezing', 'chest pain', 'heart palpitations', 'joint pain', 'muscle pain', 'back pain', 'neck pain', 'stomach pain', 'abdominal pain', 'pelvic pain', 'menstrual pain', 'cramps' ], 'conditions': [ 'disease', 'condition', 'disorder', 'syndrome', 'illness', 'sickness', 'infection', 'inflammation', 'allergy', 'asthma', 'diabetes', 'hypertension', 'depression', 'anxiety', 'stress', 'panic', 'phobia', 'ocd', 'ptsd', 'adhd', 'autism', 'dementia', 'alzheimer', 'parkinson', 'epilepsy', 'cancer', 'tumor', 'cancerous', 'malignant', 'benign', 'metastasis', 'heart disease', 'stroke', 'heart attack', 'coronary', 'arrhythmia', 'pneumonia', 'bronchitis', 'copd', 'emphysema', 'tuberculosis', 'migraine', 'headache', 'chronic migraine', 'cluster headache', 'tension headache', 'sinus headache', 'cure', 'treat', 'treatment' ], 'treatments': [ 'treatment', 'therapy', 'medication', 'medicine', 'drug', 'pill', 'tablet', 'injection', 'vaccine', 'immunization', 'surgery', 'operation', 'procedure', 'chemotherapy', 'radiation', 'physical therapy', 'occupational therapy', 'psychotherapy', 'counseling', 'rehabilitation', 'recovery', 'healing', 'prescription', 'dosage', 'side effects', 'contraindications' ], 'body_parts': [ 'head', 'brain', 'eye', 'ear', 'nose', 'mouth', 'throat', 'neck', 'chest', 'heart', 'lung', 'liver', 'kidney', 'stomach', 'intestine', 'back', 'spine', 'joint', 'muscle', 'bone', 'skin', 'hair', 'nail', 'arm', 'leg', 'hand', 'foot', 'finger', 'toe', 'pelvis', 'genital' ], 'medical_context': [ 'doctor', 'physician', 'nurse', 'specialist', 'surgeon', 'dentist', 'medical', 'health', 'healthcare', 'hospital', 'clinic', 'emergency', 'ambulance', 'paramedic', 'pharmacy', 'pharmacist', 'lab', 'test', 'diagnosis', 'prognosis', 'examination', 'checkup', 'screening', 'patient', 'case', 'history', 'medical history', 'family history' ], 'life_stages': [ 'pregnancy', 'pregnant', 'baby', 'infant', 'newborn', 'child', 'pediatric', 'teenager', 'adolescent', 'adult', 'elderly', 'senior', 'geriatric', 'menopause', 'puberty', 'aging', 'birth', 'delivery', 'miscarriage' ], 'vital_signs': [ 'blood pressure', 'heart rate', 'pulse', 'temperature', 'fever', 'respiratory rate', 'oxygen saturation', 'weight', 'height', 'bmi', 'blood sugar', 'glucose', 'cholesterol', 'hemoglobin', 'white blood cell' ] } # Check for medical keywords for category, keywords in medical_categories.items(): if any(keyword in query_lower for keyword in keywords): return True # Check for medical question patterns medical_patterns = [ r'\b(what|how|why|when|where)\s+(causes?|treats?|prevents?|symptoms?|signs?)\b', r'\b(is|are)\s+(.*?)\s+(dangerous|serious|harmful|safe|normal)\b', r'\b(should|can|may|might)\s+(i|you|we)\s+(take|use|do|avoid)\b', r'\b(diagnosis|diagnosed|symptoms|treatment|medicine|drug)\b', r'\b(medical|health|doctor|physician|hospital|clinic)\b', r'\b(pain|hurt|ache|sore|fever|cough|headache)\b', r'\b(which\s+medication|best\s+medication|how\s+to\s+cure|without\s+medications)\b', r'\b(chronic\s+migraine|migraine\s+treatment|migraine\s+cure)\b', r'\b(cure|treat|heal|relief|remedy|solution)\b' ] for pattern in medical_patterns: if re.search(pattern, query_lower): return True return False def check_user_query(self, user_query: str) -> Tuple[bool, str]: """Validate the user query is safe to process with medical context awareness.""" text = user_query or "" # For medical queries, be more permissive if self._is_medical_query(text): logger.info("[SafetyGuard] Medical query detected, skipping strict validation") return True, "medical_query" # If too long, validate each chunk; any UNSAFE makes overall UNSAFE for part in self._chunk_text(text): messages = [{"role": "user", "content": part}] reply = self._call_guard(messages, max_tokens=64) ok, reason = self._parse_guard_reply(reply) if not ok: return False, reason return True, "" def _detect_harmful_content(self, text: str) -> Tuple[bool, str]: """Detect harmful content using sophisticated pattern matching.""" if not text: return True, "" text_lower = text.lower() # First check if this is clearly medical content - be more permissive if self._is_medical_query(text): # For medical content, only check for truly dangerous patterns dangerous_medical_patterns = { 'suicide_self_harm': [ r'\b(kill\s+yourself|suicide|end\s+your\s+life|take\s+your\s+life)\b', r'\b(self\s*harm|self\s*injury|cut\s+yourself|hurt\s+yourself)\b', r'\b(overdose|poison\s+yourself|hang\s+yourself)\b' ], 'dangerous_medical_advice': [ r'\b(overdose\s+on|take\s+too\s+much|excessive\s+dosage)\b', r'\b(mix\s+drugs|combine\s+medications|contraindicated)\b', r'\b(stop\s+taking\s+prescribed|discontinue\s+medication)\b', r'\b(don\'t\s+call\s+911|avoid\s+emergency\s+room|ignore\s+severe\s+symptoms)\b' ], 'illegal_drugs': [ r'\b(cocaine|heroin|methamphetamine|ecstasy|lsd|marijuana\s+abuse)\b', r'\b(prescription\s+abuse|drug\s+dealing|illegal\s+substances)\b' ] } # Check only dangerous medical patterns for category, patterns in dangerous_medical_patterns.items(): for pattern in patterns: if re.search(pattern, text_lower): return False, f"contains {category}: {pattern}" # Allow medical content through return True, "medical_content" # For non-medical content, use stricter patterns critical_patterns = { 'suicide_self_harm': [ r'\b(kill\s+yourself|suicide|end\s+your\s+life|take\s+your\s+life)\b', r'\b(self\s*harm|self\s*injury|cut\s+yourself|hurt\s+yourself)\b', r'\b(overdose|poison\s+yourself|hang\s+yourself)\b' ], 'violence': [ r'\b(kill\s+someone|murder|assassinate|violence|harm\s+others)\b', r'\b(bomb|explosive|weapon|gun|knife)\b', r'\b(attack\s+(someone|people|others|innocent))\b' # More specific attack pattern ], 'illegal_drugs': [ r'\b(cocaine|heroin|methamphetamine|ecstasy|lsd|marijuana\s+abuse)\b', r'\b(prescription\s+abuse|drug\s+dealing|illegal\s+substances)\b' ], 'dangerous_medical': [ r'\b(overdose\s+on|take\s+too\s+much|excessive\s+dosage)\b', r'\b(mix\s+drugs|combine\s+medications|contraindicated)\b', r'\b(stop\s+taking\s+prescribed|discontinue\s+medication)\b' ] } # Check critical patterns for category, patterns in critical_patterns.items(): for pattern in patterns: if re.search(pattern, text_lower): return False, f"contains {category}: {pattern}" # Context-aware medical safety checks medical_safety_patterns = { 'dosage_warnings': [ r'\b(take\s+more\s+than\s+prescribed|exceed\s+recommended\s+dose)\b', r'\b(double\s+up\s+on\s+medication|take\s+extra\s+pills)\b' ], 'emergency_advice': [ r'\b(don\'t\s+call\s+911|avoid\s+emergency\s+room|ignore\s+severe\s+symptoms)\b', r'\b(self\s*treat\s+emergency|handle\s+at\s+home\s+when\s+critical)\b' ] } # Check medical safety patterns for category, patterns in medical_safety_patterns.items(): for pattern in patterns: if re.search(pattern, text_lower): return False, f"contains dangerous medical advice: {pattern}" return True, "" def _enhance_messages_with_context(self, messages: List[Dict]) -> List[Dict]: """Enhance messages with medical context awareness for better guard performance.""" if not messages: return messages # Check if any message contains medical content has_medical_content = any( self._is_medical_query(msg.get("content", "")) for msg in messages ) if has_medical_content: # Add medical context to the system prompt enhanced_messages = [] for i, msg in enumerate(messages): if i == 0 and msg.get("role") == "user": # Add medical context to first user message original_content = msg.get("content", "") enhanced_content = f"""MEDICAL CONTEXT: This is a medical consultation. The user is asking about health-related topics. ORIGINAL MESSAGE: {original_content} Please evaluate this message for safety, keeping in mind that: - Medical questions and health concerns are legitimate - General medical information sharing is acceptable - Only block content that is clearly harmful, dangerous, or illegal - Medical advice should be educational, not prescriptive""" enhanced_messages.append({ "role": msg["role"], "content": enhanced_content }) else: enhanced_messages.append(msg) return enhanced_messages return messages def _assess_risk_level(self, text: str) -> Tuple[str, float]: """Assess the risk level of content using multiple indicators.""" if not text: return "low", 0.0 text_lower = text.lower() # If this is medical content, be more lenient if self._is_medical_query(text): # For medical content, only flag truly dangerous patterns dangerous_medical_indicators = { 'high': [ 'suicide', 'kill yourself', 'end your life', 'self harm', 'overdose', 'poison yourself', 'illegal drugs', 'violence' ], 'medium': [ 'prescription abuse', 'excessive dosage', 'mix drugs', 'stop taking prescribed', 'ignore severe symptoms' ] } risk_score = 0.0 for level, indicators in dangerous_medical_indicators.items(): for indicator in indicators: if indicator in text_lower: if level == 'high': risk_score += 3.0 elif level == 'medium': risk_score += 1.5 # Normalize score for medical content (more lenient) risk_score = min(risk_score / 15.0, 1.0) if risk_score >= 0.6: return "high", risk_score elif risk_score >= 0.2: return "medium", risk_score else: return "low", risk_score # For non-medical content, use original risk assessment risk_indicators = { 'high': [ 'suicide', 'kill yourself', 'end your life', 'self harm', 'overdose', 'poison', 'illegal drugs', 'violence', 'harm others' ], 'medium': [ 'prescription abuse', 'excessive dosage', 'mix drugs', 'stop medication', 'ignore symptoms', 'avoid doctor' ], 'low': [ 'pain', 'headache', 'fever', 'cough', 'treatment', 'medicine', 'doctor', 'hospital', 'symptoms' ] } risk_score = 0.0 for level, indicators in risk_indicators.items(): for indicator in indicators: if indicator in text_lower: if level == 'high': risk_score += 3.0 elif level == 'medium': risk_score += 1.5 else: risk_score += 0.5 # Normalize score risk_score = min(risk_score / 10.0, 1.0) if risk_score >= 0.7: return "high", risk_score elif risk_score >= 0.3: return "medium", risk_score else: return "low", risk_score def check_model_answer(self, user_query: str, model_answer: str) -> Tuple[bool, str]: """Validate the model's answer is safe with medical context awareness.""" uq = user_query or "" ans = model_answer or "" # Assess risk level first risk_level, risk_score = self._assess_risk_level(ans) logger.info(f"[SafetyGuard] Risk assessment: {risk_level} (score: {risk_score:.2f})") # Always check for harmful content first is_safe, reason = self._detect_harmful_content(ans) if not is_safe: return False, reason # For high-risk content, always use strict validation if risk_level == "high": logger.warning("[SafetyGuard] High-risk content detected, using strict validation") user_parts = self._chunk_text(uq, chunk_size=2000) user_context = user_parts[0] if user_parts else "" for ans_part in self._chunk_text(ans): messages = [ {"role": "user", "content": user_context}, {"role": "assistant", "content": ans_part}, ] reply = self._call_guard(messages, max_tokens=96) ok, reason = self._parse_guard_reply(reply) if not ok: return False, reason return True, "high_risk_validated" # For medical queries and answers, use relaxed validation if self._is_medical_query(uq) or self._is_medical_query(ans): logger.info("[SafetyGuard] Medical content detected, using relaxed validation") return True, "medical_content" # For medium-risk non-medical content, use guard validation if risk_level == "medium": logger.info("[SafetyGuard] Medium-risk content detected, using guard validation") user_parts = self._chunk_text(uq, chunk_size=2000) user_context = user_parts[0] if user_parts else "" for ans_part in self._chunk_text(ans): messages = [ {"role": "user", "content": user_context}, {"role": "assistant", "content": ans_part}, ] reply = self._call_guard(messages, max_tokens=96) ok, reason = self._parse_guard_reply(reply) if not ok: return False, reason return True, "medium_risk_validated" # For low-risk content, allow through logger.info("[SafetyGuard] Low-risk content detected, allowing through") return True, "low_risk" # Global instance (optional convenience) safety_guard = SafetyGuard()