| """ |
| Example usage of Online mode with warmup |
| |
| This demonstrates: |
| 1. Warmup phase (generate N sequences to calibrate threshold) |
| 2. Threshold computation (DeepConf-low or DeepConf-high) |
| 3. Final generation with calibrated early stopping |
| """ |
|
|
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
| def extract_answer(text: str) -> Optional[str]: |
| """ |
| Extract boxed answer from LaTeX text |
| |
| Looks for \\boxed{answer} pattern in generated text. |
| """ |
| if "boxed" in text: |
| ans = text.split("boxed")[-1] |
| if len(ans) == 0: |
| return "" |
| elif ans[0] == "{": |
| stack = 1 |
| a = "" |
| for c in ans[1:]: |
| if c == "{": |
| stack += 1 |
| a += c |
| elif c == "}": |
| stack -= 1 |
| if stack == 0: |
| break |
| a += c |
| else: |
| a += c |
| else: |
| a = ans.split("$")[0].strip() |
| return a.strip() |
|
|
| return None |
|
|
|
|
| def compute_least_grouped(confs: list, group_size: int) -> list: |
| """ |
| Compute sliding window mean confidence |
| |
| Args: |
| confs: List of per-token confidence values |
| group_size: Size of sliding window |
| |
| Returns: |
| List of mean confidences for each window position |
| """ |
| if len(confs) < group_size: |
| return [sum(confs) / len(confs)] if confs else [0] |
|
|
| sliding_means = [] |
| for i in range(len(confs) - group_size + 1): |
| window = confs[i : i + group_size] |
| sliding_means.append(round(sum(window) / len(window), 3)) |
| return sliding_means |
|
|
|
|
| def process_single_output( |
| sequence, confidences, tokenizer, window_size: int, threshold: Optional[float] = None |
| ) -> dict: |
| """ |
| Process a single generated sequence |
| |
| Args: |
| sequence: Generated token IDs |
| confidences: Per-token confidence values (list or tensor) |
| tokenizer: Tokenizer for decoding |
| window_size: Size of sliding window for confidence |
| threshold: Optional threshold for early stopping detection |
| |
| Returns: |
| Dictionary with trace data |
| """ |
| |
| if hasattr(confidences, "tolist"): |
| confs = confidences.tolist() |
| else: |
| confs = list(confidences) |
|
|
| |
| text = tokenizer.decode(sequence, skip_special_tokens=True) |
|
|
| |
| sliding_window = compute_least_grouped(confs, window_size) |
| min_conf = min(sliding_window) if sliding_window else 0 |
|
|
| |
| stopped_early = False |
| stop_position = None |
|
|
| if threshold is not None: |
| for pos, window_mean in enumerate(sliding_window): |
| if window_mean < threshold: |
| stopped_early = True |
| stop_position = pos + window_size |
| break |
|
|
| |
| extracted_answer = extract_answer(text) |
|
|
| return { |
| "text": text, |
| "confs": confs, |
| "group_confs": sliding_window, |
| "min_conf": min_conf, |
| "stopped_early": stopped_early, |
| "stop_position": stop_position, |
| "extracted_answer": extracted_answer, |
| "num_tokens": len(confs), |
| "token_ids": sequence.tolist() if hasattr(sequence, "tolist") else list(sequence), |
| } |
|
|
|
|
| def process_batch_results(outputs, tokenizer, window_size: int = 2048, threshold: Optional[float] = None) -> dict: |
| """ |
| Process batch generation outputs |
| |
| This function provides post-processing capabilities for batch-generated |
| sequences, allowing analysis of confidence patterns and early stopping |
| behavior after generation is complete. |
| |
| Args: |
| outputs: GenerateDecoderOnlyOutput from model.generate() |
| tokenizer: Tokenizer for decoding sequences |
| window_size: Size of sliding window for confidence computation |
| threshold: Optional threshold for detecting where early stopping would occur |
| |
| Returns: |
| Dictionary containing: |
| - traces: List of processed trace dictionaries |
| - min_confs: List of minimum confidences per trace |
| - total_tokens: Total tokens across all traces |
| - num_traces: Number of traces processed |
| """ |
| if not hasattr(outputs, "sequences"): |
| raise ValueError("outputs must have 'sequences' attribute") |
|
|
| if not hasattr(outputs, "confidences") or outputs.confidences is None: |
| raise ValueError("outputs must have 'confidences' attribute. Set output_confidences=True in generation_config") |
|
|
| sequences = outputs.sequences |
| confidences = outputs.confidences |
|
|
| |
| traces = [] |
| min_confs = [] |
| total_tokens = 0 |
|
|
| for i in range(sequences.shape[0]): |
| trace_data = process_single_output(sequences[i], confidences[i], tokenizer, window_size, threshold) |
|
|
| traces.append(trace_data) |
| min_confs.append(trace_data["min_conf"]) |
| total_tokens += trace_data["num_tokens"] |
|
|
| return {"traces": traces, "min_confs": min_confs, "total_tokens": total_tokens, "num_traces": len(traces)} |
|
|
|
|
| def compute_warmup_threshold(min_confs: list, variant: str = "low", eta: Optional[float] = None) -> float: |
| """ |
| Compute threshold from warmup confidences |
| |
| Args: |
| min_confs: List of minimum confidences from warmup sequences |
| variant: "low" (aggressive) or "high" (permissive) |
| eta: Optional manual eta value (overrides variant default) |
| |
| Returns: |
| Computed threshold value |
| """ |
| if eta is None: |
| eta = 0.1 if variant == "low" else 0.9 if variant == "high" else 0.5 |
|
|
| confs = np.asarray(min_confs, dtype=np.float32) |
| pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) |
| threshold = float(np.percentile(confs, pct)) |
|
|
| return threshold |
|
|
|
|
| |
| |
| |
|
|
|
|
| def prepare_prompt(question: str, tokenizer): |
| """Prepare prompt using chat template""" |
| messages = [{"role": "user", "content": question}] |
|
|
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
| return prompt |
|
|
|
|
| def run_online_mode_example( |
| question: str, |
| ground_truth: Optional[str] = None, |
| warmup_traces: int = 8, |
| confidence_variant: str = "low", |
| window_size: int = 10, |
| max_tokens: int = 128, |
| temperature: float = 0.7, |
| top_p: float = 0.95, |
| ): |
| """ |
| Run DeepConf in online mode |
| |
| Args: |
| question: Question to answer |
| ground_truth: Optional ground truth answer for evaluation |
| warmup_traces: Number of warmup sequences (default: 8) |
| confidence_variant: "low" (aggressive) or "high" (permissive) |
| window_size: Sliding window size for confidence |
| max_tokens: Max tokens per generation |
| temperature: Sampling temperature |
| top_p: Top-p sampling |
| """ |
|
|
| |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
| print(f"Loading model: {model_name}") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| local_files_only=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
|
|
| |
| prompt = prepare_prompt(question, tokenizer) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| print("\n" + "=" * 80) |
| print("DEEPCONF ONLINE MODE - FOLLOWING OFFICIAL PATTERN") |
| print("=" * 80) |
| print(f"\nQuestion: {question}") |
| if ground_truth: |
| print(f"Ground truth: {ground_truth}") |
| print("\nConfiguration:") |
| print(f" - Warmup traces: {warmup_traces}") |
| print(f" - Variant: DeepConf-{confidence_variant}") |
| print(f" - Window size: {window_size}") |
| print(f" - Max tokens: {max_tokens}") |
| print(f" - Temperature: {temperature}") |
| print(f" - Top-p: {top_p}") |
|
|
| |
| |
| |
| print("\n" + "=" * 80) |
| print(f"PHASE 1: WARMUP (Generating {warmup_traces} sequences for calibration)") |
| print("=" * 80) |
|
|
| warmup_config = GenerationConfig( |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| max_new_tokens=max_tokens, |
| enable_conf=True, |
| enable_early_stopping=False, |
| output_confidences=True, |
| return_dict_in_generate=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| |
| expanded_ids = inputs.input_ids.repeat(warmup_traces, 1) |
| if "attention_mask" in inputs and inputs.attention_mask is not None: |
| expanded_mask = inputs.attention_mask.repeat(warmup_traces, 1) |
| else: |
| expanded_mask = None |
|
|
| print(f"Generating {warmup_traces} warmup sequences...") |
| warmup_outputs = model.generate( |
| input_ids=expanded_ids, |
| attention_mask=expanded_mask, |
| generation_config=warmup_config, |
| custom_generate="kashif/DeepConf", |
| trust_remote_code=True, |
| ) |
|
|
| |
| warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=window_size) |
|
|
| print("\nWarmup complete!") |
| print(f" - Total tokens: {warmup_results['total_tokens']}") |
| print(f" - Min confidences: {[round(c, 3) for c in warmup_results['min_confs']]}") |
|
|
| |
| print("\nWarmup Traces:") |
| print("-" * 80) |
| for i, trace in enumerate(warmup_results["traces"]): |
| text = trace["text"][len(prompt) :].strip() |
| answer = extract_answer(text) |
| print(f"\nTrace {i + 1}:") |
| print(f" Tokens: {trace['num_tokens']}, Min conf: {trace['min_conf']:.3f}") |
| print(f" Text: {text[:80]}..." if len(text) > 80 else f" Text: {text}") |
| if answer: |
| print(f" Answer: {answer}") |
| if ground_truth: |
| correct = answer.strip() == ground_truth.strip() |
| print(f" Correct: {'✓' if correct else '✗'}") |
|
|
| |
| |
| |
| print("\n" + "=" * 80) |
| print("PHASE 2: THRESHOLD COMPUTATION") |
| print("=" * 80) |
|
|
| threshold = compute_warmup_threshold(warmup_results["min_confs"], variant=confidence_variant) |
|
|
| eta = 0.1 if confidence_variant == "low" else 0.9 |
| percentile = (1.0 - eta) * 100 |
|
|
| print("\nComputed threshold from warmup:") |
| print(f" - Variant: DeepConf-{confidence_variant} (eta={eta})") |
| print(f" - Percentile: {percentile:.0f}th") |
| print(f" - Threshold: {threshold:.3f}") |
| print("\nInterpretation:") |
| if confidence_variant == "low": |
| print(" DeepConf-low is AGGRESSIVE - stops early to save tokens") |
| else: |
| print(" DeepConf-high is PERMISSIVE - allows longer generation") |
|
|
| |
| |
| |
| print("\n" + "=" * 80) |
| print("PHASE 3: FINAL GENERATION (With calibrated early stopping)") |
| print("=" * 80) |
|
|
| final_config = GenerationConfig( |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| max_new_tokens=max_tokens, |
| enable_conf=True, |
| enable_early_stopping=True, |
| threshold=threshold, |
| window_size=window_size, |
| output_confidences=True, |
| return_dict_in_generate=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| print(f"Generating with DeepConf-{confidence_variant} (threshold={threshold:.3f})...") |
| final_output = model.generate( |
| **inputs, |
| generation_config=final_config, |
| custom_generate="kashif/DeepConf", |
| trust_remote_code=True, |
| ) |
|
|
| final_text = tokenizer.decode(final_output.sequences[0], skip_special_tokens=True) |
| final_tokens = final_output.sequences.shape[1] - inputs.input_ids.shape[1] |
| final_answer = extract_answer(final_text) |
|
|
| |
| if hasattr(final_output, "confidences") and final_output.confidences is not None: |
| min_conf = final_output.confidences.min().item() |
| mean_conf = final_output.confidences.mean().item() |
| else: |
| min_conf = None |
| mean_conf = None |
|
|
| print("\nFinal generation complete!") |
| print(f" - Tokens generated: {final_tokens}") |
| if min_conf is not None: |
| print(f" - Min confidence: {min_conf:.3f}") |
| print(f" - Mean confidence: {mean_conf:.3f}") |
|
|
| print("\nGenerated text:") |
| print("-" * 80) |
| print(final_text) |
| print("-" * 80) |
|
|
| if final_answer: |
| print(f"\nExtracted answer: {final_answer}") |
| if ground_truth: |
| correct = final_answer.strip() == ground_truth.strip() |
| print(f"Correct: {'✓' if correct else '✗'}") |
|
|
| |
| |
| |
| print("\n" + "=" * 80) |
| print("SUMMARY") |
| print("=" * 80) |
|
|
| total_warmup_tokens = warmup_results["total_tokens"] |
| total_tokens = total_warmup_tokens + final_tokens |
|
|
| print(f"Total tokens: {total_tokens}") |
| print(f" - Warmup: {total_warmup_tokens} ({warmup_traces} sequences)") |
| print(f" - Final: {final_tokens}") |
|
|
| |
| avg_warmup_tokens = total_warmup_tokens / warmup_traces |
| potential_savings = avg_warmup_tokens - final_tokens |
| if potential_savings > 0: |
| print("\nToken savings from early stopping:") |
| print(f" - Average warmup length: {avg_warmup_tokens:.1f} tokens") |
| print(f" - Final length: {final_tokens} tokens") |
| print(f" - Saved: {potential_savings:.1f} tokens ({potential_savings / avg_warmup_tokens * 100:.1f}%)") |
|
|
| print("\n" + "=" * 80) |
| print("Example complete!") |
| print("=" * 80) |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("\n\n" + "â–ˆ" * 80) |
| print("EXAMPLE 1: Simple Math Problem") |
| print("â–ˆ" * 80) |
|
|
| run_online_mode_example( |
| question="What is 15 * 8? Show your work step by step.", |
| ground_truth="120", |
| warmup_traces=4, |
| confidence_variant="low", |
| window_size=5, |
| max_tokens=64, |
| ) |
|
|
| |
| print("\n\n" + "â–ˆ" * 80) |
| print("EXAMPLE 2: Square Root Problem") |
| print("â–ˆ" * 80) |
|
|
| run_online_mode_example( |
| question="What is the square root of 144? Express your answer in the form \\boxed{answer}.", |
| ground_truth="12", |
| warmup_traces=4, |
| confidence_variant="high", |
| window_size=5, |
| max_tokens=64, |
| ) |
|
|
| |
| print("\n\n" + "â–ˆ" * 80) |
| print("EXAMPLE 3: Word Problem") |
| print("â–ˆ" * 80) |
|
|
| run_online_mode_example( |
| question="If a train travels 60 miles per hour for 2.5 hours, how far does it travel?", |
| ground_truth="150", |
| warmup_traces=4, |
| confidence_variant="low", |
| window_size=5, |
| max_tokens=96, |
| ) |
|
|