LiamKhoaLe commited on
Commit
812cc3b
·
1 Parent(s): d555d15

Upd MedSwin Falls back to manual formatting if the template is missing or fails

Browse files
Files changed (1) hide show
  1. app.py +46 -7
app.py CHANGED
@@ -281,6 +281,35 @@ def generate_speech(text: str):
281
  logger.error(f"TTS error: {e}")
282
  return None
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def detect_language(text: str) -> str:
285
  """Detect language of input text"""
286
  try:
@@ -943,11 +972,21 @@ def stream_chat(
943
  max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
944
  max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
945
 
946
- prompt = medical_tokenizer.apply_chat_template(
947
- messages,
948
- tokenize=False,
949
- add_generation_prompt=True
950
- )
 
 
 
 
 
 
 
 
 
 
951
 
952
  inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
953
  prompt_length = inputs['input_ids'].shape[1]
@@ -1207,7 +1246,7 @@ def create_demo():
1207
  minimum=0,
1208
  maximum=1,
1209
  step=0.1,
1210
- value=0.7,
1211
  label="Temperature"
1212
  )
1213
  max_new_tokens = gr.Slider(
@@ -1222,7 +1261,7 @@ def create_demo():
1222
  minimum=0.0,
1223
  maximum=1.0,
1224
  step=0.1,
1225
- value=0.95,
1226
  label="Top P"
1227
  )
1228
  top_k = gr.Slider(
 
281
  logger.error(f"TTS error: {e}")
282
  return None
283
 
284
+ def format_prompt_manually(messages: list, tokenizer) -> str:
285
+ """Manually format prompt for models without chat template"""
286
+ prompt_parts = []
287
+
288
+ # Combine system and user messages into a single instruction
289
+ system_content = ""
290
+ user_content = ""
291
+
292
+ for msg in messages:
293
+ role = msg.get("role", "user")
294
+ content = msg.get("content", "")
295
+
296
+ if role == "system":
297
+ system_content = content
298
+ elif role == "user":
299
+ user_content = content
300
+ elif role == "assistant":
301
+ # Skip assistant messages in history for now (can be added if needed)
302
+ pass
303
+
304
+ # Format for MedAlpaca/LLaMA-based medical models
305
+ # Common format: Instruction + Input -> Response
306
+ if system_content:
307
+ prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
308
+ else:
309
+ prompt = f"Question: {user_content}\n\nAnswer:"
310
+
311
+ return prompt
312
+
313
  def detect_language(text: str) -> str:
314
  """Detect language of input text"""
315
  try:
 
972
  max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
973
  max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
974
 
975
+ # Check if tokenizer has chat template, otherwise format manually
976
+ if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
977
+ try:
978
+ prompt = medical_tokenizer.apply_chat_template(
979
+ messages,
980
+ tokenize=False,
981
+ add_generation_prompt=True
982
+ )
983
+ except Exception as e:
984
+ logger.warning(f"Chat template failed, using manual formatting: {e}")
985
+ # Fallback to manual formatting
986
+ prompt = format_prompt_manually(messages, medical_tokenizer)
987
+ else:
988
+ # Manual formatting for models without chat template
989
+ prompt = format_prompt_manually(messages, medical_tokenizer)
990
 
991
  inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
992
  prompt_length = inputs['input_ids'].shape[1]
 
1246
  minimum=0,
1247
  maximum=1,
1248
  step=0.1,
1249
+ value=0.2,
1250
  label="Temperature"
1251
  )
1252
  max_new_tokens = gr.Slider(
 
1261
  minimum=0.0,
1262
  maximum=1.0,
1263
  step=0.1,
1264
+ value=0.7,
1265
  label="Top P"
1266
  )
1267
  top_k = gr.Slider(