Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
812cc3b
1
Parent(s):
d555d15
Upd MedSwin Falls back to manual formatting if the template is missing or fails
Browse files
app.py
CHANGED
|
@@ -281,6 +281,35 @@ def generate_speech(text: str):
|
|
| 281 |
logger.error(f"TTS error: {e}")
|
| 282 |
return None
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
def detect_language(text: str) -> str:
|
| 285 |
"""Detect language of input text"""
|
| 286 |
try:
|
|
@@ -943,11 +972,21 @@ def stream_chat(
|
|
| 943 |
max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
|
| 944 |
max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
|
| 945 |
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
|
| 952 |
inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
|
| 953 |
prompt_length = inputs['input_ids'].shape[1]
|
|
@@ -1207,7 +1246,7 @@ def create_demo():
|
|
| 1207 |
minimum=0,
|
| 1208 |
maximum=1,
|
| 1209 |
step=0.1,
|
| 1210 |
-
value=0.
|
| 1211 |
label="Temperature"
|
| 1212 |
)
|
| 1213 |
max_new_tokens = gr.Slider(
|
|
@@ -1222,7 +1261,7 @@ def create_demo():
|
|
| 1222 |
minimum=0.0,
|
| 1223 |
maximum=1.0,
|
| 1224 |
step=0.1,
|
| 1225 |
-
value=0.
|
| 1226 |
label="Top P"
|
| 1227 |
)
|
| 1228 |
top_k = gr.Slider(
|
|
|
|
| 281 |
logger.error(f"TTS error: {e}")
|
| 282 |
return None
|
| 283 |
|
| 284 |
+
def format_prompt_manually(messages: list, tokenizer) -> str:
|
| 285 |
+
"""Manually format prompt for models without chat template"""
|
| 286 |
+
prompt_parts = []
|
| 287 |
+
|
| 288 |
+
# Combine system and user messages into a single instruction
|
| 289 |
+
system_content = ""
|
| 290 |
+
user_content = ""
|
| 291 |
+
|
| 292 |
+
for msg in messages:
|
| 293 |
+
role = msg.get("role", "user")
|
| 294 |
+
content = msg.get("content", "")
|
| 295 |
+
|
| 296 |
+
if role == "system":
|
| 297 |
+
system_content = content
|
| 298 |
+
elif role == "user":
|
| 299 |
+
user_content = content
|
| 300 |
+
elif role == "assistant":
|
| 301 |
+
# Skip assistant messages in history for now (can be added if needed)
|
| 302 |
+
pass
|
| 303 |
+
|
| 304 |
+
# Format for MedAlpaca/LLaMA-based medical models
|
| 305 |
+
# Common format: Instruction + Input -> Response
|
| 306 |
+
if system_content:
|
| 307 |
+
prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
|
| 308 |
+
else:
|
| 309 |
+
prompt = f"Question: {user_content}\n\nAnswer:"
|
| 310 |
+
|
| 311 |
+
return prompt
|
| 312 |
+
|
| 313 |
def detect_language(text: str) -> str:
|
| 314 |
"""Detect language of input text"""
|
| 315 |
try:
|
|
|
|
| 972 |
max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
|
| 973 |
max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
|
| 974 |
|
| 975 |
+
# Check if tokenizer has chat template, otherwise format manually
|
| 976 |
+
if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
|
| 977 |
+
try:
|
| 978 |
+
prompt = medical_tokenizer.apply_chat_template(
|
| 979 |
+
messages,
|
| 980 |
+
tokenize=False,
|
| 981 |
+
add_generation_prompt=True
|
| 982 |
+
)
|
| 983 |
+
except Exception as e:
|
| 984 |
+
logger.warning(f"Chat template failed, using manual formatting: {e}")
|
| 985 |
+
# Fallback to manual formatting
|
| 986 |
+
prompt = format_prompt_manually(messages, medical_tokenizer)
|
| 987 |
+
else:
|
| 988 |
+
# Manual formatting for models without chat template
|
| 989 |
+
prompt = format_prompt_manually(messages, medical_tokenizer)
|
| 990 |
|
| 991 |
inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
|
| 992 |
prompt_length = inputs['input_ids'].shape[1]
|
|
|
|
| 1246 |
minimum=0,
|
| 1247 |
maximum=1,
|
| 1248 |
step=0.1,
|
| 1249 |
+
value=0.2,
|
| 1250 |
label="Temperature"
|
| 1251 |
)
|
| 1252 |
max_new_tokens = gr.Slider(
|
|
|
|
| 1261 |
minimum=0.0,
|
| 1262 |
maximum=1.0,
|
| 1263 |
step=0.1,
|
| 1264 |
+
value=0.7,
|
| 1265 |
label="Top P"
|
| 1266 |
)
|
| 1267 |
top_k = gr.Slider(
|