Spaces:
Sleeping
Sleeping
File size: 21,657 Bytes
dbcfd39 9bc16e9 dbcfd39 ec1346d dbcfd39 b8eaabe dbcfd39 66240a5 dbcfd39 66240a5 dbcfd39 b8eaabe ec1346d dbcfd39 ec1346d dbcfd39 66240a5 b8eaabe 66240a5 ec1346d 66240a5 ec1346d 66240a5 dbcfd39 66240a5 dbcfd39 b8eaabe 4364204 b8eaabe 4364204 b8eaabe dbcfd39 b8eaabe ec1346d b8eaabe ec1346d dbcfd39 b8eaabe 4364204 b8eaabe 4364204 b8eaabe 4364204 b8eaabe dbcfd39 b8eaabe ec1346d b8eaabe dbcfd39 66240a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 |
import os
import re
import requests
import logging
from typing import Tuple, List, Dict
logger = logging.getLogger(__name__)
class SafetyGuard:
"""
Wrapper around NVIDIA Llama Guard (meta/llama-guard-4-12b) hosted at
https://integrate.api.nvidia.com/v1/chat/completions
Exposes helpers to validate:
- user input safety
- model output safety (in context of the user question)
"""
def __init__(self):
self.api_key = os.getenv("NVIDIA_URI")
if not self.api_key:
raise ValueError("NVIDIA_URI environment variable not set for SafetyGuard")
self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
self.model = "meta/llama-guard-4-12b"
self.timeout_s = 30
@staticmethod
def _chunk_text(text: str, chunk_size: int = 2800, overlap: int = 200) -> List[str]:
"""Chunk long text to keep request payloads small enough for the guard.
Uses character-based approximation with small overlap.
"""
if not text:
return [""]
n = len(text)
if n <= chunk_size:
return [text]
chunks: List[str] = []
start = 0
while start < n:
end = min(start + chunk_size, n)
chunks.append(text[start:end])
if end == n:
break
start = max(0, end - overlap)
return chunks
def _call_guard(self, messages: List[Dict], max_tokens: int = 512) -> str:
# Enhance messages with medical context if detected
enhanced_messages = self._enhance_messages_with_context(messages)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# Try OpenAI-compatible schema first
payload_chat = {
"model": self.model,
"messages": enhanced_messages,
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": max_tokens,
"stream": False,
}
# Alternative schema (some NVIDIA deployments require message content objects)
alt_messages = []
for m in enhanced_messages:
content = m.get("content", "")
if isinstance(content, str):
content = [{"type": "text", "text": content}]
alt_messages.append({"role": m.get("role", "user"), "content": content})
payload_alt = {
"model": self.model,
"messages": alt_messages,
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": max_tokens,
"stream": False,
}
# Attempt primary, then fallback
for payload in (payload_chat, payload_alt):
try:
resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s)
if resp.status_code >= 400:
# Log server message for debugging payload issues
try:
logger.error(f"[SafetyGuard] HTTP {resp.status_code}: {resp.text[:400]}")
except Exception:
pass
resp.raise_for_status()
data = resp.json()
content = (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
)
if content:
return content
except Exception as e:
# Try next payload shape
logger.error(f"[SafetyGuard] Guard API call failed: {e}")
continue
# All attempts failed
return ""
@staticmethod
def _parse_guard_reply(text: str) -> Tuple[bool, str]:
"""Parse guard reply; expect 'SAFE' or 'UNSAFE: <reason>' (case-insensitive)."""
if not text:
# Fail-open: treat as SAFE if guard unavailable to avoid false blocks
return True, "guard_unavailable"
t = text.strip()
upper = t.upper()
if upper.startswith("SAFE") and not upper.startswith("SAFEGUARD"):
return True, ""
if upper.startswith("UNSAFE"):
# Extract reason after the first colon if present
parts = t.split(":", 1)
reason = parts[1].strip() if len(parts) > 1 else "policy violation"
return False, reason
# Fallback: treat unknown response as unsafe
return False, t[:180]
def _is_medical_query(self, query: str) -> bool:
"""Check if query is clearly medical in nature using comprehensive patterns."""
if not query:
return False
query_lower = query.lower()
# Medical keyword categories
medical_categories = {
'symptoms': [
'symptom', 'pain', 'ache', 'hurt', 'sore', 'tender', 'stiff', 'numb',
'headache', 'migraine', 'fever', 'cough', 'cold', 'flu', 'sneeze',
'nausea', 'vomit', 'diarrhea', 'constipation', 'bloating', 'gas',
'dizziness', 'vertigo', 'fatigue', 'weakness', 'tired', 'exhausted',
'shortness of breath', 'wheezing', 'chest pain', 'heart palpitations',
'joint pain', 'muscle pain', 'back pain', 'neck pain', 'stomach pain',
'abdominal pain', 'pelvic pain', 'menstrual pain', 'cramps'
],
'conditions': [
'disease', 'condition', 'disorder', 'syndrome', 'illness', 'sickness',
'infection', 'inflammation', 'allergy', 'asthma', 'diabetes', 'hypertension',
'depression', 'anxiety', 'stress', 'panic', 'phobia', 'ocd', 'ptsd',
'adhd', 'autism', 'dementia', 'alzheimer', 'parkinson', 'epilepsy',
'cancer', 'tumor', 'cancerous', 'malignant', 'benign', 'metastasis',
'heart disease', 'stroke', 'heart attack', 'coronary', 'arrhythmia',
'pneumonia', 'bronchitis', 'copd', 'emphysema', 'tuberculosis',
'migraine', 'headache', 'chronic migraine', 'cluster headache',
'tension headache', 'sinus headache', 'cure', 'treat', 'treatment'
],
'treatments': [
'treatment', 'therapy', 'medication', 'medicine', 'drug', 'pill', 'tablet',
'injection', 'vaccine', 'immunization', 'surgery', 'operation', 'procedure',
'chemotherapy', 'radiation', 'physical therapy', 'occupational therapy',
'psychotherapy', 'counseling', 'rehabilitation', 'recovery', 'healing',
'prescription', 'dosage', 'side effects', 'contraindications'
],
'body_parts': [
'head', 'brain', 'eye', 'ear', 'nose', 'mouth', 'throat', 'neck',
'chest', 'heart', 'lung', 'liver', 'kidney', 'stomach', 'intestine',
'back', 'spine', 'joint', 'muscle', 'bone', 'skin', 'hair', 'nail',
'arm', 'leg', 'hand', 'foot', 'finger', 'toe', 'pelvis', 'genital'
],
'medical_context': [
'doctor', 'physician', 'nurse', 'specialist', 'surgeon', 'dentist',
'medical', 'health', 'healthcare', 'hospital', 'clinic', 'emergency',
'ambulance', 'paramedic', 'pharmacy', 'pharmacist', 'lab', 'test',
'diagnosis', 'prognosis', 'examination', 'checkup', 'screening',
'patient', 'case', 'history', 'medical history', 'family history'
],
'life_stages': [
'pregnancy', 'pregnant', 'baby', 'infant', 'newborn', 'child', 'pediatric',
'teenager', 'adolescent', 'adult', 'elderly', 'senior', 'geriatric',
'menopause', 'puberty', 'aging', 'birth', 'delivery', 'miscarriage'
],
'vital_signs': [
'blood pressure', 'heart rate', 'pulse', 'temperature', 'fever',
'respiratory rate', 'oxygen saturation', 'weight', 'height', 'bmi',
'blood sugar', 'glucose', 'cholesterol', 'hemoglobin', 'white blood cell'
]
}
# Check for medical keywords
for category, keywords in medical_categories.items():
if any(keyword in query_lower for keyword in keywords):
return True
# Check for medical question patterns
medical_patterns = [
r'\b(what|how|why|when|where)\s+(causes?|treats?|prevents?|symptoms?|signs?)\b',
r'\b(is|are)\s+(.*?)\s+(dangerous|serious|harmful|safe|normal)\b',
r'\b(should|can|may|might)\s+(i|you|we)\s+(take|use|do|avoid)\b',
r'\b(diagnosis|diagnosed|symptoms|treatment|medicine|drug)\b',
r'\b(medical|health|doctor|physician|hospital|clinic)\b',
r'\b(pain|hurt|ache|sore|fever|cough|headache)\b',
r'\b(which\s+medication|best\s+medication|how\s+to\s+cure|without\s+medications)\b',
r'\b(chronic\s+migraine|migraine\s+treatment|migraine\s+cure)\b',
r'\b(cure|treat|heal|relief|remedy|solution)\b'
]
for pattern in medical_patterns:
if re.search(pattern, query_lower):
return True
return False
def check_user_query(self, user_query: str) -> Tuple[bool, str]:
"""Validate the user query is safe to process with medical context awareness."""
text = user_query or ""
# For medical queries, be more permissive
if self._is_medical_query(text):
logger.info("[SafetyGuard] Medical query detected, skipping strict validation")
return True, "medical_query"
# If too long, validate each chunk; any UNSAFE makes overall UNSAFE
for part in self._chunk_text(text):
messages = [{"role": "user", "content": part}]
reply = self._call_guard(messages, max_tokens=64)
ok, reason = self._parse_guard_reply(reply)
if not ok:
return False, reason
return True, ""
def _detect_harmful_content(self, text: str) -> Tuple[bool, str]:
"""Detect harmful content using sophisticated pattern matching."""
if not text:
return True, ""
text_lower = text.lower()
# First check if this is clearly medical content - be more permissive
if self._is_medical_query(text):
# For medical content, only check for truly dangerous patterns
dangerous_medical_patterns = {
'suicide_self_harm': [
r'\b(kill\s+yourself|suicide|end\s+your\s+life|take\s+your\s+life)\b',
r'\b(self\s*harm|self\s*injury|cut\s+yourself|hurt\s+yourself)\b',
r'\b(overdose|poison\s+yourself|hang\s+yourself)\b'
],
'dangerous_medical_advice': [
r'\b(overdose\s+on|take\s+too\s+much|excessive\s+dosage)\b',
r'\b(mix\s+drugs|combine\s+medications|contraindicated)\b',
r'\b(stop\s+taking\s+prescribed|discontinue\s+medication)\b',
r'\b(don\'t\s+call\s+911|avoid\s+emergency\s+room|ignore\s+severe\s+symptoms)\b'
],
'illegal_drugs': [
r'\b(cocaine|heroin|methamphetamine|ecstasy|lsd|marijuana\s+abuse)\b',
r'\b(prescription\s+abuse|drug\s+dealing|illegal\s+substances)\b'
]
}
# Check only dangerous medical patterns
for category, patterns in dangerous_medical_patterns.items():
for pattern in patterns:
if re.search(pattern, text_lower):
return False, f"contains {category}: {pattern}"
# Allow medical content through
return True, "medical_content"
# For non-medical content, use stricter patterns
critical_patterns = {
'suicide_self_harm': [
r'\b(kill\s+yourself|suicide|end\s+your\s+life|take\s+your\s+life)\b',
r'\b(self\s*harm|self\s*injury|cut\s+yourself|hurt\s+yourself)\b',
r'\b(overdose|poison\s+yourself|hang\s+yourself)\b'
],
'violence': [
r'\b(kill\s+someone|murder|assassinate|violence|harm\s+others)\b',
r'\b(bomb|explosive|weapon|gun|knife)\b',
r'\b(attack\s+(someone|people|others|innocent))\b' # More specific attack pattern
],
'illegal_drugs': [
r'\b(cocaine|heroin|methamphetamine|ecstasy|lsd|marijuana\s+abuse)\b',
r'\b(prescription\s+abuse|drug\s+dealing|illegal\s+substances)\b'
],
'dangerous_medical': [
r'\b(overdose\s+on|take\s+too\s+much|excessive\s+dosage)\b',
r'\b(mix\s+drugs|combine\s+medications|contraindicated)\b',
r'\b(stop\s+taking\s+prescribed|discontinue\s+medication)\b'
]
}
# Check critical patterns
for category, patterns in critical_patterns.items():
for pattern in patterns:
if re.search(pattern, text_lower):
return False, f"contains {category}: {pattern}"
# Context-aware medical safety checks
medical_safety_patterns = {
'dosage_warnings': [
r'\b(take\s+more\s+than\s+prescribed|exceed\s+recommended\s+dose)\b',
r'\b(double\s+up\s+on\s+medication|take\s+extra\s+pills)\b'
],
'emergency_advice': [
r'\b(don\'t\s+call\s+911|avoid\s+emergency\s+room|ignore\s+severe\s+symptoms)\b',
r'\b(self\s*treat\s+emergency|handle\s+at\s+home\s+when\s+critical)\b'
]
}
# Check medical safety patterns
for category, patterns in medical_safety_patterns.items():
for pattern in patterns:
if re.search(pattern, text_lower):
return False, f"contains dangerous medical advice: {pattern}"
return True, ""
def _enhance_messages_with_context(self, messages: List[Dict]) -> List[Dict]:
"""Enhance messages with medical context awareness for better guard performance."""
if not messages:
return messages
# Check if any message contains medical content
has_medical_content = any(
self._is_medical_query(msg.get("content", "")) for msg in messages
)
if has_medical_content:
# Add medical context to the system prompt
enhanced_messages = []
for i, msg in enumerate(messages):
if i == 0 and msg.get("role") == "user":
# Add medical context to first user message
original_content = msg.get("content", "")
enhanced_content = f"""MEDICAL CONTEXT: This is a medical consultation. The user is asking about health-related topics.
ORIGINAL MESSAGE: {original_content}
Please evaluate this message for safety, keeping in mind that:
- Medical questions and health concerns are legitimate
- General medical information sharing is acceptable
- Only block content that is clearly harmful, dangerous, or illegal
- Medical advice should be educational, not prescriptive"""
enhanced_messages.append({
"role": msg["role"],
"content": enhanced_content
})
else:
enhanced_messages.append(msg)
return enhanced_messages
return messages
def _assess_risk_level(self, text: str) -> Tuple[str, float]:
"""Assess the risk level of content using multiple indicators."""
if not text:
return "low", 0.0
text_lower = text.lower()
# If this is medical content, be more lenient
if self._is_medical_query(text):
# For medical content, only flag truly dangerous patterns
dangerous_medical_indicators = {
'high': [
'suicide', 'kill yourself', 'end your life', 'self harm',
'overdose', 'poison yourself', 'illegal drugs', 'violence'
],
'medium': [
'prescription abuse', 'excessive dosage', 'mix drugs',
'stop taking prescribed', 'ignore severe symptoms'
]
}
risk_score = 0.0
for level, indicators in dangerous_medical_indicators.items():
for indicator in indicators:
if indicator in text_lower:
if level == 'high':
risk_score += 3.0
elif level == 'medium':
risk_score += 1.5
# Normalize score for medical content (more lenient)
risk_score = min(risk_score / 15.0, 1.0)
if risk_score >= 0.6:
return "high", risk_score
elif risk_score >= 0.2:
return "medium", risk_score
else:
return "low", risk_score
# For non-medical content, use original risk assessment
risk_indicators = {
'high': [
'suicide', 'kill yourself', 'end your life', 'self harm',
'overdose', 'poison', 'illegal drugs', 'violence', 'harm others'
],
'medium': [
'prescription abuse', 'excessive dosage', 'mix drugs',
'stop medication', 'ignore symptoms', 'avoid doctor'
],
'low': [
'pain', 'headache', 'fever', 'cough', 'treatment',
'medicine', 'doctor', 'hospital', 'symptoms'
]
}
risk_score = 0.0
for level, indicators in risk_indicators.items():
for indicator in indicators:
if indicator in text_lower:
if level == 'high':
risk_score += 3.0
elif level == 'medium':
risk_score += 1.5
else:
risk_score += 0.5
# Normalize score
risk_score = min(risk_score / 10.0, 1.0)
if risk_score >= 0.7:
return "high", risk_score
elif risk_score >= 0.3:
return "medium", risk_score
else:
return "low", risk_score
def check_model_answer(self, user_query: str, model_answer: str) -> Tuple[bool, str]:
"""Validate the model's answer is safe with medical context awareness."""
uq = user_query or ""
ans = model_answer or ""
# Assess risk level first
risk_level, risk_score = self._assess_risk_level(ans)
logger.info(f"[SafetyGuard] Risk assessment: {risk_level} (score: {risk_score:.2f})")
# Always check for harmful content first
is_safe, reason = self._detect_harmful_content(ans)
if not is_safe:
return False, reason
# For high-risk content, always use strict validation
if risk_level == "high":
logger.warning("[SafetyGuard] High-risk content detected, using strict validation")
user_parts = self._chunk_text(uq, chunk_size=2000)
user_context = user_parts[0] if user_parts else ""
for ans_part in self._chunk_text(ans):
messages = [
{"role": "user", "content": user_context},
{"role": "assistant", "content": ans_part},
]
reply = self._call_guard(messages, max_tokens=96)
ok, reason = self._parse_guard_reply(reply)
if not ok:
return False, reason
return True, "high_risk_validated"
# For medical queries and answers, use relaxed validation
if self._is_medical_query(uq) or self._is_medical_query(ans):
logger.info("[SafetyGuard] Medical content detected, using relaxed validation")
return True, "medical_content"
# For medium-risk non-medical content, use guard validation
if risk_level == "medium":
logger.info("[SafetyGuard] Medium-risk content detected, using guard validation")
user_parts = self._chunk_text(uq, chunk_size=2000)
user_context = user_parts[0] if user_parts else ""
for ans_part in self._chunk_text(ans):
messages = [
{"role": "user", "content": user_context},
{"role": "assistant", "content": ans_part},
]
reply = self._call_guard(messages, max_tokens=96)
ok, reason = self._parse_guard_reply(reply)
if not ok:
return False, reason
return True, "medium_risk_validated"
# For low-risk content, allow through
logger.info("[SafetyGuard] Low-risk content detected, allowing through")
return True, "low_risk"
# Global instance (optional convenience)
safety_guard = SafetyGuard() |