BinKhoaLe1812's picture
Update models/guard.py
ec1346d verified
raw
history blame
5.96 kB
import os
import requests
import logging
from typing import Tuple, List, Dict
logger = logging.getLogger(__name__)
class SafetyGuard:
"""
Wrapper around NVIDIA Llama Guard (meta/llama-guard-4-12b) hosted at
https://integrate.api.nvidia.com/v1/chat/completions
Exposes helpers to validate:
- user input safety
- model output safety (in context of the user question)
"""
def __init__(self):
self.api_key = os.getenv("NVIDIA_URI")
if not self.api_key:
raise ValueError("NVIDIA_URI environment variable not set for SafetyGuard")
self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
self.model = "meta/llama-guard-4-12b"
self.timeout_s = 30
@staticmethod
def _chunk_text(text: str, chunk_size: int = 2800, overlap: int = 200) -> List[str]:
"""Chunk long text to keep request payloads small enough for the guard.
Uses character-based approximation with small overlap.
"""
if not text:
return [""]
n = len(text)
if n <= chunk_size:
return [text]
chunks: List[str] = []
start = 0
while start < n:
end = min(start + chunk_size, n)
chunks.append(text[start:end])
if end == n:
break
start = max(0, end - overlap)
return chunks
def _call_guard(self, messages: List[Dict], max_tokens: int = 512) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# Try OpenAI-compatible schema first
payload_chat = {
"model": self.model,
"messages": messages,
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": max_tokens,
"stream": False,
}
# Alternative schema (some NVIDIA deployments require message content objects)
alt_messages = []
for m in messages:
content = m.get("content", "")
if isinstance(content, str):
content = [{"type": "text", "text": content}]
alt_messages.append({"role": m.get("role", "user"), "content": content})
payload_alt = {
"model": self.model,
"messages": alt_messages,
"temperature": 0.2,
"top_p": 0.7,
"max_tokens": max_tokens,
"stream": False,
}
# Attempt primary, then fallback
for payload in (payload_chat, payload_alt):
try:
resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s)
if resp.status_code >= 400:
# Log server message for debugging payload issues
try:
logger.error(f"[SafetyGuard] HTTP {resp.status_code}: {resp.text[:400]}")
except Exception:
pass
resp.raise_for_status()
data = resp.json()
content = (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
)
if content:
return content
except Exception as e:
# Try next payload shape
logger.error(f"[SafetyGuard] Guard API call failed: {e}")
continue
# All attempts failed
return ""
@staticmethod
def _parse_guard_reply(text: str) -> Tuple[bool, str]:
"""Parse guard reply; expect 'SAFE' or 'UNSAFE: <reason>' (case-insensitive)."""
if not text:
# Fail-open: treat as SAFE if guard unavailable to avoid false blocks
return True, "guard_unavailable"
t = text.strip()
upper = t.upper()
if upper.startswith("SAFE") and not upper.startswith("SAFEGUARD"):
return True, ""
if upper.startswith("UNSAFE"):
# Extract reason after the first colon if present
parts = t.split(":", 1)
reason = parts[1].strip() if len(parts) > 1 else "policy violation"
return False, reason
# Fallback: treat unknown response as unsafe
return False, t[:180]
def check_user_query(self, user_query: str) -> Tuple[bool, str]:
"""Validate the user query is safe to process (provider-style single user message)."""
text = user_query or ""
# If too long, validate each chunk; any UNSAFE makes overall UNSAFE
for part in self._chunk_text(text):
messages = [{"role": "user", "content": part}]
reply = self._call_guard(messages, max_tokens=64)
ok, reason = self._parse_guard_reply(reply)
if not ok:
return False, reason
return True, ""
def check_model_answer(self, user_query: str, model_answer: str) -> Tuple[bool, str]:
"""Validate the model's answer is safe using provider's example schema (user + assistant turns)."""
uq = user_query or ""
ans = model_answer or ""
# Chunk assistant answer; if user query is huge, use first chunk of it as context
user_parts = self._chunk_text(uq, chunk_size=2000)
user_context = user_parts[0] if user_parts else ""
for ans_part in self._chunk_text(ans):
messages = [
{"role": "user", "content": user_context},
{"role": "assistant", "content": ans_part},
]
reply = self._call_guard(messages, max_tokens=96)
ok, reason = self._parse_guard_reply(reply)
if not ok:
return False, reason
return True, ""
# Global instance (optional convenience)
safety_guard = SafetyGuard()