Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import gradio as gr | |
| import openai | |
| import re | |
| #openai.api_key = 'xxx' | |
| def load_or_create_model_and_embeddings(model_name, data_file, output_dir): | |
| model_path = os.path.join(output_dir, 'saved_model') | |
| embeddings_path = os.path.join(output_dir, 'corpus_embeddings.pt') | |
| if os.path.exists(model_path) and os.path.exists(embeddings_path): | |
| model = SentenceTransformer(model_path) | |
| embeddings = torch.load(embeddings_path) | |
| with open(data_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| else: | |
| model = SentenceTransformer(model_name) | |
| with open(data_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| texts = [item['text'] for item in data] | |
| embeddings = model.encode(texts, convert_to_tensor=True) | |
| model.save(model_path) | |
| torch.save(embeddings, embeddings_path) | |
| return model, embeddings, data | |
| model_name = 'sentence-transformers/all-MiniLM-L6-v2' | |
| data_file = 'labeled_cti_data.json' | |
| output_dir = '.' | |
| model, embeddings, data = load_or_create_model_and_embeddings(model_name, data_file, output_dir) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings.cpu().numpy().astype('float32')) | |
| def get_entity_groups(entities): | |
| return list(set(entity['entity_group'] for entity in entities)) | |
| def get_color_for_entity(entity_group): | |
| colors = { | |
| 'SamFile': '#EE8434', # Orange (wheel) | |
| 'Way': '#C95D63', # Indian red | |
| 'Idus': '#AE8799', # Mountbatten pink | |
| 'Tool': '#9083AE', # African Violet | |
| 'Features': '#8181B9', # Tropical indigo | |
| 'HackOrg': '#496DDB', # Royal Blue (web color) | |
| 'Purp': '#BCD8C1', # Celadon | |
| 'OffAct': '#D6DBB2', # Vanilla | |
| 'Org': '#E3D985', # Flax | |
| 'SecTeam': '#E57A44', # Orange (Crayola) | |
| 'Time': '#E3D985', # Dark purple | |
| 'Exp': '#5D76CF', # Glaucous | |
| 'Area': '#757FC1', # Another shade of blue | |
| } | |
| return colors.get(entity_group, '#000000') # Default to black if entity group not found | |
| def semantic_search(query, top_k=5): | |
| query_embedding = model.encode([query], convert_to_tensor=True) | |
| distances, indices = index.search(query_embedding.cpu().numpy().astype('float32'), top_k) | |
| results = [] | |
| for distance, idx in zip(distances[0], indices[0]): | |
| similarity_score = 1 - distance / 2 # 將距離轉換為相似度分數 | |
| if similarity_score >= 0.45: # 只添加相似度大於等於0.45的結果 | |
| results.append({ | |
| 'text': data[idx]['text'], | |
| 'entities': data[idx]['entities'], | |
| 'similarity_score': similarity_score, | |
| 'entity_groups': get_entity_groups(data[idx]['entities']) | |
| }) | |
| return results | |
| def search_and_format(query): | |
| results = semantic_search(query) | |
| if not results: | |
| return "<div class='search-result'><p>查無相關資訊。</p></div>" | |
| formatted_results = """ | |
| <style> | |
| .search-result { | |
| font-size: 24px; | |
| line-height: 1.6; | |
| } | |
| .search-result h2 { | |
| font-size: 24px; | |
| color: #333; | |
| } | |
| .search-result h3 { | |
| font-size: 24px; | |
| color: #444; | |
| } | |
| .search-result p { | |
| margin-bottom: 24px; | |
| } | |
| .result-separator { | |
| border-top: 2px solid #ccc; | |
| margin: 20px 0; | |
| } | |
| </style> | |
| <div class="search-result"> | |
| """ | |
| for i, result in enumerate(results, 1): | |
| if i > 1: | |
| formatted_results += '<div class="result-separator"></div>' | |
| formatted_results += f"<p><strong>相似度分數:</strong> {result['similarity_score']:.4f}</p>" | |
| formatted_results += f"<p><strong>情資:</strong> {format_text_with_entities_markdown(result['text'], result['entities'])}</p>" | |
| formatted_results += f"<p><strong>命名實體:</strong> {'、'.join(result['entity_groups'])}</p>" | |
| formatted_results += "</div>" | |
| return formatted_results | |
| def format_text_with_entities_markdown(text, entities): | |
| # 將實體按照起始位置排序 | |
| entity_spans = sorted(entities, key=lambda x: x['start']) | |
| # 創建一個字典來存儲每個單詞的實體 | |
| word_entities = {} | |
| # 使用正則表達式分割文本為單詞 | |
| words = re.findall(r'\S+|\s+', text) | |
| current_pos = 0 | |
| for word in words: | |
| word_start = current_pos | |
| word_end = current_pos + len(word) | |
| word_entities[word] = [] | |
| # 檢查每個實體是否與當前單詞重疊 | |
| for entity in entity_spans: | |
| if entity['start'] < word_end and entity['end'] > word_start: | |
| word_entities[word].append(entity['entity_group']) | |
| current_pos = word_end | |
| # 處理每個單詞 | |
| formatted_text = [] | |
| for word in words: | |
| if word_entities[word]: | |
| unique_entity_groups = list(dict.fromkeys(word_entities[word])) # 去除重複的實體 | |
| entity_tags = [] | |
| for i, group in enumerate(unique_entity_groups): | |
| entity_tag = f'<sup style="color: {get_color_for_entity(group)}; font-size: 14px;">{group}</sup>' | |
| if i > 0: # 如果不是第一個標籤,添加逗號分隔符 | |
| entity_tags.append('<sup style="font-size: 14px;">、</sup>') | |
| entity_tags.append(entity_tag) | |
| formatted_word = f'<strong>{word}</strong>{"".join(entity_tags)}' | |
| else: | |
| formatted_word = word | |
| formatted_text.append(formatted_word) | |
| return ''.join(formatted_text) | |
| def transcribe_audio(audio): | |
| try: | |
| with open(audio, "rb") as audio_file: | |
| transcript = openai.Audio.transcribe("whisper-1", audio_file) | |
| return transcript.text | |
| except Exception as e: | |
| return f"轉錄時發生錯誤: {str(e)}" | |
| def audio_to_search(audio): | |
| transcription = transcribe_audio(audio) | |
| search_results = search_and_format(transcription) | |
| combined_output = f"" | |
| return combined_output, transcription | |
| # 範例問題 | |
| example_queries = [ | |
| "Tell me about recent cyber attacks from Russia", | |
| "What APT groups are targeting Ukraine?", | |
| "Explain the Log4j vulnerability", | |
| "Chinese state-sponsored hacking activities", | |
| "Toilet?", | |
| "Latest North Korean hacker", | |
| "Describe the SolarWinds supply chain attack", | |
| "What is the Lazarus Group known for?", | |
| "Common attack vectors used against critical infrastructure", | |
| "pls rick roll me" | |
| ] | |
| # 自定義 CSS | |
| custom_css = """ | |
| .container {display: flex; flex-direction: row;} | |
| .input-column {flex: 1; padding-right: 20px;} | |
| .output-column {flex: 2;} | |
| .examples-list {display: flex; flex-wrap: wrap; gap: 10px;} | |
| .examples-list > * {flex-basis: calc(50% - 5px);} | |
| footer {display:none !important} | |
| .gradio-container {font-size: 16px;} | |
| """ | |
| # 創建Gradio界面 | |
| with gr.Blocks(css=custom_css) as iface: | |
| gr.Markdown("# AskCTI", elem_classes=["text-3xl"]) | |
| gr.Markdown("使用文字或使用語音輸入問題或關鍵字查詢相關威脅情資,右側會顯示前 5 個最相關的情資報告。", elem_classes=["text-xl"]) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=300): | |
| query_input = gr.Textbox(lines=3, label="", elem_classes=["text-lg"]) | |
| with gr.Row(): | |
| submit_btn = gr.Button("查詢", elem_classes=["text-lg"]) | |
| audio_input = gr.Audio(type="filepath", label="語音輸入") | |
| #audio_input = gr.Audio(sources="microphone", label="錄音", elem_classes="small-button") | |
| gr.Markdown("### 範例查詢", elem_classes=["text-xl"]) | |
| for i in range(0, len(example_queries), 2): | |
| with gr.Row(): | |
| for j in range(2): | |
| if i + j < len(example_queries): | |
| gr.Button(example_queries[i+j], elem_classes=["text-lg"]).click( | |
| lambda x: x, inputs=[gr.Textbox(value=example_queries[i+j], visible=False)], outputs=[query_input] | |
| ) | |
| with gr.Column(scale=2): | |
| output = gr.HTML(elem_classes=["text-lg"]) | |
| submit_btn.click(search_and_format, inputs=[query_input], outputs=[output]) | |
| audio_input.change(audio_to_search, inputs=[audio_input], outputs=[output, query_input]) | |
| # 啟動Gradio界面 | |
| iface.launch() |