llm-projec-2 / app.py
lopezkor000's picture
init
c394ae1
from flask import Flask, render_template, request, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = Flask(__name__)
# Load model and tokenizer
MODEL_NAME = "viccon23/STU-Injection-aegis"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
print("Model loaded successfully!")
def classify_prompt(text):
"""Classify if a prompt is safe or unsafe"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1).item()
confidence = predictions[0][predicted_class].item()
# Assuming class 0 is safe and class 1 is unsafe
label = "Safe" if predicted_class == 0 else "Unsafe"
return {
"label": label,
"confidence": round(confidence * 100, 2),
"predicted_class": predicted_class
}
@app.route('/')
def index():
return render_template('index.html')
@app.route('/classify', methods=['POST'])
def classify():
data = request.get_json()
prompt = data.get('prompt', '')
if not prompt:
return jsonify({"error": "No prompt provided"}), 400
result = classify_prompt(prompt)
return jsonify(result)
if __name__ == '__main__':
import os
port = int(os.environ.get('PORT', 7860))
app.run(debug=False, host='0.0.0.0', port=port)