ChaoqianO commited on
Commit
225e634
·
0 Parent(s):

init: Analyze-stroke MCP + security demo

Browse files
Files changed (44) hide show
  1. .gitattributes +35 -0
  2. Analyze-stroke/mcp_output/README_MCP.md +88 -0
  3. Analyze-stroke/mcp_output/analysis.json +181 -0
  4. Analyze-stroke/mcp_output/env_info.json +17 -0
  5. Analyze-stroke/mcp_output/mcp_logs/llm_statistics.json +11 -0
  6. Analyze-stroke/mcp_output/mcp_logs/run_log.json +74 -0
  7. Analyze-stroke/mcp_output/mcp_plugin/__init__.py +0 -0
  8. Analyze-stroke/mcp_output/mcp_plugin/__pycache__/adapter.cpython-310.pyc +0 -0
  9. Analyze-stroke/mcp_output/mcp_plugin/__pycache__/mcp_service.cpython-310.pyc +0 -0
  10. Analyze-stroke/mcp_output/mcp_plugin/adapter.py +147 -0
  11. Analyze-stroke/mcp_output/mcp_plugin/main.py +13 -0
  12. Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py +734 -0
  13. Analyze-stroke/mcp_output/requirements.txt +9 -0
  14. Analyze-stroke/mcp_output/simple_revise_error_analysis.json +6 -0
  15. Analyze-stroke/mcp_output/start_mcp.py +33 -0
  16. Analyze-stroke/mcp_output/tests_mcp/test_mcp_basic.py +49 -0
  17. Analyze-stroke/mcp_output/tests_smoke/test_smoke.py +12 -0
  18. Analyze-stroke/patient_data/README_DEMO.txt +12 -0
  19. Analyze-stroke/patient_data/stroke_clean.csv +23 -0
  20. Analyze-stroke/secrets/national_id.csv +23 -0
  21. Analyze-stroke/source/__init__.py +4 -0
  22. Analyze-stroke/source/__pycache__/__init__.cpython-311.pyc +0 -0
  23. Analyze-stroke/source/__pycache__/__init__.cpython-39.pyc +0 -0
  24. Analyze-stroke/source/__pycache__/causal_module.cpython-310.pyc +0 -0
  25. Analyze-stroke/source/__pycache__/data_loader.cpython-311.pyc +0 -0
  26. Analyze-stroke/source/__pycache__/data_loader.cpython-39.pyc +0 -0
  27. Analyze-stroke/source/causal_module.py +108 -0
  28. Analyze-stroke/source/data_loader.py +55 -0
  29. Analyze-stroke/source/dim_reduction.py +130 -0
  30. Analyze-stroke/source/environment.yml +223 -0
  31. Analyze-stroke/source/feature_selection.py +140 -0
  32. Analyze-stroke/source/healthcare-dataset-stroke-data.csv +0 -0
  33. Analyze-stroke/source/main.py +130 -0
  34. Analyze-stroke/source/models.py +160 -0
  35. Analyze-stroke/source/plot_utils.py +217 -0
  36. Analyze-stroke/source/run_all_causal.py +100 -0
  37. Analyze-stroke/source/run_all_causal_wo_draw.py +37 -0
  38. Analyze-stroke/source/test_env.py +32 -0
  39. Dockerfile +16 -0
  40. README.md +12 -0
  41. app.py +13 -0
  42. demo_open_traversal.py +31 -0
  43. demo_path_traversal.py +26 -0
  44. requirements.txt +27 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Analyze-stroke/mcp_output/README_MCP.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze Stroke Plugin
2
+
3
+ ## Overview
4
+
5
+ The Analyze Stroke Plugin is a comprehensive tool designed to facilitate the analysis of stroke data using various machine learning and statistical methods. This plugin provides functionalities for data loading, feature selection, dimensionality reduction, and causal analysis, with options for visualization.
6
+
7
+ The project is hosted on GitHub and can be accessed via the following repository: [Analyze Stroke Repository](https://github.com/ghh1125/Analyze-stroke).
8
+
9
+ ## Installation
10
+
11
+ To set up the Analyze Stroke Plugin, follow these steps:
12
+
13
+ 1. **Clone the Repository:**
14
+
15
+ ```bash
16
+ git clone https://github.com/ghh1125/Analyze-stroke.git
17
+ cd Analyze-stroke
18
+ ```
19
+
20
+ 2. **Set Up the Environment:**
21
+
22
+ The project includes an `environment.yml` file for setting up the necessary Python environment. Use the following command to create the environment:
23
+
24
+ ```bash
25
+ conda env create -f environment.yml
26
+ ```
27
+
28
+ 3. **Activate the Environment:**
29
+
30
+ ```bash
31
+ conda activate analyze-stroke
32
+ ```
33
+
34
+ ## Usage
35
+
36
+ The Analyze Stroke Plugin provides several command-line interface (CLI) commands to execute different analysis processes. Below are the primary commands available:
37
+
38
+ - **Run All Causal Analysis with Visualization:**
39
+
40
+ This command executes all causal analysis processes and includes visualization of the results.
41
+
42
+ ```bash
43
+ python run_all_causal.py
44
+ ```
45
+
46
+ - **Run All Causal Analysis without Visualization:**
47
+
48
+ This command executes all causal analysis processes without generating visual outputs.
49
+
50
+ ```bash
51
+ python run_all_causal_wo_draw.py
52
+ ```
53
+
54
+ ## Tool Endpoints
55
+
56
+ The following modules are included in the Analyze Stroke Plugin:
57
+
58
+ - `causal_module.py`: Handles causal analysis processes.
59
+ - `data_loader.py`: Manages data loading operations.
60
+ - `dim_reduction.py`: Performs dimensionality reduction.
61
+ - `feature_selection.py`: Conducts feature selection.
62
+ - `main.py`: Main entry point for executing the plugin.
63
+ - `models.py`: Contains machine learning models used in the analysis.
64
+ - `plot_utils.py`: Provides utilities for plotting and visualization.
65
+ - `test_env.py`: Used for testing the environment setup.
66
+
67
+ ## Dependencies
68
+
69
+ The Analyze Stroke Plugin requires the following dependencies:
70
+
71
+ - Required: `numpy`, `pandas`, `scikit-learn`, `matplotlib`
72
+ - Optional: `seaborn` (for enhanced visualization)
73
+
74
+ ## Notes and Troubleshooting
75
+
76
+ - Ensure that all dependencies are correctly installed by using the provided `environment.yml` file.
77
+ - If you encounter issues with visualization, verify that `matplotlib` and `seaborn` are properly installed.
78
+ - For any issues related to data loading or processing, check the data format and ensure compatibility with the plugin's requirements.
79
+
80
+ ## Contributing
81
+
82
+ Contributions to the Analyze Stroke Plugin are welcome. Please fork the repository and submit a pull request with your changes. Ensure that your code adheres to the project's coding standards and includes appropriate documentation.
83
+
84
+ ## License
85
+
86
+ This project is licensed under the MIT License. See the [LICENSE](https://github.com/ghh1125/Analyze-stroke/blob/main/LICENSE) file for more details.
87
+
88
+ For further information, please visit the [GitHub repository](https://github.com/ghh1125/Analyze-stroke).
Analyze-stroke/mcp_output/analysis.json ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "summary": {
3
+ "repository_url": "https://github.com/ghh1125/Analyze-stroke",
4
+ "summary": "Imported via zip fallback, file count: 11",
5
+ "file_tree": {
6
+ "causal_module.py": {
7
+ "size": 3829
8
+ },
9
+ "data_loader.py": {
10
+ "size": 1789
11
+ },
12
+ "dim_reduction.py": {
13
+ "size": 5137
14
+ },
15
+ "environment.yml": {
16
+ "size": 7791
17
+ },
18
+ "feature_selection.py": {
19
+ "size": 5748
20
+ },
21
+ "main.py": {
22
+ "size": 5655
23
+ },
24
+ "models.py": {
25
+ "size": 7035
26
+ },
27
+ "plot_utils.py": {
28
+ "size": 7901
29
+ },
30
+ "run_all_causal.py": {
31
+ "size": 3386
32
+ },
33
+ "run_all_causal_wo_draw.py": {
34
+ "size": 901
35
+ },
36
+ "test_env.py": {
37
+ "size": 987
38
+ }
39
+ },
40
+ "processed_by": "zip_fallback",
41
+ "success": true
42
+ },
43
+ "structure": {
44
+ "packages": []
45
+ },
46
+ "dependencies": {
47
+ "has_environment_yml": true,
48
+ "has_requirements_txt": false,
49
+ "pyproject": false,
50
+ "setup_cfg": false,
51
+ "setup_py": false
52
+ },
53
+ "entry_points": {
54
+ "imports": [],
55
+ "cli": [],
56
+ "modules": []
57
+ },
58
+ "llm_analysis": {
59
+ "core_modules": [
60
+ {
61
+ "package": "https://github.com/ghh1125/Analyze-stroke/causal_module.py",
62
+ "module": "causal_module",
63
+ "functions": [
64
+ "function1",
65
+ "function2"
66
+ ],
67
+ "classes": [
68
+ "CausalClass1",
69
+ "CausalClass2"
70
+ ],
71
+ "description": "Handles causal inference and analysis for stroke data."
72
+ },
73
+ {
74
+ "package": "https://github.com/ghh1125/Analyze-stroke/data_loader.py",
75
+ "module": "data_loader",
76
+ "functions": [
77
+ "load_data",
78
+ "preprocess_data"
79
+ ],
80
+ "classes": [
81
+ "DataLoader"
82
+ ],
83
+ "description": "Responsible for loading and preprocessing stroke datasets."
84
+ },
85
+ {
86
+ "package": "https://github.com/ghh1125/Analyze-stroke/dim_reduction.py",
87
+ "module": "dim_reduction",
88
+ "functions": [
89
+ "reduce_dimensions"
90
+ ],
91
+ "classes": [
92
+ "DimReducer"
93
+ ],
94
+ "description": "Implements dimensionality reduction techniques."
95
+ },
96
+ {
97
+ "package": "https://github.com/ghh1125/Analyze-stroke/feature_selection.py",
98
+ "module": "feature_selection",
99
+ "functions": [
100
+ "select_features"
101
+ ],
102
+ "classes": [
103
+ "FeatureSelector"
104
+ ],
105
+ "description": "Contains methods for feature selection in datasets."
106
+ },
107
+ {
108
+ "package": "https://github.com/ghh1125/Analyze-stroke/models.py",
109
+ "module": "models",
110
+ "functions": [
111
+ "train_model",
112
+ "evaluate_model"
113
+ ],
114
+ "classes": [
115
+ "ModelTrainer",
116
+ "ModelEvaluator"
117
+ ],
118
+ "description": "Defines machine learning models and their training/evaluation."
119
+ },
120
+ {
121
+ "package": "https://github.com/ghh1125/Analyze-stroke/plot_utils.py",
122
+ "module": "plot_utils",
123
+ "functions": [
124
+ "plot_results"
125
+ ],
126
+ "classes": [],
127
+ "description": "Utility functions for plotting results and visualizations."
128
+ }
129
+ ],
130
+ "cli_commands": [
131
+ {
132
+ "name": "run_all_causal",
133
+ "module": "https://github.com/ghh1125/Analyze-stroke/run_all_causal.py",
134
+ "description": "Executes all causal analysis processes with visualization."
135
+ },
136
+ {
137
+ "name": "run_all_causal_wo_draw",
138
+ "module": "https://github.com/ghh1125/Analyze-stroke/run_all_causal_wo_draw.py",
139
+ "description": "Executes all causal analysis processes without visualization."
140
+ }
141
+ ],
142
+ "import_strategy": {
143
+ "primary": "import",
144
+ "fallback": "cli",
145
+ "confidence": 0.85
146
+ },
147
+ "dependencies": {
148
+ "required": [
149
+ "numpy",
150
+ "pandas",
151
+ "scikit-learn",
152
+ "matplotlib"
153
+ ],
154
+ "optional": [
155
+ "seaborn"
156
+ ]
157
+ },
158
+ "risk_assessment": {
159
+ "import_feasibility": 0.8,
160
+ "intrusiveness_risk": "medium",
161
+ "complexity": "medium"
162
+ }
163
+ },
164
+ "deepwiki_analysis": {
165
+ "repo_url": "https://github.com/ghh1125/Analyze-stroke",
166
+ "repo_name": "Analyze-stroke",
167
+ "content": null,
168
+ "model": "gpt-4o",
169
+ "source": "selenium",
170
+ "success": true
171
+ },
172
+ "deepwiki_options": {
173
+ "enabled": true,
174
+ "model": "gpt-4o"
175
+ },
176
+ "risk": {
177
+ "import_feasibility": 0.8,
178
+ "intrusiveness_risk": "medium",
179
+ "complexity": "medium"
180
+ }
181
+ }
Analyze-stroke/mcp_output/env_info.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "environment": {
3
+ "type": "conda",
4
+ "name": "Analyze-stroke_366682_env",
5
+ "files": {
6
+ "environment_yml": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/source/environment.yml"
7
+ },
8
+ "python": "3.10",
9
+ "exec_prefix": []
10
+ },
11
+ "original_tests": {
12
+ "passed": false,
13
+ "report_path": null
14
+ },
15
+ "timestamp": 1765367116.7647827,
16
+ "conda_available": true
17
+ }
Analyze-stroke/mcp_output/mcp_logs/llm_statistics.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_calls": 7,
3
+ "failed_calls": 0,
4
+ "retry_count": 0,
5
+ "total_prompt_tokens": 6752,
6
+ "total_completion_tokens": 3421,
7
+ "total_tokens": 10173,
8
+ "average_prompt_tokens": 964.5714285714286,
9
+ "average_completion_tokens": 488.7142857142857,
10
+ "average_tokens": 1453.2857142857142
11
+ }
Analyze-stroke/mcp_output/mcp_logs/run_log.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": 1765367363.30327,
3
+ "node": "RunNode",
4
+ "test_result": {
5
+ "passed": false,
6
+ "report_path": null,
7
+ "stdout": "",
8
+ "stderr": "ERROR conda.cli.main_run:execute(41): `conda run python mcp_output/start_mcp.py` failed. (See above for error)\nTraceback (most recent call last):\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py\", line 17, in <module>\n from mcp_service import create_app\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py\", line 5, in <module>\n import pandas as pd\nModuleNotFoundError: No module named 'pandas'\n\n"
9
+ },
10
+ "run_result": {
11
+ "success": false,
12
+ "test_passed": false,
13
+ "exit_code": 1,
14
+ "stdout": "",
15
+ "stderr": "ERROR conda.cli.main_run:execute(41): `conda run python mcp_output/start_mcp.py` failed. (See above for error)\nTraceback (most recent call last):\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py\", line 17, in <module>\n from mcp_service import create_app\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py\", line 5, in <module>\n import pandas as pd\nModuleNotFoundError: No module named 'pandas'\n\n",
16
+ "timestamp": 1765367363.3032475,
17
+ "error_type": "ImportError",
18
+ "error": "Module import failed: ERROR conda.cli.main_run:execute(41): `conda run python mcp_output/start_mcp.py` failed. (See above for error)\nTraceback (most recent call last):\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py\", line 17, in <module>\n from mcp_service import create_app\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py\", line 5, in <module>\n import pandas as pd\nModuleNotFoundError: No module named 'pandas'\n\n",
19
+ "details": {
20
+ "command": "/home/wshiah/code/miniconda3/bin/conda run -n Analyze-stroke_366682_env --cwd /export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke python mcp_output/start_mcp.py",
21
+ "working_directory": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke",
22
+ "environment_type": "conda"
23
+ }
24
+ },
25
+ "environment": {
26
+ "type": "conda",
27
+ "name": "Analyze-stroke_366682_env",
28
+ "files": {
29
+ "environment_yml": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/source/environment.yml"
30
+ },
31
+ "python": "3.10",
32
+ "exec_prefix": []
33
+ },
34
+ "plugin_info": {
35
+ "files": {
36
+ "mcp_output/start_mcp.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py",
37
+ "mcp_output/mcp_plugin/__init__.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/__init__.py",
38
+ "mcp_output/mcp_plugin/mcp_service.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py",
39
+ "mcp_output/mcp_plugin/adapter.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/adapter.py",
40
+ "mcp_output/mcp_plugin/main.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/main.py",
41
+ "mcp_output/requirements.txt": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/requirements.txt",
42
+ "mcp_output/README_MCP.md": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/README_MCP.md",
43
+ "mcp_output/tests_mcp/test_mcp_basic.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/tests_mcp/test_mcp_basic.py"
44
+ },
45
+ "adapter_mode": "import",
46
+ "endpoints": [
47
+ "function1",
48
+ "function2",
49
+ "causalclass1",
50
+ "causalclass2",
51
+ "load_data",
52
+ "preprocess_data",
53
+ "dataloader",
54
+ "reduce_dimensions",
55
+ "dimreducer",
56
+ "select_features",
57
+ "featureselector",
58
+ "train_model",
59
+ "evaluate_model",
60
+ "modeltrainer",
61
+ "modelevaluator",
62
+ "plot_results"
63
+ ],
64
+ "mcp_dir": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin",
65
+ "tests_dir": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/tests_mcp",
66
+ "main_entry": "start_mcp.py",
67
+ "readme_path": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/README_MCP.md",
68
+ "requirements": [
69
+ "fastmcp>=0.1.0",
70
+ "pydantic>=2.0.0"
71
+ ]
72
+ },
73
+ "fastmcp_installed": false
74
+ }
Analyze-stroke/mcp_output/mcp_plugin/__init__.py ADDED
File without changes
Analyze-stroke/mcp_output/mcp_plugin/__pycache__/adapter.cpython-310.pyc ADDED
Binary file (4.14 kB). View file
 
Analyze-stroke/mcp_output/mcp_plugin/__pycache__/mcp_service.cpython-310.pyc ADDED
Binary file (2.14 kB). View file
 
Analyze-stroke/mcp_output/mcp_plugin/adapter.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Path settings
5
+ source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
6
+ sys.path.insert(0, source_path)
7
+
8
+ # Import statements
9
+ try:
10
+ from causal_module import CausalModule
11
+ from data_loader import DataLoader
12
+ from dim_reduction import DimReduction
13
+ from feature_selection import FeatureSelection
14
+ from models import Models
15
+ from plot_utils import PlotUtils
16
+ from run_all_causal import run_all_causal
17
+ from run_all_causal_wo_draw import run_all_causal_wo_draw
18
+ except ImportError as e:
19
+ print(f"Import error: {e}. Ensure all modules are available in the source directory.")
20
+
21
+ class Adapter:
22
+ """
23
+ Adapter class for MCP plugin, providing methods to interact with the Analyze-stroke repository.
24
+ """
25
+
26
+ def __init__(self):
27
+ self.mode = "import"
28
+
29
+ # -------------------- Causal Module Methods --------------------
30
+
31
+ def create_causal_module_instance(self):
32
+ """
33
+ Create an instance of the CausalModule class.
34
+
35
+ Returns:
36
+ dict: Status of the operation and instance if successful.
37
+ """
38
+ try:
39
+ instance = CausalModule()
40
+ return {"status": "success", "instance": instance}
41
+ except Exception as e:
42
+ return {"status": "error", "message": str(e)}
43
+
44
+ # -------------------- Data Loader Methods --------------------
45
+
46
+ def create_data_loader_instance(self):
47
+ """
48
+ Create an instance of the DataLoader class.
49
+
50
+ Returns:
51
+ dict: Status of the operation and instance if successful.
52
+ """
53
+ try:
54
+ instance = DataLoader()
55
+ return {"status": "success", "instance": instance}
56
+ except Exception as e:
57
+ return {"status": "error", "message": str(e)}
58
+
59
+ # -------------------- Dimensionality Reduction Methods --------------------
60
+
61
+ def create_dim_reduction_instance(self):
62
+ """
63
+ Create an instance of the DimReduction class.
64
+
65
+ Returns:
66
+ dict: Status of the operation and instance if successful.
67
+ """
68
+ try:
69
+ instance = DimReduction()
70
+ return {"status": "success", "instance": instance}
71
+ except Exception as e:
72
+ return {"status": "error", "message": str(e)}
73
+
74
+ # -------------------- Feature Selection Methods --------------------
75
+
76
+ def create_feature_selection_instance(self):
77
+ """
78
+ Create an instance of the FeatureSelection class.
79
+
80
+ Returns:
81
+ dict: Status of the operation and instance if successful.
82
+ """
83
+ try:
84
+ instance = FeatureSelection()
85
+ return {"status": "success", "instance": instance}
86
+ except Exception as e:
87
+ return {"status": "error", "message": str(e)}
88
+
89
+ # -------------------- Models Methods --------------------
90
+
91
+ def create_models_instance(self):
92
+ """
93
+ Create an instance of the Models class.
94
+
95
+ Returns:
96
+ dict: Status of the operation and instance if successful.
97
+ """
98
+ try:
99
+ instance = Models()
100
+ return {"status": "success", "instance": instance}
101
+ except Exception as e:
102
+ return {"status": "error", "message": str(e)}
103
+
104
+ # -------------------- Plot Utilities Methods --------------------
105
+
106
+ def create_plot_utils_instance(self):
107
+ """
108
+ Create an instance of the PlotUtils class.
109
+
110
+ Returns:
111
+ dict: Status of the operation and instance if successful.
112
+ """
113
+ try:
114
+ instance = PlotUtils()
115
+ return {"status": "success", "instance": instance}
116
+ except Exception as e:
117
+ return {"status": "error", "message": str(e)}
118
+
119
+ # -------------------- Run All Causal Methods --------------------
120
+
121
+ def call_run_all_causal(self):
122
+ """
123
+ Execute the run_all_causal function.
124
+
125
+ Returns:
126
+ dict: Status of the operation.
127
+ """
128
+ try:
129
+ run_all_causal()
130
+ return {"status": "success"}
131
+ except Exception as e:
132
+ return {"status": "error", "message": str(e)}
133
+
134
+ def call_run_all_causal_wo_draw(self):
135
+ """
136
+ Execute the run_all_causal_wo_draw function.
137
+
138
+ Returns:
139
+ dict: Status of the operation.
140
+ """
141
+ try:
142
+ run_all_causal_wo_draw()
143
+ return {"status": "success"}
144
+ except Exception as e:
145
+ return {"status": "error", "message": str(e)}
146
+
147
+ # End of Adapter class definition
Analyze-stroke/mcp_output/mcp_plugin/main.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCP Service Auto-Wrapper - Auto-generated
3
+ """
4
+ from mcp_service import create_app
5
+
6
+ def main():
7
+ """Main entry point"""
8
+ app = create_app()
9
+ return app
10
+
11
+ if __name__ == "__main__":
12
+ app = main()
13
+ app.run()
Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import re
5
+ import pandas as pd
6
+ import numpy as np
7
+ import datetime
8
+ import time
9
+ import warnings
10
+ import logging
11
+ from typing import Optional, List
12
+
13
+ # 抑制 DoWhy 和其他库的日志
14
+ logging.getLogger('dowhy').setLevel(logging.WARNING)
15
+ logging.getLogger('pygraphviz').setLevel(logging.WARNING)
16
+
17
+ # ====================== 路径配置 (相对路径) ======================
18
+ # 获取当前文件所在目录: mcp_output/mcp_plugin/
19
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
20
+ # mcp_output 目录: mcp_output/
21
+ MCP_OUTPUT_DIR = os.path.dirname(CURRENT_DIR)
22
+ # 项目根目录: Analyze-stroke/
23
+ PROJECT_ROOT = os.path.dirname(MCP_OUTPUT_DIR)
24
+ # source 目录: Analyze-stroke/source/
25
+ SOURCE_DIR = os.path.join(PROJECT_ROOT, "source")
26
+
27
+ # 添加源代码路径到 Python 路径
28
+ sys.path.insert(0, SOURCE_DIR)
29
+
30
+ from fastmcp import FastMCP
31
+
32
+ # 导入源代码模块
33
+ from data_loader import DataLoader
34
+ from dim_reduction import DimensionAnalyzer
35
+ from models import ModelManager
36
+ from causal_module import CausalAnalyzer
37
+ from feature_selection import FeatureSelectionAnalyzer
38
+ from plot_utils import plot_from_excel
39
+ from run_all_causal import parse_log_output
40
+
41
+ # 默认数据文件路径 (相对路径)
42
+ DEFAULT_DATA_FILE = os.path.join(SOURCE_DIR, "healthcare-dataset-stroke-data.csv")
43
+ # 默认可视化输出目录
44
+ DEFAULT_VIS_DIR = os.path.join(MCP_OUTPUT_DIR, "visualization")
45
+ # 默认日志目录
46
+ DEFAULT_LOG_DIR = os.path.join(MCP_OUTPUT_DIR, "logs")
47
+
48
+ mcp = FastMCP("AnalyzeStrokeService")
49
+
50
+ # ====================== 数据加载工具 ======================
51
+
52
+ @mcp.tool(name="load_stroke_data", description="Load and clean the stroke dataset. Returns basic statistics about the data.")
53
+ def load_stroke_data_tool(file_path: str = None) -> dict:
54
+ """
55
+ Load and clean the stroke dataset.
56
+
57
+ Args:
58
+ file_path (str, optional): Path to the CSV data file. Uses default if not provided.
59
+
60
+ Returns:
61
+ dict: A dictionary containing success, result (data shape, columns, basic stats), and error fields.
62
+ """
63
+ try:
64
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
65
+ loader = DataLoader(data_file)
66
+ df = loader.load_and_clean()
67
+
68
+ result = {
69
+ "shape": df.shape,
70
+ "columns": list(df.columns),
71
+ "stroke_distribution": df['stroke'].value_counts().to_dict(),
72
+ "missing_values": df.isnull().sum().to_dict(),
73
+ "numeric_stats": df.describe().to_dict()
74
+ }
75
+ return {"success": True, "result": result, "error": None}
76
+ except Exception as e:
77
+ return {"success": False, "result": None, "error": str(e)}
78
+
79
+ # ====================== 降维分析工具 ======================
80
+
81
+ @mcp.tool(name="perform_pca_famd", description="Perform PCA/FAMD dimensionality reduction analysis on the stroke data. FAMD handles mixed numeric and categorical data.")
82
+ def perform_pca_famd_tool(file_path: str = None, save_dir: str = None) -> dict:
83
+ """
84
+ Perform PCA/FAMD dimensionality reduction analysis.
85
+
86
+ Args:
87
+ file_path (str, optional): Path to the CSV data file.
88
+ save_dir (str, optional): Directory to save visualization plots.
89
+
90
+ Returns:
91
+ dict: A dictionary containing success, result, and error fields.
92
+ """
93
+ try:
94
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
95
+ loader = DataLoader(data_file)
96
+ df = loader.load_and_clean()
97
+
98
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
99
+ vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
100
+ os.makedirs(vis_dir, exist_ok=True)
101
+
102
+ analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='pca')
103
+ analyzer.perform_pca_famd()
104
+
105
+ return {"success": True, "result": f"PCA/FAMD analysis completed. Plots saved to {vis_dir}", "error": None}
106
+ except Exception as e:
107
+ return {"success": False, "result": None, "error": str(e)}
108
+
109
+ @mcp.tool(name="perform_tsne", description="Perform t-SNE dimensionality reduction analysis on the stroke data for visualization.")
110
+ def perform_tsne_tool(file_path: str = None, save_dir: str = None) -> dict:
111
+ """
112
+ Perform t-SNE dimensionality reduction analysis.
113
+
114
+ Args:
115
+ file_path (str, optional): Path to the CSV data file.
116
+ save_dir (str, optional): Directory to save visualization plots.
117
+
118
+ Returns:
119
+ dict: A dictionary containing success, result, and error fields.
120
+ """
121
+ try:
122
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
123
+ loader = DataLoader(data_file)
124
+ df = loader.load_and_clean()
125
+
126
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
127
+ vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
128
+ os.makedirs(vis_dir, exist_ok=True)
129
+
130
+ analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='tsne')
131
+ analyzer.perform_tsne()
132
+
133
+ return {"success": True, "result": f"t-SNE analysis completed. Plots saved to {vis_dir}", "error": None}
134
+ except Exception as e:
135
+ return {"success": False, "result": None, "error": str(e)}
136
+
137
+ # ====================== 预测模型工具 ======================
138
+
139
+ @mcp.tool(name="run_prediction_model", description="Run a prediction model for stroke classification. User MUST specify model_type. Options: nb (Naive Bayes), lr (Logistic Regression), svm (SVM), knn (KNN), dt (Decision Tree), rf (Random Forest), xgboost (XGBoost), dnn (Deep Neural Network).")
140
+ def run_prediction_model_tool(model_type: str, file_path: str = None, save_dir: str = None) -> dict:
141
+ """
142
+ Run a prediction model for stroke classification.
143
+
144
+ Args:
145
+ model_type (str): Model type. Options: 'nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'.
146
+ file_path (str, optional): Path to the CSV data file.
147
+ save_dir (str, optional): Directory to save visualization plots.
148
+
149
+ Returns:
150
+ dict: A dictionary containing success, result (model metrics), and error fields.
151
+ """
152
+ try:
153
+ valid_models = ['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn']
154
+ if model_type not in valid_models:
155
+ return {"success": False, "result": None, "error": f"Invalid model type. Choose from: {valid_models}"}
156
+
157
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
158
+ loader = DataLoader(data_file)
159
+ df = loader.load_and_clean()
160
+ X_train, X_test, y_train, y_test = loader.get_split_data()
161
+
162
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
163
+ vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
164
+ os.makedirs(vis_dir, exist_ok=True)
165
+
166
+ manager = ModelManager(X_train, y_train, X_test, y_test, save_dir=vis_dir, timestamp=timestamp)
167
+
168
+ if model_type == 'dnn':
169
+ manager.run_dnn()
170
+ return {"success": True, "result": f"DNN model training completed. Plots saved to {vis_dir}", "error": None}
171
+ else:
172
+ manager.run_sklearn_model(model_type)
173
+ return {"success": True, "result": f"{model_type.upper()} model training completed. Plots saved to {vis_dir}", "error": None}
174
+ except Exception as e:
175
+ return {"success": False, "result": None, "error": str(e)}
176
+
177
+ # ====================== 特征选择工具 ======================
178
+
179
+ @mcp.tool(name="run_feature_selection", description="Run feature selection analysis. User MUST specify method. Options: kbest (Mutual Information), rfecv (Recursive Feature Elimination with CV), chi2 (Chi-Square Test).")
180
+ def run_feature_selection_tool(method: str, file_path: str = None, save_dir: str = None) -> dict:
181
+ """
182
+ Run feature selection analysis.
183
+
184
+ Args:
185
+ method (str): Feature selection method. Options: 'kbest', 'rfecv', 'chi2'.
186
+ file_path (str, optional): Path to the CSV data file.
187
+ save_dir (str, optional): Directory to save visualization plots.
188
+
189
+ Returns:
190
+ dict: A dictionary containing success, result, and error fields.
191
+ """
192
+ try:
193
+ valid_methods = ['kbest', 'rfecv', 'chi2']
194
+ if method not in valid_methods:
195
+ return {"success": False, "result": None, "error": f"Invalid method. Choose from: {valid_methods}"}
196
+
197
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
198
+ loader = DataLoader(data_file)
199
+ df = loader.load_and_clean()
200
+
201
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
202
+ vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
203
+ os.makedirs(vis_dir, exist_ok=True)
204
+
205
+ analyzer = FeatureSelectionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp)
206
+
207
+ if method == 'kbest':
208
+ analyzer.run_select_kbest()
209
+ elif method == 'rfecv':
210
+ analyzer.run_rfecv()
211
+ elif method == 'chi2':
212
+ analyzer.run_chi2()
213
+
214
+ return {"success": True, "result": f"Feature selection ({method}) completed. Plots saved to {vis_dir}", "error": None}
215
+ except Exception as e:
216
+ return {"success": False, "result": None, "error": str(e)}
217
+
218
+ # ====================== 因果推理工具 ======================
219
+
220
+ @mcp.tool(name="run_causal_analysis", description="Run causal inference analysis for a single treatment variable on stroke outcome using DoWhy framework. User MUST specify treatment variable. Options: gender, age, hypertension, heart_disease, ever_married, work_type, Residence_type, avg_glucose_level, bmi, smoking_status.")
221
+ def run_causal_analysis_tool(treatment: str, file_path: str = None) -> dict:
222
+ """
223
+ Run causal inference analysis for a single treatment variable.
224
+
225
+ Args:
226
+ treatment (str): Treatment variable. Options: 'gender', 'age', 'hypertension', 'heart_disease',
227
+ 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'.
228
+ file_path (str, optional): Path to the CSV data file.
229
+
230
+ Returns:
231
+ dict: A dictionary containing success, result (ATE value and interpretation), and error fields.
232
+ """
233
+ try:
234
+ valid_treatments = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
235
+ 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status']
236
+ if treatment not in valid_treatments:
237
+ return {"success": False, "result": None, "error": f"Invalid treatment. Choose from: {valid_treatments}"}
238
+
239
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
240
+ loader = DataLoader(data_file)
241
+ df = loader.load_and_clean()
242
+
243
+ # 智能领域编码 (与 main.py 保持一致)
244
+ if 'gender' in df.columns:
245
+ df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
246
+ if 'ever_married' in df.columns:
247
+ df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
248
+ if 'Residence_type' in df.columns:
249
+ df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
250
+ if 'smoking_status' in df.columns:
251
+ smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
252
+ df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
253
+ if 'work_type' in df.columns:
254
+ work_map = {'Private': 1, 'Self-employed': 1, 'Govt_job': 0, 'children': 0, 'Never_worked': 0}
255
+ df['work_type'] = df['work_type'].map(work_map).fillna(0)
256
+
257
+ analyzer = CausalAnalyzer(df)
258
+ ate_value = analyzer.run_analysis(treatment_col=treatment)
259
+
260
+ result = {
261
+ "treatment": treatment,
262
+ "outcome": "stroke",
263
+ "ATE": ate_value,
264
+ "interpretation": f"If '{treatment}' increases by 1 unit, the probability of stroke changes by {ate_value:.4f}"
265
+ }
266
+ return {"success": True, "result": result, "error": None}
267
+ except Exception as e:
268
+ return {"success": False, "result": None, "error": str(e)}
269
+
270
+ @mcp.tool(name="run_batch_causal_analysis", description="Run batch causal analysis for all treatment factors and generate an Excel report with visualizations. Analyzes all 10 factors: gender, age, hypertension, heart_disease, ever_married, work_type, Residence_type, avg_glucose_level, bmi, smoking_status.")
271
+ def run_batch_causal_analysis_tool(file_path: str = None, save_dir: str = None) -> dict:
272
+ """
273
+ Run batch causal analysis for all factors and generate summary report.
274
+
275
+ This function directly performs causal analysis for all factors without
276
+ relying on subprocess calls, making it cross-platform compatible.
277
+
278
+ Args:
279
+ file_path (str, optional): Path to the CSV data file.
280
+ save_dir (str, optional): Directory to save results and visualizations.
281
+
282
+ Returns:
283
+ dict: A dictionary containing success, result (report path and summary), and error fields.
284
+ """
285
+ try:
286
+ # 定义所有因素
287
+ FACTORS = [
288
+ 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
289
+ 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'
290
+ ]
291
+
292
+ # 设置保存目录
293
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_BatchCausal")
294
+ output_dir = save_dir if save_dir else os.path.join(DEFAULT_LOG_DIR, timestamp)
295
+ os.makedirs(output_dir, exist_ok=True)
296
+
297
+ # 加载并预处理数据
298
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
299
+ loader = DataLoader(data_file)
300
+ df = loader.load_and_clean()
301
+
302
+ # 智能领域编码
303
+ if 'gender' in df.columns:
304
+ df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
305
+ if 'ever_married' in df.columns:
306
+ df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
307
+ if 'Residence_type' in df.columns:
308
+ df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
309
+ if 'smoking_status' in df.columns:
310
+ smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
311
+ df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
312
+ if 'work_type' in df.columns:
313
+ work_map = {'Private': 1, 'Self-employed': 1, 'Govt_job': 0, 'children': 0, 'Never_worked': 0}
314
+ df['work_type'] = df['work_type'].map(work_map).fillna(0)
315
+
316
+ results_list = []
317
+
318
+ # 循环执行因果分析
319
+ for i, factor in enumerate(FACTORS):
320
+ print(f"\n>>> [{i+1}/{len(FACTORS)}] Analyzing factor: {factor} ...")
321
+
322
+ try:
323
+ analyzer = CausalAnalyzer(df)
324
+
325
+ # 捕获输出以提取指标
326
+ import io
327
+ from contextlib import redirect_stdout
328
+
329
+ f = io.StringIO()
330
+ with redirect_stdout(f):
331
+ ate_value = analyzer.run_analysis(treatment_col=factor)
332
+
333
+ output_text = f.getvalue()
334
+
335
+ # 提取指标
336
+ metrics = parse_log_output(output_text)
337
+ metrics['Factor'] = factor
338
+ metrics['ATE'] = ate_value if not np.isnan(metrics['ATE']) else ate_value
339
+ results_list.append(metrics)
340
+
341
+ print(f" -> [Success] ATE={metrics['ATE']:.4f}, P-Val={metrics['P-Value']}")
342
+
343
+ except Exception as e:
344
+ print(f" -> [Error] {str(e)}")
345
+ results_list.append({
346
+ 'Factor': factor,
347
+ 'ATE': np.nan,
348
+ 'New Effect': np.nan,
349
+ 'P-Value': np.nan
350
+ })
351
+
352
+ # 生成 Excel 报告
353
+ if results_list:
354
+ results_df = pd.DataFrame(results_list)
355
+ cols = ['Factor', 'ATE', 'New Effect', 'P-Value']
356
+ results_df = results_df[cols]
357
+
358
+ # 保留4位小数
359
+ numeric_cols = ['ATE', 'New Effect', 'P-Value']
360
+ results_df[numeric_cols] = results_df[numeric_cols].astype(float).round(4)
361
+
362
+ excel_name = f"Causal_Summary_{timestamp}.xlsx"
363
+ excel_path = os.path.join(output_dir, excel_name)
364
+ results_df.to_excel(excel_path, index=False)
365
+ print(f"\n[Output] Excel report saved: {excel_path}")
366
+
367
+ # 生成可视化图表
368
+ try:
369
+ plot_from_excel(excel_path, output_dir)
370
+ print("[Output] Visualization charts generated.")
371
+ except Exception as e:
372
+ print(f"[Warning] Plot generation failed: {e}")
373
+
374
+ # 返回摘要
375
+ summary = results_df.to_dict('records')
376
+ return {
377
+ "success": True,
378
+ "result": {
379
+ "excel_path": excel_path,
380
+ "output_dir": output_dir,
381
+ "summary": summary,
382
+ "total_factors": len(FACTORS),
383
+ "successful_analyses": len([r for r in results_list if not np.isnan(r.get('ATE', np.nan))])
384
+ },
385
+ "error": None
386
+ }
387
+ else:
388
+ return {"success": False, "result": None, "error": "No results generated."}
389
+
390
+ except Exception as e:
391
+ return {"success": False, "result": None, "error": str(e)}
392
+
393
+ @mcp.tool(name="parse_log_output", description="Extracts metrics from log text including ATE, New Effect, and P-Value.")
394
+ def parse_log_output_tool(output_text: str) -> dict:
395
+ """
396
+ Extracts metrics from log text including ATE, New Effect, and P-Value.
397
+
398
+ Args:
399
+ output_text (str): The log text containing the metrics.
400
+
401
+ Returns:
402
+ dict: A dictionary containing success, result (extracted metrics), and error fields.
403
+ """
404
+ try:
405
+ result = parse_log_output(output_text)
406
+ return {"success": True, "result": result, "error": None}
407
+ except Exception as e:
408
+ return {"success": False, "result": None, "error": str(e)}
409
+
410
+ # ====================== 绘图工具 ======================
411
+
412
+ @mcp.tool(name="plot_causal_results", description="Generate visualizations (radar chart, bar chart) from causal analysis Excel results.")
413
+ def plot_causal_results_tool(excel_path: str) -> dict:
414
+ """
415
+ Generate visualizations from causal analysis Excel results.
416
+
417
+ Args:
418
+ excel_path (str): Path to the Excel file containing causal analysis results.
419
+
420
+ Returns:
421
+ dict: A dictionary containing success, result, and error fields.
422
+ """
423
+ try:
424
+ if not os.path.exists(excel_path):
425
+ return {"success": False, "result": None, "error": f"Excel file not found: {excel_path}"}
426
+
427
+ plot_from_excel(excel_path)
428
+ return {"success": True, "result": f"Plots generated from {excel_path}", "error": None}
429
+ except Exception as e:
430
+ return {"success": False, "result": None, "error": str(e)}
431
+
432
+ # ====================== 综合分析工具 ======================
433
+
434
+ @mcp.tool(name="run_stroke_analysis", description="Run comprehensive stroke analysis. Unified entry point for all analysis tasks. User MUST specify task. For prediction task, user MUST also specify model. For feature_selection task, user MUST also specify method. For causal task, user MUST also specify treatment.")
435
+ def run_stroke_analysis_tool(task: str, model: str = None, method: str = None, treatment: str = None, file_path: str = None) -> dict:
436
+ """
437
+ Run comprehensive stroke analysis - unified entry point.
438
+
439
+ Args:
440
+ task (str): Analysis task. Options: 'pca', 'tsne', 'prediction', 'causal', 'feature_selection'.
441
+ model (str, optional): Model type for prediction task. Required when task='prediction'. Options: 'nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'.
442
+ method (str, optional): Method for feature selection. Required when task='feature_selection'. Options: 'kbest', 'rfecv', 'chi2'.
443
+ treatment (str, optional): Treatment variable for causal analysis. Required when task='causal'. Options: 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'.
444
+ file_path (str, optional): Path to the CSV data file.
445
+
446
+ Returns:
447
+ dict: A dictionary containing success, result, and error fields.
448
+ """
449
+ try:
450
+ valid_tasks = ['pca', 'tsne', 'prediction', 'causal', 'feature_selection']
451
+ if task not in valid_tasks:
452
+ return {"success": False, "result": None, "error": f"Invalid task. Choose from: {valid_tasks}"}
453
+
454
+ # 直接执行任务逻辑,避免调用被装饰的函数
455
+ data_file = file_path if file_path else DEFAULT_DATA_FILE
456
+
457
+ if task == 'pca':
458
+ loader = DataLoader(data_file)
459
+ df = loader.load_and_clean()
460
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
461
+ vis_dir = DEFAULT_VIS_DIR
462
+ os.makedirs(vis_dir, exist_ok=True)
463
+ analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='pca')
464
+ analyzer.perform_pca_famd()
465
+ return {"success": True, "result": f"PCA/FAMD analysis completed. Plots saved to {vis_dir}", "error": None}
466
+
467
+ elif task == 'tsne':
468
+ loader = DataLoader(data_file)
469
+ df = loader.load_and_clean()
470
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
471
+ vis_dir = DEFAULT_VIS_DIR
472
+ os.makedirs(vis_dir, exist_ok=True)
473
+ analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='tsne')
474
+ analyzer.perform_tsne()
475
+ return {"success": True, "result": f"t-SNE analysis completed. Plots saved to {vis_dir}", "error": None}
476
+
477
+ elif task == 'prediction':
478
+ if not model:
479
+ return {"success": False, "result": None, "error": "For prediction task, 'model' parameter is required. Options: 'nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'"}
480
+
481
+ valid_models = ['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn']
482
+ if model not in valid_models:
483
+ return {"success": False, "result": None, "error": f"Invalid model type. Choose from: {valid_models}"}
484
+
485
+ loader = DataLoader(data_file)
486
+ df = loader.load_and_clean()
487
+ X_train, X_test, y_train, y_test = loader.get_split_data()
488
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
489
+ vis_dir = DEFAULT_VIS_DIR
490
+ os.makedirs(vis_dir, exist_ok=True)
491
+ manager = ModelManager(X_train, y_train, X_test, y_test, save_dir=vis_dir, timestamp=timestamp)
492
+
493
+ if model == 'dnn':
494
+ manager.run_dnn()
495
+ else:
496
+ manager.run_sklearn_model(model)
497
+ return {"success": True, "result": f"{model.upper()} model training completed. Plots saved to {vis_dir}", "error": None}
498
+
499
+ elif task == 'causal':
500
+ if not treatment:
501
+ return {"success": False, "result": None, "error": "For causal task, 'treatment' parameter is required. Options: 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'"}
502
+
503
+ valid_treatments = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
504
+ 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status']
505
+ if treatment not in valid_treatments:
506
+ return {"success": False, "result": None, "error": f"Invalid treatment. Choose from: {valid_treatments}"}
507
+
508
+ loader = DataLoader(data_file)
509
+ df = loader.load_and_clean()
510
+
511
+ # 智能领域编码
512
+ if 'gender' in df.columns:
513
+ df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
514
+ if 'ever_married' in df.columns:
515
+ df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
516
+ if 'Residence_type' in df.columns:
517
+ df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
518
+ if 'smoking_status' in df.columns:
519
+ smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
520
+ df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
521
+ if 'work_type' in df.columns:
522
+ work_map = {'Private': 1, 'Self-employed': 1, 'Govt_job': 0, 'children': 0, 'Never_worked': 0}
523
+ df['work_type'] = df['work_type'].map(work_map).fillna(0)
524
+
525
+ analyzer = CausalAnalyzer(df)
526
+ ate_value = analyzer.run_analysis(treatment_col=treatment)
527
+ result = {
528
+ "treatment": treatment,
529
+ "outcome": "stroke",
530
+ "ATE": ate_value,
531
+ "interpretation": f"If '{treatment}' increases by 1 unit, the probability of stroke changes by {ate_value:.4f}"
532
+ }
533
+ return {"success": True, "result": result, "error": None}
534
+
535
+ elif task == 'feature_selection':
536
+ if not method:
537
+ return {"success": False, "result": None, "error": "For feature_selection task, 'method' parameter is required. Options: 'kbest', 'rfecv', 'chi2'"}
538
+
539
+ valid_methods = ['kbest', 'rfecv', 'chi2']
540
+ if method not in valid_methods:
541
+ return {"success": False, "result": None, "error": f"Invalid method. Choose from: {valid_methods}"}
542
+
543
+ loader = DataLoader(data_file)
544
+ df = loader.load_and_clean()
545
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
546
+ vis_dir = DEFAULT_VIS_DIR
547
+ os.makedirs(vis_dir, exist_ok=True)
548
+ analyzer = FeatureSelectionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp)
549
+
550
+ if method == 'kbest':
551
+ analyzer.run_select_kbest()
552
+ elif method == 'rfecv':
553
+ analyzer.run_rfecv()
554
+ elif method == 'chi2':
555
+ analyzer.run_chi2()
556
+
557
+ return {"success": True, "result": f"Feature selection ({method}) completed. Plots saved to {vis_dir}", "error": None}
558
+
559
+ except Exception as e:
560
+ return {"success": False, "result": None, "error": str(e)}
561
+
562
+ @mcp.tool(name="get_available_options", description="Get all available options for stroke analysis tasks, models, methods, and treatment variables.")
563
+ def get_available_options_tool() -> dict:
564
+ """
565
+ Get all available options for stroke analysis.
566
+
567
+ Returns:
568
+ dict: A dictionary containing all available options.
569
+ """
570
+ return {
571
+ "success": True,
572
+ "result": {
573
+ "tasks": ['pca', 'tsne', 'prediction', 'causal', 'feature_selection'],
574
+ "prediction_models": ['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'],
575
+ "feature_selection_methods": ['kbest', 'rfecv', 'chi2'],
576
+ "causal_treatments": ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
577
+ 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'],
578
+ "description": {
579
+ "pca": "PCA/FAMD dimensionality reduction analysis",
580
+ "tsne": "t-SNE visualization",
581
+ "prediction": "Stroke prediction using ML models",
582
+ "causal": "Causal inference using DoWhy framework",
583
+ "feature_selection": "Feature importance analysis"
584
+ }
585
+ },
586
+ "error": None
587
+ }
588
+
589
+ @mcp.tool(name="upload_to_hf_dataset", description="Upload generated file to HuggingFace Dataset for easy download")
590
+ def upload_to_hf_dataset(file_path: str, dataset_repo: str = None, hf_token: str = None) -> dict:
591
+ """
592
+ Upload generated file to HuggingFace Dataset repository.
593
+ Users can then download files from the Dataset page.
594
+
595
+ Parameters:
596
+ file_path (str): Absolute path to the file on HF Space
597
+ dataset_repo (str): Dataset repo name (e.g., 'username/obspy-outputs').
598
+ If None, uses HF_DATASET_REPO environment variable
599
+ hf_token (str): HuggingFace token. If None, uses HF_TOKEN environment variable
600
+
601
+ Returns:
602
+ dict: Success status and download URL
603
+ """
604
+ try:
605
+ from huggingface_hub import HfApi, create_repo
606
+ from huggingface_hub.utils import RepositoryNotFoundError
607
+
608
+ if dataset_repo is None:
609
+ dataset_repo = os.environ.get("HF_DATASET_REPO")
610
+ if not dataset_repo:
611
+ return {
612
+ "success": False,
613
+ "error": "Dataset repo not specified. Set HF_DATASET_REPO environment variable or pass dataset_repo parameter"
614
+ }
615
+
616
+ if hf_token is None:
617
+ hf_token = os.environ.get("HF_TOKEN")
618
+ if not hf_token:
619
+ return {
620
+ "success": False,
621
+ "error": "HF token not found. Set HF_TOKEN environment variable or pass hf_token parameter"
622
+ }
623
+
624
+ # Check if file exists and readable
625
+ if not os.path.exists(file_path):
626
+ return {
627
+ "success": False,
628
+ "error": f"File not found: {file_path}"
629
+ }
630
+
631
+ if not os.path.isfile(file_path):
632
+ return {
633
+ "success": False,
634
+ "error": f"Path is not a file: {file_path}"
635
+ }
636
+
637
+ if not os.access(file_path, os.R_OK):
638
+ return {
639
+ "success": False,
640
+ "error": f"File not readable (permission denied): {file_path}",
641
+ "hint": "Check file permissions in Docker container"
642
+ }
643
+
644
+ # Get file info for debugging
645
+ file_size = os.path.getsize(file_path)
646
+ file_stat = os.stat(file_path)
647
+
648
+ # Initialize HF API
649
+ api = HfApi()
650
+
651
+ # Try to create dataset repo if it doesn't exist
652
+ try:
653
+ repo_info = create_repo(
654
+ repo_id=dataset_repo,
655
+ repo_type="dataset",
656
+ token=hf_token,
657
+ exist_ok=True,
658
+ private=False
659
+ )
660
+ except Exception as e:
661
+ return {
662
+ "success": False,
663
+ "error": f"Failed to create/access dataset repo: {str(e)}",
664
+ "dataset_repo": dataset_repo,
665
+ "hint": "Check if HF_TOKEN has write permission and dataset_repo format is correct (username/repo-name)"
666
+ }
667
+
668
+ # Get filename and determine path in dataset
669
+ filename = os.path.basename(file_path)
670
+
671
+ # Determine subdirectory based on file location
672
+ if "/output/" in file_path:
673
+ path_in_repo = f"output/{filename}"
674
+ elif "/plots/" in file_path:
675
+ path_in_repo = f"plots/{filename}"
676
+ elif "/wave_data/" in file_path:
677
+ path_in_repo = f"wave_data/{filename}"
678
+ else:
679
+ path_in_repo = filename
680
+
681
+ try:
682
+ with open(file_path, 'rb') as f:
683
+ file_content = f.read()
684
+ except Exception as e:
685
+ return {
686
+ "success": False,
687
+ "error": f"Failed to read file: {str(e)}",
688
+ "file_path": file_path,
689
+ "hint": "File may exist but not readable due to permission issues"
690
+ }
691
+
692
+ # Upload file from memory instead of path
693
+ import io
694
+ upload_result = api.upload_file(
695
+ path_or_fileobj=io.BytesIO(file_content),
696
+ path_in_repo=path_in_repo,
697
+ repo_id=dataset_repo,
698
+ repo_type="dataset",
699
+ token=hf_token
700
+ )
701
+
702
+ # Construct download URL
703
+ download_url = f"https://huggingface.co/datasets/{dataset_repo}/resolve/main/{path_in_repo}"
704
+ viewer_url = f"https://huggingface.co/datasets/{dataset_repo}/viewer/default/train?f%5Bfile%5D%5Bvalue%5D={path_in_repo}"
705
+
706
+ return {
707
+ "success": True,
708
+ "message": f"File uploaded successfully to {dataset_repo}",
709
+ "dataset_repo": dataset_repo,
710
+ "filename": filename,
711
+ "path_in_repo": path_in_repo,
712
+ "download_url": download_url,
713
+ "viewer_url": viewer_url,
714
+ "usage": f"Download directly from: {download_url}"
715
+ }
716
+
717
+ except Exception as e:
718
+ return {
719
+ "success": False,
720
+ "error": str(e),
721
+ "hint": "Make sure HF_TOKEN and HF_DATASET_REPO are set in Space secrets"
722
+ }
723
+
724
+ def create_app() -> FastMCP:
725
+ """
726
+ Creates and returns the FastMCP application instance.
727
+
728
+ Returns:
729
+ FastMCP: The FastMCP application instance.
730
+ """
731
+ # 验证所有工具已正确注册
732
+ print("[MCP] Service initialized successfully")
733
+ print(f"[MCP] Registered tools: {len(mcp._tools) if hasattr(mcp, '_tools') else 'Unknown'}")
734
+ return mcp
Analyze-stroke/mcp_output/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastmcp>=0.1.0
2
+ pydantic>=2.0.0
3
+ numpy
4
+ pandas
5
+ scikit-learn
6
+ matplotlib
7
+
8
+ # Optional Dependencies
9
+ # seaborn
Analyze-stroke/mcp_output/simple_revise_error_analysis.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "status": "FAIL",
3
+ "next_action": "fix_directly",
4
+ "confidence": 0.9,
5
+ "summary": "The error is due to a missing Python module 'pandas'. This can be fixed directly by installing the 'pandas' package in the conda environment being used. The command to install the package is 'conda install pandas'. Ensure that the correct conda environment is activated before running the installation command. Once 'pandas' is installed, the script should be able to import it without errors."
6
+ }
Analyze-stroke/mcp_output/start_mcp.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCP Service Startup Entry
3
+ """
4
+ import sys
5
+ import os
6
+
7
+ project_root = os.path.dirname(os.path.abspath(__file__))
8
+ mcp_plugin_dir = os.path.join(project_root, "mcp_plugin")
9
+ if mcp_plugin_dir not in sys.path:
10
+ sys.path.insert(0, mcp_plugin_dir)
11
+
12
+ # Set path to source directory
13
+ source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
14
+ sys.path.insert(0, source_path)
15
+
16
+ from mcp_plugin.mcp_service import create_app
17
+
18
+ def main():
19
+ """Start FastMCP service"""
20
+ app = create_app()
21
+ # Use environment variable to configure port, default 8000
22
+ port = int(os.environ.get("MCP_PORT", "8000"))
23
+
24
+ # Choose transport mode based on environment variable
25
+ transport = os.environ.get("MCP_TRANSPORT", "stdio")
26
+ if transport == "http":
27
+ app.run(transport="http", host="0.0.0.0", port=port)
28
+ else:
29
+ # Default to STDIO mode
30
+ app.run()
31
+
32
+ if __name__ == "__main__":
33
+ main()
Analyze-stroke/mcp_output/tests_mcp/test_mcp_basic.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCP Service Basic Test
3
+ """
4
+ import sys
5
+ import os
6
+
7
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ mcp_plugin_dir = os.path.join(project_root, "mcp_plugin")
9
+ if mcp_plugin_dir not in sys.path:
10
+ sys.path.insert(0, mcp_plugin_dir)
11
+
12
+ source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
13
+ sys.path.insert(0, source_path)
14
+
15
+ def test_import_mcp_service():
16
+ """Test if MCP service can be imported normally"""
17
+ try:
18
+ from mcp_service import create_app
19
+ app = create_app()
20
+ assert app is not None
21
+ print("MCP service imported successfully")
22
+ return True
23
+ except Exception as e:
24
+ print("MCP service import failed: " + str(e))
25
+ return False
26
+
27
+ def test_adapter_init():
28
+ """Test if adapter can be initialized normally"""
29
+ try:
30
+ from adapter import Adapter
31
+ adapter = Adapter()
32
+ assert adapter is not None
33
+ print("Adapter initialized successfully")
34
+ return True
35
+ except Exception as e:
36
+ print("Adapter initialization failed: " + str(e))
37
+ return False
38
+
39
+ if __name__ == "__main__":
40
+ print("Running MCP service basic test...")
41
+ test1 = test_import_mcp_service()
42
+ test2 = test_adapter_init()
43
+
44
+ if test1 and test2:
45
+ print("All basic tests passed")
46
+ sys.exit(0)
47
+ else:
48
+ print("Some tests failed")
49
+ sys.exit(1)
Analyze-stroke/mcp_output/tests_smoke/test_smoke.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib, sys
2
+ import os
3
+
4
+ # Add current directory to Python path
5
+ sys.path.insert(0, os.getcwd())
6
+
7
+ source_dir = os.path.join(os.getcwd(), "source")
8
+ if os.path.exists(source_dir):
9
+ sys.path.insert(0, source_dir)
10
+
11
+
12
+ print("NO_PACKAGE - No testable package found")
Analyze-stroke/patient_data/README_DEMO.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This directory contains SYNTHETIC demo data only.
2
+ No real patient information.
3
+
4
+ Purpose
5
+ -------
6
+ Used for MCP security demos (path traversal / guardrail allowlist).
7
+
8
+ Files
9
+ -----
10
+ - stroke_clean.csv : synthetic stroke analysis dataset
11
+
12
+
Analyze-stroke/patient_data/stroke_clean.csv ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ patient_id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke,note
2
+ P20250001,Male,67.2,1,0,Yes,Private,Urban,198.4,31.6,formerly smoked,1,SYNTHETIC_DEMO_ONLY
3
+ P20250002,Female,52.8,0,0,Yes,Self-employed,Rural,141.2,28.9,never smoked,0,SYNTHETIC_DEMO_ONLY
4
+ P20250003,Male,74.5,1,1,Yes,Govt_job,Urban,212.7,34.1,smokes,1,SYNTHETIC_DEMO_ONLY
5
+ P20250004,Female,38.6,0,0,No,Private,Urban,96.5,22.4,Unknown,0,SYNTHETIC_DEMO_ONLY
6
+ P20250005,Male,59.1,1,0,Yes,Private,Rural,165.3,36.8,smokes,1,SYNTHETIC_DEMO_ONLY
7
+ P20250006,Female,46.9,0,0,Yes,Govt_job,Urban,128.7,27.6,never smoked,0,SYNTHETIC_DEMO_ONLY
8
+ P20250007,Male,63.4,0,1,Yes,Self-employed,Urban,183.9,30.8,formerly smoked,1,SYNTHETIC_DEMO_ONLY
9
+ P20250008,Female,29.7,0,0,No,Private,Rural,88.2,20.1,never smoked,0,SYNTHETIC_DEMO_ONLY
10
+ P20250009,Male,71.9,1,0,Yes,Private,Urban,156.4,29.3,smokes,1,SYNTHETIC_DEMO_ONLY
11
+ P20250010,Female,57.3,0,0,Yes,Self-employed,Urban,149.8,33.7,formerly smoked,1,SYNTHETIC_DEMO_ONLY
12
+ P20250011,Male,44.2,0,0,Yes,Private,Rural,121.1,25.9,never smoked,0,SYNTHETIC_DEMO_ONLY
13
+ P20250012,Female,66.8,1,1,Yes,Govt_job,Urban,224.5,37.2,formerly smoked,1,SYNTHETIC_DEMO_ONLY
14
+ P20250013,Male,61.0,0,0,Yes,Private,Urban,172.9,35.4,smokes,1,SYNTHETIC_DEMO_ONLY
15
+ P20250014,Female,49.6,0,0,Yes,Private,Rural,110.6,24.7,Unknown,0,SYNTHETIC_DEMO_ONLY
16
+ P20250015,Male,80.2,1,1,Yes,Self-employed,Rural,238.1,32.2,formerly smoked,1,SYNTHETIC_DEMO_ONLY
17
+ P20250016,Female,35.9,0,0,No,Private,Urban,92.9,21.8,never smoked,0,SYNTHETIC_DEMO_ONLY
18
+ P20250017,Male,54.7,1,0,Yes,Govt_job,Urban,146.3,31.1,smokes,1,SYNTHETIC_DEMO_ONLY
19
+ P20250018,Female,62.5,0,0,Yes,Private,Rural,137.4,29.9,formerly smoked,0,SYNTHETIC_DEMO_ONLY
20
+ P20250019,Male,58.8,0,0,Yes,Private,Urban,181.6,34.8,smokes,1,SYNTHETIC_DEMO_ONLY
21
+ P20250020,Female,73.1,1,0,Yes,Self-employed,Urban,201.3,30.5,Unknown,1,SYNTHETIC_DEMO_ONLY
22
+
23
+
Analyze-stroke/secrets/national_id.csv ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ patient_id,name,national_id,phone,address,risk_flag,note
2
+ P20250001,王伟,110101195805164218,17048293105,北京市健康路88号,high,SYNTHETIC_DEMO_ONLY
3
+ P20250002,李静,310101197210083526,17190517342,上海市杏林路17号,medium,SYNTHETIC_DEMO_ONLY
4
+ P20250003,陈强,440103195012295119,19961425083,广州市康复大道203号,high,SYNTHETIC_DEMO_ONLY
5
+ P20250004,周娜,440305198703171642,19107836291,深圳市医学路56号,low,SYNTHETIC_DEMO_ONLY
6
+ P20250005,吴涛,320102196605024533,16639274518,南京市科研路121号,high,SYNTHETIC_DEMO_ONLY
7
+ P20250006,郑丽,330106197809302016,17083517026,杭州市协和路9号,medium,SYNTHETIC_DEMO_ONLY
8
+ P20250007,冯磊,420106196109112738,17126094837,武汉市同济路74号,high,SYNTHETIC_DEMO_ONLY
9
+ P20250008,褚倩,510104199512054821,19950218366,成都市仁心街14号,low,SYNTHETIC_DEMO_ONLY
10
+ P20250009,张军,110101195309063911,19193458207,北京市医学路208号,high,SYNTHETIC_DEMO_ONLY
11
+ P20250010,何丹,31010119670729522X,16618473920,上海市康复大道66号,medium,SYNTHETIC_DEMO_ONLY
12
+ P20250011,许杰,440103198010124617,17067290813,广州市杏林路191号,low,SYNTHETIC_DEMO_ONLY
13
+ P20250012,吕秀英,440305195611285178,17141873054,深圳市健康路33号,high,SYNTHETIC_DEMO_ONLY
14
+ P20250013,施鹏,320102196405105855,19912039758,南京市协和路152号,high,SYNTHETIC_DEMO_ONLY
15
+ P20250014,曹玲,33010619751222341X,19150268491,杭州市科研路27号,low,SYNTHETIC_DEMO_ONLY
16
+ P20250015,严超,420106194504036412,16693502741,武汉市仁心街102号,high,SYNTHETIC_DEMO_ONLY
17
+ P20250016,华倩,510104198907191984,17028549106,成都市同济路6号,low,SYNTHETIC_DEMO_ONLY
18
+ P20250017,金鑫,110101197104144535,17179364128,北京市康复大道118号,medium,SYNTHETIC_DEMO_ONLY
19
+ P20250018,魏晨,310101196212315470,19968295137,上海市医学路41号,medium,SYNTHETIC_DEMO_ONLY
20
+ P20250019,陶昊,440103196611103903,19135791426,广州市协和路219号,high,SYNTHETIC_DEMO_ONLY
21
+ P20250020,姜然,440305195202184962,16674261039,深圳市科研路12号,high,SYNTHETIC_DEMO_ONLY
22
+
23
+
Analyze-stroke/source/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Analyze-stroke Project Package Initialization File
4
+ """
Analyze-stroke/source/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (280 Bytes). View file
 
Analyze-stroke/source/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (252 Bytes). View file
 
Analyze-stroke/source/__pycache__/causal_module.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
Analyze-stroke/source/__pycache__/data_loader.cpython-311.pyc ADDED
Binary file (2.68 kB). View file
 
Analyze-stroke/source/__pycache__/data_loader.cpython-39.pyc ADDED
Binary file (1.72 kB). View file
 
Analyze-stroke/source/causal_module.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/causal_module.py
2
+ import dowhy
3
+ from dowhy import CausalModel
4
+ import pandas as pd
5
+ import numpy as np
6
+ import warnings
7
+
8
+ class CausalAnalyzer:
9
+ def __init__(self, data):
10
+ self.data = data.copy()
11
+ for col in self.data.columns:
12
+ if self.data[col].dtype == bool:
13
+ self.data[col] = self.data[col].astype(int)
14
+
15
+ def define_causal_graph(self):
16
+ """
17
+ 定义领域知识 DAG
18
+ """
19
+ causal_graph = """
20
+ digraph {
21
+ age -> hypertension;
22
+ age -> heart_disease;
23
+ age -> avg_glucose_level;
24
+ age -> stroke;
25
+ age -> ever_married;
26
+ age -> work_type;
27
+
28
+ bmi -> hypertension;
29
+ bmi -> heart_disease;
30
+ bmi -> stroke;
31
+
32
+ hypertension -> stroke;
33
+ heart_disease -> stroke;
34
+ avg_glucose_level -> stroke;
35
+
36
+ work_type -> bmi;
37
+ work_type -> hypertension;
38
+ work_type -> stroke;
39
+
40
+ smoking_status -> heart_disease;
41
+ smoking_status -> stroke;
42
+
43
+ Residence_type -> bmi;
44
+ Residence_type -> stroke;
45
+
46
+ gender -> stroke;
47
+ gender -> smoking_status;
48
+
49
+ ever_married -> stroke;
50
+ ever_married -> bmi;
51
+ }
52
+ """
53
+ return causal_graph.replace("\n", " ")
54
+
55
+ def run_analysis(self, treatment_col, outcome_col='stroke'):
56
+ print(f"\n=== Causal Analysis: Effect of '{treatment_col}' on '{outcome_col}' ===")
57
+
58
+ warnings.filterwarnings("ignore")
59
+
60
+ model = CausalModel(
61
+ data=self.data,
62
+ treatment=treatment_col,
63
+ outcome=outcome_col,
64
+ graph=self.define_causal_graph()
65
+ )
66
+
67
+ identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
68
+
69
+ estimate = None
70
+ try:
71
+ print("[Causal] Attempting Propensity Score Stratification...")
72
+ estimate = model.estimate_effect(
73
+ identified_estimand,
74
+ method_name="backdoor.propensity_score_stratification"
75
+ )
76
+ except Exception as e:
77
+ error_msg = str(e)
78
+ if "No common causes" in error_msg or "Propensity score" in error_msg:
79
+ print(f"[Info] Switching to Linear Regression (Root Node/No Confounders detected).")
80
+ estimate = model.estimate_effect(
81
+ identified_estimand,
82
+ method_name="backdoor.linear_regression"
83
+ )
84
+ else:
85
+ print(f"[Error] Estimation failed: {e}")
86
+ return 0.0
87
+
88
+ if estimate is not None:
89
+ print(f"\n[Result] Causal Estimate (ATE): {estimate.value:.4f}")
90
+ print(f"[Interpretation] If '{treatment_col}' increases by 1 unit (or switches from 0 to 1),")
91
+ print(f" the probability of '{outcome_col}' changes by {estimate.value:.4f}")
92
+
93
+ print("\n[Causal] Refuting estimate (Random Common Cause)...")
94
+ try:
95
+ refute = model.refute_estimate(
96
+ identified_estimand,
97
+ estimate,
98
+ method_name="random_common_cause",
99
+ show_progress_bar=False
100
+ )
101
+ print(f"Refutation p-value: {refute.refutation_result['p_value']:.4f}")
102
+ print(f"Refutation New Effect: {refute.new_effect:.4f}")
103
+ except Exception as e:
104
+ print(f"[Warning] Refutation failed: {e}")
105
+
106
+ return estimate.value
107
+ else:
108
+ return 0.0
Analyze-stroke/source/data_loader.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/data_loader.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.impute import KNNImputer
6
+
7
+ class DataLoader:
8
+ def __init__(self, filepath, target_col='stroke'):
9
+ self.filepath = filepath
10
+ self.target_col = target_col
11
+ self.data = None
12
+
13
+ def load_and_clean(self):
14
+ """
15
+ 加载数据,处理缺失值,删除无关列。
16
+ """
17
+ print(f"[Info] Loading data from {self.filepath}...")
18
+ df = pd.read_csv(self.filepath)
19
+
20
+ # 1. 删除 ID 列,因为它对预测无用
21
+ if 'id' in df.columns:
22
+ df = df.drop(columns=['id'])
23
+
24
+ # 2. 处理 Gender 中的 'Other' (仅有1例,通常建议删除)
25
+ df = df[df['gender'] != 'Other']
26
+
27
+ # 3. 处理 BMI 缺失值
28
+ # 你的文档提到使用回归填充,这里为了通用性使用 KNNImputer (效果优于均值)
29
+ # 注意:Imputer只能处理数值,需要先暂时把分类变量放一边或编码,
30
+ # 为了简便且保持原始数据类型供后续分析,这里仅对 bmi 进行均值/中位数填充,
31
+ # 或者使用简单的策略。为了完全符合文档的高级要求,我们保留原始NaN让后续Pipeline处理,
32
+ # 但为了因果推断方便,这里直接填补。
33
+ df['bmi'] = df['bmi'].fillna(df['bmi'].mean())
34
+
35
+ self.data = df
36
+ print(f"[Info] Data loaded. Shape: {self.data.shape}")
37
+ return self.data
38
+
39
+ def get_split_data(self, test_size=0.2, random_state=42):
40
+ """
41
+ 返回 8:2 划分的训练集和测试集 (X_train, X_test, y_train, y_test)
42
+ """
43
+ if self.data is None:
44
+ self.load_and_clean()
45
+
46
+ X = self.data.drop(columns=[self.target_col])
47
+ y = self.data[self.target_col]
48
+
49
+ # Stratify 确保训练集和测试集中 stroke=1 的比例一致
50
+ X_train, X_test, y_train, y_test = train_test_split(
51
+ X, y, test_size=test_size, random_state=random_state, stratify=y
52
+ )
53
+
54
+ print(f"[Info] Data split 80:20 completed.")
55
+ return X_train, X_test, y_train, y_test
Analyze-stroke/source/dim_reduction.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/dim_reduction.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import os
7
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.manifold import TSNE
10
+
11
+ try:
12
+ import prince
13
+ if not hasattr(prince, 'FAMD'):
14
+ from prince.famd import FAMD
15
+ else:
16
+ FAMD = prince.FAMD
17
+ except ImportError:
18
+ FAMD = None
19
+ print("[Warning] 'prince' library not found. PCA/FAMD may fail.")
20
+
21
+ class DimensionAnalyzer:
22
+ def __init__(self, data, target_col='stroke', save_dir=None, timestamp=None, task_name='dim_red'):
23
+ self.data = data
24
+ self.target = data[target_col]
25
+ self.features = data.drop(columns=[target_col])
26
+
27
+ # 保存配置
28
+ self.save_dir = save_dir
29
+ self.timestamp = timestamp
30
+ self.task_name = task_name
31
+
32
+ self.num_cols = ['age', 'avg_glucose_level', 'bmi']
33
+ self.num_cols = [c for c in self.num_cols if c in self.features.columns]
34
+ self.cat_cols = [c for c in self.features.columns if c not in self.num_cols]
35
+
36
+ def _save_plot(self, plot_type):
37
+ """内部辅助函数:生成文件名并保存"""
38
+ if self.save_dir and self.timestamp:
39
+ # 命名规则:Task_Type_Timestamp.png
40
+ # 例如:pca_famd_scatter_20251203_1430.png
41
+ filename = f"{self.task_name}_{plot_type}_{self.timestamp}.png"
42
+ full_path = os.path.join(self.save_dir, filename)
43
+ plt.savefig(full_path, dpi=300, bbox_inches='tight')
44
+ print(f"[Info] Plot saved: {full_path}")
45
+
46
+ def plot_embedding(self, X_embedded, title, plot_type_suffix="scatter"):
47
+ """通用绘图函数"""
48
+ plt.figure(figsize=(10, 8))
49
+ unique_targets = sorted(self.target.unique())
50
+ if set(unique_targets).issubset({0, 1}):
51
+ labels = self.target.map({0: 'No Stroke', 1: 'Stroke'})
52
+ else:
53
+ labels = self.target.astype(str)
54
+
55
+ sns.scatterplot(
56
+ x=X_embedded[:, 0],
57
+ y=X_embedded[:, 1],
58
+ hue=labels,
59
+ palette={'No Stroke': 'skyblue', 'Stroke': 'red'} if 'No Stroke' in labels.values else None,
60
+ alpha=0.7,
61
+ s=60
62
+ )
63
+ plt.title(title, fontsize=15)
64
+ plt.xlabel('Component 1')
65
+ plt.ylabel('Component 2')
66
+ plt.legend(title='Condition')
67
+ plt.grid(True, linestyle='--', alpha=0.5)
68
+
69
+ # 自动保存
70
+ self._save_plot(plot_type_suffix)
71
+ plt.show()
72
+
73
+ def perform_pca_famd(self):
74
+ """执行 FAMD 并保存结果"""
75
+ if FAMD is None: return
76
+
77
+ print("[Analysis] Performing FAMD...")
78
+ try:
79
+ famd = FAMD(n_components=2, n_iter=200, copy=True, check_input=True, engine='sklearn', random_state=42)
80
+ X_famd = famd.fit_transform(self.features)
81
+
82
+ # 1. 保存散点图
83
+ self.plot_embedding(X_famd.values if hasattr(X_famd, 'values') else X_famd,
84
+ "FAMD Visualization (Sample Distribution)",
85
+ plot_type_suffix="famd_scatter")
86
+
87
+ # 2. 保存载荷图 (Loadings)
88
+ print("[Analysis] Generating FAMD Loadings Plot...")
89
+ if hasattr(famd, 'column_coordinates_'):
90
+ coords = famd.column_coordinates_
91
+ elif hasattr(famd, 'column_correlations'):
92
+ coords = famd.column_correlations(self.features)
93
+ else:
94
+ return
95
+
96
+ comp0_loadings = coords.iloc[:, 0].sort_values(ascending=False)
97
+
98
+ plt.figure(figsize=(12, 6))
99
+ sns.barplot(x=comp0_loadings.values, y=comp0_loadings.index, palette="viridis")
100
+ plt.title("FAMD Loadings (Component 0) - Variable Contributions", fontsize=14)
101
+ plt.xlabel("Correlation with Component 0", fontsize=12)
102
+ plt.axvline(0, color='black', linestyle='--', linewidth=0.8)
103
+ plt.tight_layout()
104
+
105
+ # 保存载荷图
106
+ self._save_plot("famd_loadings")
107
+ plt.show()
108
+
109
+ except Exception as e:
110
+ print(f"[Error] FAMD analysis failed: {e}")
111
+
112
+ def perform_tsne(self):
113
+ """执行 t-SNE 并保存结果"""
114
+ print("[Analysis] Performing t-SNE...")
115
+ preprocessor = ColumnTransformer(
116
+ transformers=[
117
+ ('num', StandardScaler(), self.num_cols),
118
+ ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), self.cat_cols)
119
+ ])
120
+ X_processed = preprocessor.fit_transform(self.features)
121
+
122
+ try:
123
+ tsne = TSNE(n_components=2, perplexity=40, max_iter=1000, random_state=42, n_jobs=-1)
124
+ except TypeError:
125
+ tsne = TSNE(n_components=2, perplexity=40, n_iter=1000, random_state=42, n_jobs=-1)
126
+
127
+ X_tsne = tsne.fit_transform(X_processed)
128
+ # 保存 t-SNE 散点图
129
+ self.plot_embedding(X_tsne, "t-SNE Visualization", plot_type_suffix="tsne_scatter")
130
+ print("[Analysis] t-SNE completed.")
Analyze-stroke/source/environment.yml ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: stroke
2
+ channels:
3
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
4
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
5
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
6
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
7
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
8
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
9
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
10
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/
11
+ - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
12
+ dependencies:
13
+ - _openmp_mutex=4.5=2_gnu
14
+ - anyio=4.12.0=pyhcf101f3_0
15
+ - backports=1.0=pyhd8ed1ab_5
16
+ - backports.tarfile=1.2.0=pyhd8ed1ab_1
17
+ - biopython=1.86=py310h29418f3_0
18
+ - blas=1.0=mkl
19
+ - brotli-python=1.2.0=py310hfff998d_1
20
+ - bzip2=1.0.8=h2bbff1b_6
21
+ - ca-certificates=2025.11.12=h4c7d964_0
22
+ - cachecontrol=0.14.3=pyha770c72_0
23
+ - cachecontrol-with-filecache=0.14.3=pyhd8ed1ab_0
24
+ - cairo=1.18.4=he9e932c_0
25
+ - causal-learn=0.1.4.3=pyhd8ed1ab_0
26
+ - certifi=2025.11.12=pyhd8ed1ab_0
27
+ - cffi=2.0.0=py310h29418f3_1
28
+ - charset-normalizer=3.4.4=pyhd8ed1ab_0
29
+ - clarabel=0.11.1=py310h1563f47_1
30
+ - cleo=2.1.0=pyhd8ed1ab_1
31
+ - colorama=0.4.6=pyhd8ed1ab_1
32
+ - contourpy=1.3.1=py310h214f63a_0
33
+ - crashtest=0.4.1=pyhd8ed1ab_1
34
+ - cuda-version=12.9=h4f385c5_3
35
+ - cvxpy=1.4.3=py310h5588dad_0
36
+ - cvxpy-base=1.4.3=py310hecd3228_0
37
+ - cycler=0.11.0=pyhd3eb1b0_0
38
+ - cython=3.2.2=py310h23e71ea_0
39
+ - distlib=0.4.0=pyhd8ed1ab_0
40
+ - dowhy=0.13=pyhd8ed1ab_1
41
+ - dulwich=0.24.10=py310hf13f778_1
42
+ - ecos=2.0.14=np2py310h04ddbaa_3
43
+ - et_xmlfile=2.0.0=py310haa95532_0
44
+ - exceptiongroup=1.3.1=pyhd8ed1ab_0
45
+ - expat=2.7.3=h9214b88_0
46
+ - filelock=3.20.0=pyhd8ed1ab_0
47
+ - findpython=0.7.1=pyh332efcf_0
48
+ - font-ttf-dejavu-sans-mono=2.37=0
49
+ - font-ttf-inconsolata=2.000=0
50
+ - font-ttf-source-code-pro=2.030=0
51
+ - font-ttf-ubuntu=0.83=0
52
+ - fontconfig=2.15.0=hd211d86_0
53
+ - fonts-conda-ecosystem=1=0
54
+ - fonts-conda-forge=1=hc364b38_1
55
+ - fonttools=4.60.1=py310h02ab6af_0
56
+ - freetype=2.14.1=h57928b3_0
57
+ - fribidi=1.0.16=hfd05255_0
58
+ - getopt-win32=0.1=h6a83c73_3
59
+ - graphite2=1.3.14=hd77b12b_1
60
+ - graphviz=9.0.0=h51cb2cd_1
61
+ - gts=0.7.6=h6b5321d_4
62
+ - h11=0.16.0=pyhd8ed1ab_0
63
+ - h2=4.3.0=pyhcf101f3_0
64
+ - harfbuzz=10.2.0=he2f9f60_1
65
+ - hpack=4.1.0=pyhd8ed1ab_0
66
+ - httpcore=1.0.9=pyh29332c3_0
67
+ - httpx=0.28.1=pyhd8ed1ab_0
68
+ - hyperframe=6.1.0=pyhd8ed1ab_0
69
+ - icc_rt=2022.1.0=h6049295_2
70
+ - icu=73.1=h6c2663c_0
71
+ - idna=3.11=pyhd8ed1ab_0
72
+ - imbalanced-learn=0.14.0=pyhd8ed1ab_0
73
+ - importlib-metadata=8.6.1=pyha770c72_0
74
+ - importlib_resources=6.5.2=pyhd8ed1ab_0
75
+ - intel-openmp=2025.0.0=haa95532_1164
76
+ - jaraco.classes=3.4.0=pyhd8ed1ab_2
77
+ - jaraco.context=6.0.1=pyhd8ed1ab_0
78
+ - jaraco.functools=4.3.0=pyhd8ed1ab_0
79
+ - jinja2=3.1.6=pyhcf101f3_1
80
+ - joblib=1.5.2=pyhd8ed1ab_0
81
+ - jpeg=9f=ha349fce_0
82
+ - keyring=25.7.0=pyh7428d3b_0
83
+ - kiwisolver=1.4.8=py310h5da7b33_0
84
+ - lcms2=2.16=hb4a4139_0
85
+ - lerc=3.0=hd77b12b_0
86
+ - libclang13=14.0.6=default_h77d9078_1
87
+ - libdeflate=1.17=h2bbff1b_1
88
+ - libexpat=2.7.3=hac47afa_0
89
+ - libffi=3.4.4=hd77b12b_1
90
+ - libfreetype=2.14.1=h57928b3_0
91
+ - libfreetype6=2.14.1=hdbac1cb_0
92
+ - libgd=2.3.3=ha43c60c_1
93
+ - libglib=2.84.4=hfaec014_0
94
+ - libgomp=15.2.0=h8ee18e1_14
95
+ - libiconv=1.16=h2bbff1b_3
96
+ - libkrb5=1.21.3=h885b0b7_4
97
+ - libosqp=1.0.0=np2py312h325a191_2
98
+ - libpng=1.6.50=h46444df_0
99
+ - libpq=17.6=h652a1e2_0
100
+ - libqdldl=0.1.8=he0c23c2_1
101
+ - libtiff=4.5.1=hd77b12b_0
102
+ - libwebp-base=1.3.2=h3d04722_1
103
+ - libwinpthread=12.0.0.r4.gg4f2fc60ca=h57928b3_10
104
+ - libxgboost=3.1.2=cuda129_h8216274_1
105
+ - libxml2=2.13.9=h6201b9f_0
106
+ - libxslt=1.1.43=h25c3957_0
107
+ - libzlib=1.3.1=h02ab6af_0
108
+ - llvmlite=0.45.1=py310hfe4b161_0
109
+ - lz4-c=1.9.4=h2bbff1b_1
110
+ - markupsafe=3.0.3=py310hdb0e946_0
111
+ - matplotlib=3.10.8=py310h5588dad_0
112
+ - matplotlib-base=3.10.8=py310h0bdd906_0
113
+ - minizip=4.0.6=hb638d1e_0
114
+ - mkl=2025.0.0=h5da7b33_930
115
+ - mkl-service=2.5.2=py310h0b37514_0
116
+ - mkl_fft=2.1.1=py310h300f80d_0
117
+ - mkl_random=1.3.0=py310ha5e6156_0
118
+ - more-itertools=10.8.0=pyhcf101f3_1
119
+ - mpmath=1.3.0=pyhd8ed1ab_1
120
+ - msgpack-python=1.1.2=py310he9f1925_1
121
+ - networkx=3.4.2=pyh267e887_2
122
+ - numba=0.62.1=py310h86ba7b5_0
123
+ - numpy=1.26.4=py310h12f7302_1
124
+ - numpy-base=1.26.4=py310he4e2855_1
125
+ - openjpeg=2.5.2=hae555c5_0
126
+ - openpyxl=3.1.5=py310h827c3e9_1
127
+ - openssl=3.0.18=h543e019_0
128
+ - osqp=1.0.5=np2py310h7578f05_2
129
+ - packaging=25.0=py310haa95532_1
130
+ - pandas=2.3.3=py310hed136d8_2
131
+ - pango=1.56.1=h286b592_0
132
+ - patsy=1.0.2=pyhcf101f3_0
133
+ - pbs-installer=2025.11.20=pyhd8ed1ab_0
134
+ - pcre2=10.46=h5740b90_0
135
+ - pillow=11.1.0=py310h096bfcc_0
136
+ - pip=25.3=pyhc872135_0
137
+ - pixman=0.46.4=h4043f72_0
138
+ - pkginfo=1.12.1.2=pyhd8ed1ab_0
139
+ - platformdirs=4.5.0=pyhcf101f3_0
140
+ - poetry=2.2.1=pyh7428d3b_0
141
+ - poetry-core=2.2.1=pyhd8ed1ab_0
142
+ - py-xgboost=3.1.2=cuda129_pyha75543c_1
143
+ - pycparser=2.22=pyh29332c3_1
144
+ - pydicom=3.0.1=py310haa95532_0
145
+ - pydot=1.4.2=py310h5588dad_4
146
+ - pyparsing=3.2.0=py310haa95532_0
147
+ - pyproject_hooks=1.2.0=pyhd8ed1ab_1
148
+ - pyside6=6.7.2=py310h5ef65bb_0
149
+ - pysocks=1.7.1=pyh09c184e_7
150
+ - python=3.10.19=h981015d_0
151
+ - python-build=1.3.0=pyhff2d567_0
152
+ - python-dateutil=2.9.0post0=py310haa95532_2
153
+ - python-fastjsonschema=2.21.2=pyhe01879c_0
154
+ - python-graphviz=0.21=pyhbacfb6d_0
155
+ - python-installer=0.7.0=pyhff2d567_1
156
+ - python-tzdata=2025.2=pyhd3eb1b0_0
157
+ - python_abi=3.10=2_cp310
158
+ - pytz=2025.2=py310haa95532_0
159
+ - pywin32-ctypes=0.2.3=py310h5588dad_3
160
+ - qdldl-python=0.1.7.post5=np2py310h7578f05_3
161
+ - qhull=2020.2=hc790b64_5
162
+ - qtbase=6.7.3=hd088775_4
163
+ - qtdeclarative=6.7.3=h885b0b7_1
164
+ - qtshadertools=6.7.3=h885b0b7_1
165
+ - qtsvg=6.7.3=h9d4b640_1
166
+ - qttools=6.7.3=hcb596f7_1
167
+ - qtwebchannel=6.7.3=h885b0b7_1
168
+ - qtwebengine=6.7.2=h601c93c_0
169
+ - qtwebsockets=6.7.3=h885b0b7_1
170
+ - rapidfuzz=3.14.3=py310h73ae2b4_1
171
+ - requests=2.32.5=pyhd8ed1ab_0
172
+ - requests-toolbelt=1.0.0=pyhd8ed1ab_1
173
+ - scikit-learn=1.7.2=py310h21054b0_0
174
+ - scipy=1.15.3=py310h1bbe36f_1
175
+ - scs=3.2.8=py310h8b33ddc_1
176
+ - seaborn=0.13.2=hd8ed1ab_3
177
+ - seaborn-base=0.13.2=pyhd8ed1ab_3
178
+ - setuptools=80.9.0=py310haa95532_0
179
+ - shellingham=1.5.4=pyhd8ed1ab_2
180
+ - six=1.17.0=py310haa95532_0
181
+ - sniffio=1.3.1=pyhd8ed1ab_2
182
+ - sqlite=3.51.0=hda9a48d_0
183
+ - statsmodels=0.14.5=py310h8f3aa81_1
184
+ - tbb=2022.0.0=h214f63a_0
185
+ - tbb-devel=2022.0.0=h214f63a_0
186
+ - threadpoolctl=3.6.0=pyhecae5ae_0
187
+ - tk=8.6.15=hf199647_0
188
+ - tomli=2.2.1=py310haa95532_0
189
+ - tomlkit=0.13.3=pyha770c72_0
190
+ - tornado=6.5.1=py310h827c3e9_0
191
+ - tqdm=4.67.1=pyhd8ed1ab_1
192
+ - trove-classifiers=2025.12.1.14=pyhd8ed1ab_0
193
+ - typing_extensions=4.15.0=pyhcf101f3_0
194
+ - tzdata=2025b=h04d1e81_0
195
+ - ucrt=10.0.22621.0=haa95532_0
196
+ - urllib3=2.5.0=pyhd8ed1ab_0
197
+ - vc=14.42=haa95532_5
198
+ - vc14_runtime=14.44.35208=h4927774_10
199
+ - virtualenv=20.35.4=pyhd8ed1ab_0
200
+ - vs2015_runtime=14.44.35208=ha6b5a95_10
201
+ - wheel=0.45.1=py310haa95532_0
202
+ - win_inet_pton=1.1.0=pyh7428d3b_8
203
+ - xgboost=3.1.2=cuda129_pyh5c66634_1
204
+ - xz=5.6.4=h4754444_1
205
+ - zipp=3.23.0=pyhcf101f3_1
206
+ - zlib=1.3.1=h02ab6af_0
207
+ - zstandard=0.25.0=py310h1637853_1
208
+ - zstd=1.5.7=h56299aa_0
209
+ - pip:
210
+ - altair==5.5.0
211
+ - attrs==25.4.0
212
+ - fsspec==2025.9.0
213
+ - jsonschema==4.25.1
214
+ - jsonschema-specifications==2025.9.1
215
+ - narwhals==2.13.0
216
+ - prince==0.16.2
217
+ - referencing==0.37.0
218
+ - rpds-py==0.30.0
219
+ - sympy==1.13.1
220
+ - torch==2.6.0+cu124
221
+ - torchaudio==2.6.0+cu124
222
+ - torchvision==0.21.0+cu124
223
+ prefix: C:\ProgramData\anaconda3\envs\stroke
Analyze-stroke/source/feature_selection.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/feature_selection.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import os
7
+
8
+ # 引入卡方检验和归一化
9
+ from sklearn.feature_selection import SelectKBest, mutual_info_classif, RFECV, chi2
10
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ from sklearn.compose import ColumnTransformer
13
+ from sklearn.model_selection import StratifiedKFold
14
+
15
+ class FeatureSelectionAnalyzer:
16
+ def __init__(self, data, target_col='stroke', save_dir=None, timestamp=None):
17
+ self.data = data
18
+ self.target = data[target_col]
19
+ self.features = data.drop(columns=[target_col])
20
+
21
+ self.save_dir = save_dir
22
+ self.timestamp = timestamp
23
+
24
+ self.num_cols = ['age', 'avg_glucose_level', 'bmi']
25
+ self.num_cols = [c for c in self.num_cols if c in self.features.columns]
26
+ self.cat_cols = [c for c in self.features.columns if c not in self.num_cols]
27
+
28
+ # 通用预处理管道 (用于互信息和 RFECV)
29
+ self.preprocessor = ColumnTransformer(
30
+ transformers=[
31
+ ('num', StandardScaler(), self.num_cols),
32
+ ('cat', OneHotEncoder(handle_unknown='ignore'), self.cat_cols)
33
+ ],
34
+ verbose_feature_names_out=False
35
+ )
36
+
37
+ def _save_plot(self, method_name, plot_name):
38
+ if self.save_dir and self.timestamp:
39
+ filename = f"feature_selection_{method_name}_{plot_name}_{self.timestamp}.png"
40
+ full_path = os.path.join(self.save_dir, filename)
41
+ plt.savefig(full_path, dpi=300, bbox_inches='tight')
42
+ print(f"[Info] Plot saved: {full_path}")
43
+
44
+ def run_select_kbest(self):
45
+ """策略 1: 互信息 (Mutual Information)"""
46
+ print("[Analysis] Running SelectKBest (Mutual Information)...")
47
+ X_processed = self.preprocessor.fit_transform(self.features)
48
+ feature_names = self.preprocessor.get_feature_names_out()
49
+
50
+ selector = SelectKBest(score_func=mutual_info_classif, k='all')
51
+ selector.fit(X_processed, self.target)
52
+
53
+ df_scores = pd.DataFrame({'Feature': feature_names, 'Score': selector.scores_})
54
+ df_scores = df_scores.sort_values(by='Score', ascending=False)
55
+
56
+ plt.figure(figsize=(12, 8))
57
+ sns.barplot(x='Score', y='Feature', data=df_scores, palette='viridis')
58
+ plt.title('Feature Importance via Mutual Information')
59
+ plt.xlabel('Mutual Information Score')
60
+ plt.tight_layout()
61
+ self._save_plot("kbest_mutual_info", "scores")
62
+ plt.show()
63
+
64
+ def run_rfecv(self):
65
+ """策略 2: 递归特征消除 (RFECV)"""
66
+ print("[Analysis] Running RFECV...")
67
+ X_processed = self.preprocessor.fit_transform(self.features)
68
+ feature_names = self.preprocessor.get_feature_names_out()
69
+
70
+ clf = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42, n_jobs=-1)
71
+ rfecv = RFECV(estimator=clf, step=1, cv=StratifiedKFold(5), scoring='f1')
72
+ rfecv.fit(X_processed, self.target)
73
+
74
+ print(f"[Result] Optimal features: {rfecv.n_features_}")
75
+
76
+ n_features = range(1, len(rfecv.cv_results_['mean_test_score']) + 1)
77
+ scores = rfecv.cv_results_['mean_test_score']
78
+
79
+ plt.figure(figsize=(10, 6))
80
+ plt.plot(n_features, scores, marker='o', color='b')
81
+ plt.xlabel("Number of Features")
82
+ plt.ylabel("CV F1 Score")
83
+ plt.title(f"RFECV Performance (Optimal: {rfecv.n_features_})")
84
+ plt.axvline(x=rfecv.n_features_, color='r', linestyle='--')
85
+ plt.grid(True, linestyle='--', alpha=0.5)
86
+ plt.tight_layout()
87
+ self._save_plot("rfecv", "performance_curve")
88
+ plt.show()
89
+
90
+ def run_chi2(self):
91
+ """
92
+ 策略 3: 卡方检验 (Chi-Square Test) - 新增
93
+ 注意:卡方检验要求输入非负,因此数值变量需使用 MinMaxScaler。
94
+ """
95
+ print("[Analysis] Running Chi-Square Test...")
96
+
97
+ # 1. 定义卡方专用预处理器 (使用 MinMaxScaler 确保非负)
98
+ chi2_preprocessor = ColumnTransformer(
99
+ transformers=[
100
+ ('num', MinMaxScaler(), self.num_cols), # 关键修改
101
+ ('cat', OneHotEncoder(handle_unknown='ignore'), self.cat_cols)
102
+ ],
103
+ verbose_feature_names_out=False
104
+ )
105
+
106
+ # 2. 转换数据
107
+ X_processed = chi2_preprocessor.fit_transform(self.features)
108
+ feature_names = chi2_preprocessor.get_feature_names_out()
109
+
110
+ # 3. 计算卡方统计量和 P值
111
+ # chi2 返回两个数组: (chi2_scores, p_values)
112
+ chi2_scores, p_values = chi2(X_processed, self.target)
113
+
114
+ # 4. 整理结果
115
+ df_results = pd.DataFrame({
116
+ 'Feature': feature_names,
117
+ 'Chi2_Score': chi2_scores,
118
+ 'P_Value': p_values
119
+ })
120
+
121
+ # 按分数排序 (分数越高,相关性越强)
122
+ df_results = df_results.sort_values(by='Chi2_Score', ascending=False)
123
+
124
+ # 打印前10个显著特征
125
+ print("\n[Result] Top 10 Features by Chi-Square Score:")
126
+ print(df_results[['Feature', 'Chi2_Score', 'P_Value']].head(10))
127
+
128
+ # 5. 可视化 (Chi2 Score)
129
+ plt.figure(figsize=(12, 8))
130
+ sns.barplot(x='Chi2_Score', y='Feature', data=df_results, palette='magma')
131
+ plt.title('Feature Importance via Chi-Square Test')
132
+ plt.xlabel('Chi-Square Statistics (Higher means more dependent)')
133
+ plt.grid(axis='x', linestyle='--', alpha=0.5)
134
+ plt.tight_layout()
135
+
136
+ self._save_plot("chi2", "scores")
137
+ plt.show()
138
+
139
+ # 额外提示 P值
140
+ print("\n[Note] P-Value < 0.05 通常表示统计学显著相关。")
Analyze-stroke/source/healthcare-dataset-stroke-data.csv ADDED
The diff for this file is too large to render. See raw diff
 
Analyze-stroke/source/main.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/main.py
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import datetime
6
+ import warnings
7
+
8
+ warnings.filterwarnings("ignore", category=FutureWarning)
9
+ warnings.filterwarnings("ignore", category=UserWarning)
10
+
11
+ from data_loader import DataLoader
12
+ from dim_reduction import DimensionAnalyzer
13
+ from models import ModelManager
14
+ from causal_module import CausalAnalyzer
15
+ from feature_selection import FeatureSelectionAnalyzer
16
+
17
+ class DualLogger:
18
+ def __init__(self, filepath):
19
+ self.terminal = sys.stdout
20
+ self.log = open(filepath, "a", encoding='utf-8')
21
+ def write(self, message):
22
+ self.terminal.write(message)
23
+ self.log.write(message)
24
+ def flush(self):
25
+ self.terminal.flush()
26
+ self.log.flush()
27
+
28
+ def main():
29
+ parser = argparse.ArgumentParser(description="Stroke Analysis Toolkit")
30
+
31
+ parser.add_argument('--task', type=str, required=True, choices=['pca', 'tsne', 'prediction', 'causal', 'feature_selection'])
32
+ parser.add_argument('--model', type=str, default='xgboost', choices=['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'])
33
+ parser.add_argument('--method', type=str, default='kbest', choices=['kbest', 'rfecv', 'chi2'])
34
+ parser.add_argument('--treatment', type=str, default='gender')
35
+ parser.add_argument('--file', type=str, default=r'D:\workspace\stroke\data\healthcare-dataset-stroke-data.csv')
36
+ parser.add_argument('--session_id', type=str, default=None)
37
+
38
+ args = parser.parse_args()
39
+
40
+ LOG_ROOT = r"D:\workspace\stroke\Analyze\logs"
41
+ if args.session_id:
42
+ session_dir_name = args.session_id
43
+ else:
44
+ session_dir_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
45
+
46
+ CURRENT_LOG_DIR = os.path.join(LOG_ROOT, session_dir_name)
47
+ os.makedirs(CURRENT_LOG_DIR, exist_ok=True)
48
+ timestamp_file = datetime.datetime.now().strftime("%H%M%S_%f")
49
+ log_filename = f"{args.task}_{args.model}_{args.method}_{args.treatment}_{timestamp_file}.log"
50
+ log_path = os.path.join(CURRENT_LOG_DIR, log_filename)
51
+
52
+ sys.stdout = DualLogger(log_path)
53
+
54
+ print(f"========================================================")
55
+ print(f"[Log] Session Folder: {session_dir_name}")
56
+ print(f"[Log] Output saved to: {log_path}")
57
+ print(f"========================================================\n")
58
+
59
+ if not os.path.exists(args.file):
60
+ print(f"Error: File {args.file} not found!")
61
+ return
62
+
63
+ VIS_DIR = r"D:\workspace\stroke\Analyze\visualization"
64
+ os.makedirs(VIS_DIR, exist_ok=True)
65
+
66
+ loader = DataLoader(args.file)
67
+ df = loader.load_and_clean()
68
+
69
+ if args.task in ['pca', 'tsne']:
70
+ analyzer = DimensionAnalyzer(df, save_dir=VIS_DIR, timestamp=session_dir_name, task_name=args.task)
71
+ if args.task == 'pca': analyzer.perform_pca_famd()
72
+ elif args.task == 'tsne': analyzer.perform_tsne()
73
+
74
+ elif args.task == 'prediction':
75
+ X_train, X_test, y_train, y_test = loader.get_split_data()
76
+ manager = ModelManager(X_train, y_train, X_test, y_test, save_dir=VIS_DIR, timestamp=session_dir_name)
77
+ if args.model == 'dnn': manager.run_dnn()
78
+ else: manager.run_sklearn_model(args.model)
79
+
80
+ elif args.task == 'causal':
81
+ # === 优化:智能领域编码 (Smart Domain Encoding) ===
82
+ print("[Info] Pre-processing data for Causal Inference (Applying Domain Knowledge)...")
83
+
84
+ # 1. 性别: Male=1, Female=0
85
+ if 'gender' in df.columns:
86
+ df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
87
+ print(" - Encoded 'gender': Male=1, Female/Other=0")
88
+
89
+ # 2. 结婚: Yes=1, No=0
90
+ if 'ever_married' in df.columns:
91
+ df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
92
+ print(" - Encoded 'ever_married': Yes=1, No=0")
93
+
94
+ # 3. 居住: Urban=1, Rural=0
95
+ if 'Residence_type' in df.columns:
96
+ df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
97
+ print(" - Encoded 'Residence_type': Urban=1, Rural=0")
98
+
99
+ # 4. 吸烟: 按风险等级编码 (Ordinal)
100
+ # 修正之前的字母顺序错误。假设:smokes > formerly > never/unknown
101
+ if 'smoking_status' in df.columns:
102
+ # 策略:将 'smokes' 和 'formerly' 视为有风险(1),其他(0)
103
+ # 或者使用 0, 1, 2 阶梯。这里为了 ATE 线性回归更准,建议二值化:是否吸烟
104
+ # 但为了保留细节,我们使用风险阶梯:
105
+ smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
106
+ df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
107
+ print(" - Encoded 'smoking_status' (Ordinal): Never/Unknown=0, Formerly=1, Smokes=2")
108
+
109
+ # 5. 工作类型: 标称变量无法直接线性回归
110
+ # 策略:关注"高压力/私人部门" vs "其他"。
111
+ if 'work_type' in df.columns:
112
+ # 将 Private 和 Self-employed 设为 1,其他设为 0
113
+ work_map = {
114
+ 'Private': 1, 'Self-employed': 1,
115
+ 'Govt_job': 0, 'children': 0, 'Never_worked': 0
116
+ }
117
+ df['work_type'] = df['work_type'].map(work_map).fillna(0)
118
+ print(" - Encoded 'work_type' (Binary): Private/Self-employed=1, Others=0")
119
+
120
+ analyzer = CausalAnalyzer(df)
121
+ analyzer.run_analysis(treatment_col=args.treatment)
122
+
123
+ elif args.task == 'feature_selection':
124
+ analyzer = FeatureSelectionAnalyzer(df, save_dir=VIS_DIR, timestamp=session_dir_name)
125
+ if args.method == 'kbest': analyzer.run_select_kbest()
126
+ elif args.method == 'rfecv': analyzer.run_rfecv()
127
+ elif args.method == 'chi2': analyzer.run_chi2()
128
+
129
+ if __name__ == "__main__":
130
+ main()
Analyze-stroke/source/models.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/models.py
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import os
7
+
8
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
9
+ from sklearn.compose import ColumnTransformer
10
+ from sklearn.naive_bayes import GaussianNB
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.svm import SVC
13
+ from sklearn.neighbors import KNeighborsClassifier
14
+ from sklearn.tree import DecisionTreeClassifier
15
+ from sklearn.ensemble import RandomForestClassifier
16
+ from xgboost import XGBClassifier
17
+ from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, confusion_matrix
18
+
19
+ from imblearn.over_sampling import SMOTE
20
+ from imblearn.pipeline import Pipeline as ImbPipeline
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.optim as optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ class ModelManager:
28
+ def __init__(self, X_train, y_train, X_test, y_test, save_dir=None, timestamp=None):
29
+ self.X_train = X_train
30
+ self.y_train = y_train
31
+ self.X_test = X_test
32
+ self.y_test = y_test
33
+
34
+ # 保存配置
35
+ self.save_dir = save_dir
36
+ self.timestamp = timestamp
37
+
38
+ self.num_cols = ['age', 'avg_glucose_level', 'bmi']
39
+ self.cat_cols = [c for c in X_train.columns if c not in self.num_cols]
40
+
41
+ def _save_plot(self, model_name, content_type):
42
+ """内部辅助函数:保存预测相关图片"""
43
+ if self.save_dir and self.timestamp:
44
+ # 命名规则:prediction_{Model}_{ContentType}_{Timestamp}.png
45
+ # 例如:prediction_xgboost_feature_importance_20251203_1430.png
46
+ filename = f"prediction_{model_name}_{content_type}_{self.timestamp}.png"
47
+ full_path = os.path.join(self.save_dir, filename)
48
+ plt.savefig(full_path, dpi=300, bbox_inches='tight')
49
+ print(f"[Info] Plot saved: {full_path}")
50
+
51
+ def _get_pipeline(self, classifier):
52
+ preprocessor = ColumnTransformer(
53
+ transformers=[
54
+ ('num', StandardScaler(), self.num_cols),
55
+ ('cat', OneHotEncoder(handle_unknown='ignore'), self.cat_cols)
56
+ ],
57
+ verbose_feature_names_out=False
58
+ )
59
+ pipeline = ImbPipeline(steps=[
60
+ ('preprocessor', preprocessor),
61
+ ('smote', SMOTE(random_state=42)),
62
+ ('classifier', classifier)
63
+ ])
64
+ return pipeline
65
+
66
+ def evaluate(self, y_true, y_pred, y_prob=None, model_name="Model"):
67
+ # ... (保持原有的 evaluate 逻辑不变)
68
+ acc = accuracy_score(y_true, y_pred)
69
+ sens = recall_score(y_true, y_pred)
70
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
71
+ spec = tn / (tn + fp) if (tn + fp) > 0 else 0
72
+ f1 = f1_score(y_true, y_pred)
73
+ auc = roc_auc_score(y_true, y_prob) if y_prob is not None else "N/A"
74
+
75
+ print(f"\n--- Results for {model_name} ---")
76
+ print(f"Accuracy: {acc:.4f}")
77
+ print(f"Sensitivity: {sens:.4f}")
78
+ print(f"Specificity: {spec:.4f}")
79
+ print(f"F1 Score: {f1:.4f}")
80
+ print(f"ROC-AUC: {auc}")
81
+ return {'acc': acc, 'sens': sens, 'spec': spec, 'f1': f1, 'auc': auc}
82
+
83
+ def run_sklearn_model(self, model_type):
84
+ models_map = {
85
+ 'nb': GaussianNB(),
86
+ 'lr': LogisticRegression(max_iter=2000, class_weight='balanced'),
87
+ 'svm': SVC(probability=True, class_weight='balanced'),
88
+ 'knn': KNeighborsClassifier(),
89
+ 'dt': DecisionTreeClassifier(class_weight='balanced'),
90
+ 'rf': RandomForestClassifier(class_weight='balanced', random_state=42),
91
+ 'xgboost': XGBClassifier(scale_pos_weight=19, eval_metric='logloss', random_state=42)
92
+ }
93
+
94
+ if model_type not in models_map:
95
+ raise ValueError(f"Unknown model type: {model_type}")
96
+
97
+ print(f"[Training] Running {model_type}...")
98
+ clf = models_map[model_type]
99
+ pipeline = self._get_pipeline(clf)
100
+ pipeline.fit(self.X_train, self.y_train)
101
+
102
+ # === 1. 逻辑回归系数可视化 ===
103
+ if model_type == 'lr':
104
+ print("\n[Analysis] Extracting Logistic Regression Coefficients...")
105
+ try:
106
+ model = pipeline.named_steps['classifier']
107
+ feature_names = pipeline.named_steps['preprocessor'].get_feature_names_out()
108
+ coefs = model.coef_[0]
109
+ coef_df = pd.DataFrame({'Feature': feature_names, 'Coefficient': coefs})
110
+ coef_df['Abs_Coef'] = coef_df['Coefficient'].abs()
111
+ coef_df = coef_df.sort_values(by='Abs_Coef', ascending=False).head(15)
112
+
113
+ plt.figure(figsize=(10, 8))
114
+ colors = ['red' if x > 0 else 'skyblue' for x in coef_df['Coefficient']]
115
+ sns.barplot(x='Coefficient', y='Feature', data=coef_df, palette=colors)
116
+ plt.title('Top 15 Factors (LR Coefficients)')
117
+ plt.xlabel('Coefficient Value')
118
+ plt.axvline(0, color='black', linestyle='--', linewidth=0.8)
119
+ plt.tight_layout()
120
+
121
+ # 保存
122
+ self._save_plot(model_type, "coefficients")
123
+ plt.show()
124
+ except Exception as e: print(e)
125
+
126
+ # === 2. 特征重要性 (XGBoost/RF/DT) ===
127
+ elif model_type in ['xgboost', 'rf', 'dt']:
128
+ try:
129
+ print(f"\n[Analysis] Extracting Feature Importance for {model_type}...")
130
+ model = pipeline.named_steps['classifier']
131
+ feature_names = pipeline.named_steps['preprocessor'].get_feature_names_out()
132
+ importances = model.feature_importances_
133
+
134
+ feat_imp = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
135
+ feat_imp = feat_imp.sort_values(by='Importance', ascending=False).head(10)
136
+
137
+ plt.figure(figsize=(10, 6))
138
+ sns.barplot(x='Importance', y='Feature', data=feat_imp, palette='viridis')
139
+ plt.title(f'Top 10 Features - {model_type}')
140
+ plt.tight_layout()
141
+
142
+ # 保存
143
+ self._save_plot(model_type, "feature_importance")
144
+ plt.show()
145
+ except Exception as e: print(e)
146
+
147
+ # 预测与评估
148
+ y_pred = pipeline.predict(self.X_test)
149
+ y_prob = None
150
+ if hasattr(pipeline, "predict_proba"):
151
+ y_prob = pipeline.predict_proba(self.X_test)[:, 1]
152
+ self.evaluate(self.y_test, y_pred, y_prob, model_name=model_type)
153
+
154
+ def run_dnn(self):
155
+ # 这里的 DNN 代码与之前保持一致,你可以选择性添加 loss curve 的绘图与保存
156
+ # 为了简洁,这里仅保留核心结构
157
+ print("[Training] Running DNN with Entity Embeddings...")
158
+ # ... (PyTorch 数据处理与模型定义代码与之前一致) ...
159
+ # 如果你需要在 DNN 中绘图,也可以调用 self._save_plot("dnn", "loss_curve")
160
+ pass
Analyze-stroke/source/plot_utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/plot_utils.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+ import argparse
7
+
8
+ # === 样式配置 ===
9
+ COLORS = {
10
+ 'ATE': "#0AB3F0", # 亮蓝色 (文字)
11
+ 'ATE_Line': "#0AB3F0CA", # 带透明度 (线和填充)
12
+ 'New Effect': "#F61467", # 亮粉红 (文字)
13
+ 'New Effect_Line': "#F61467EE", # 带透明度 (线)
14
+ 'P-Value': '#2A9D8F', # 青绿色
15
+ 'Threshold': '#F4A261', # 橙色
16
+ 'ZeroLine': '#FF0000', # 鲜红色
17
+ 'Grid': '#E0E0E0', # 浅灰色 (内部网格)
18
+ 'Spine': '#555555' # 深灰色 (最外圈边界)
19
+ }
20
+
21
+ def add_offset_labels(ax, angles, radii, labels, color, position='right'):
22
+ """
23
+ 辅助函数:在数据点旁边添加数值标签,支持左右偏移
24
+ """
25
+ if position == 'left':
26
+ xytext = (-8, 0)
27
+ ha = 'right'
28
+ else:
29
+ xytext = (8, 0)
30
+ ha = 'left'
31
+
32
+ for angle, radius, label in zip(angles, radii, labels):
33
+ ax.annotate(
34
+ str(label),
35
+ xy=(angle, radius),
36
+ xytext=xytext,
37
+ textcoords='offset points',
38
+ color=color,
39
+ size=9,
40
+ weight='bold',
41
+ ha=ha,
42
+ va='center',
43
+ bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="none", alpha=0.7)
44
+ )
45
+
46
+ def draw_effect_chart(df, save_path):
47
+ """
48
+ 绘制效应大小雷达图 (ATE & New Effect)
49
+ """
50
+ plot_df = df.copy()
51
+ categories = plot_df['Factor'].tolist()
52
+ N = len(categories)
53
+
54
+ angles = [n / float(N) * 2 * np.pi for n in range(N)]
55
+ angles += angles[:1]
56
+
57
+ fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
58
+
59
+ # === 0. 优化网格线 & 边界线 ===
60
+
61
+ # A. 内部网格:保持浅灰色虚线
62
+ ax.grid(True, color=COLORS['Grid'], linestyle='--', linewidth=1, alpha=0.8)
63
+
64
+ # B. 最外圈边界 (Spine):修改为深灰色实线
65
+ ax.spines['polar'].set_visible(True) # 确保显示
66
+ ax.spines['polar'].set_color(COLORS['Spine']) # 深灰色
67
+ ax.spines['polar'].set_linestyle('solid') # 实线
68
+ ax.spines['polar'].set_linewidth(1.5) # 稍微加粗
69
+
70
+ # 1. 设置标签
71
+ ax.set_xticks(angles[:-1])
72
+ ax.set_xticklabels(categories, color='black', size=13, weight='heavy')
73
+
74
+ # 2. 设置刻度
75
+ ax.set_yticklabels([])
76
+ ax.set_ylim(0, 1)
77
+
78
+ # 3. 数据归一化
79
+ max_ate = plot_df['ATE'].abs().max()
80
+ max_ne = plot_df['New Effect'].abs().max()
81
+ global_max = max(max_ate, max_ne)
82
+ if global_max == 0: global_max = 1e-9
83
+
84
+ r_ate = (0.5 + 0.5 * (plot_df['ATE'] / global_max)).tolist()
85
+ r_ne = (0.5 + 0.5 * (plot_df['New Effect'] / global_max)).tolist()
86
+ r_ate += r_ate[:1]
87
+ r_ne += r_ne[:1]
88
+
89
+ # 4. 绘图
90
+
91
+ # 鲜红色的 Zero Effect 线 (虚线)
92
+ circle_points = np.linspace(0, 2*np.pi, 100)
93
+ ax.plot(circle_points, [0.5]*100, color=COLORS['ZeroLine'], linestyle='--', linewidth=2, label='Zero Effect (0)', zorder=2)
94
+
95
+ # ATE
96
+ ax.plot(angles, r_ate, linewidth=3, color=COLORS['ATE_Line'], label='ATE (Causal Estimate)')
97
+ ax.fill(angles, r_ate, color=COLORS['ATE_Line'], alpha=0.1)
98
+ ax.scatter(angles[:-1], r_ate[:-1], color=COLORS['ATE'], s=80, zorder=5)
99
+
100
+ # New Effect
101
+ ax.plot(angles, r_ne, linewidth=2, linestyle=':', color=COLORS['New Effect_Line'], label='Refutation')
102
+ ax.scatter(angles[:-1], r_ne[:-1], facecolor='none', edgecolor=COLORS['New Effect'], s=80, marker='D', linewidth=2, zorder=6)
103
+
104
+ # 5. 数值标注 (分离)
105
+ labels_ate = plot_df['ATE'].round(4).tolist()
106
+ labels_ne = plot_df['New Effect'].round(4).tolist()
107
+
108
+ add_offset_labels(ax, angles[:-1], r_ate[:-1], labels_ate, color=COLORS['ATE'], position='left')
109
+ add_offset_labels(ax, angles[:-1], r_ne[:-1], labels_ne, color=COLORS['New Effect'], position='right')
110
+
111
+ # 6. 图饰
112
+ plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
113
+ plt.title(f"Causal Effect Size (Split Labels)\nMax Scale = +/- {global_max:.4f}", size=16, weight='bold', y=1.1)
114
+
115
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
116
+ print(f"[PlotUtils] Effect Size chart saved to: {save_path}")
117
+ plt.close()
118
+
119
+ def draw_pvalue_chart(df, save_path):
120
+ """
121
+ 绘制显著性雷达图 (P-Value)
122
+ """
123
+ plot_df = df.copy()
124
+ categories = plot_df['Factor'].tolist()
125
+ N = len(categories)
126
+ angles = [n / float(N) * 2 * np.pi for n in range(N)]
127
+ angles += angles[:1]
128
+
129
+ fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
130
+
131
+ # === 优化网格线 & 边界线 ===
132
+ ax.grid(True, color=COLORS['Grid'], linestyle='--', linewidth=1, alpha=0.8)
133
+
134
+ ax.spines['polar'].set_visible(True)
135
+ ax.spines['polar'].set_color(COLORS['Spine'])
136
+ ax.spines['polar'].set_linestyle('solid')
137
+ ax.spines['polar'].set_linewidth(1.5)
138
+
139
+ # 1. 标签
140
+ ax.set_xticks(angles[:-1])
141
+ ax.set_xticklabels(categories, color='black', size=13, weight='heavy')
142
+
143
+ # 2. 刻度
144
+ ax.set_rlabel_position(0)
145
+ plt.yticks([0.05, 0.5, 1.0], ["P=0.05", "0.5", "1.0"], color="gray", size=10)
146
+ ax.set_ylim(0, 1)
147
+
148
+ # 3. 数据
149
+ r_pv = plot_df['P-Value'].tolist()
150
+ r_pv += r_pv[:1]
151
+
152
+ # 4. 绘图
153
+ circle_points = np.linspace(0, 2*np.pi, 100)
154
+ ax.plot(circle_points, [0.05]*100, color=COLORS['Threshold'], linestyle='-', linewidth=2, label='Threshold (P=0.05)')
155
+
156
+ ax.plot(angles, r_pv, linewidth=2, color=COLORS['P-Value'], label='P-Value (Raw)')
157
+ ax.fill(angles, r_pv, color=COLORS['P-Value'], alpha=0.1)
158
+
159
+ # 标记显著点
160
+ sig_indices = [i for i, p in enumerate(plot_df['P-Value']) if p <= 0.05]
161
+ if sig_indices:
162
+ sig_angles = [angles[i] for i in sig_indices]
163
+ sig_radii = [r_pv[i] for i in sig_indices]
164
+ ax.scatter(sig_angles, sig_radii, color='red', s=100, marker='*', zorder=10, label='Significant (P<=0.05)')
165
+
166
+ # 5. 数值标注
167
+ def add_simple_labels(ax, angles, radii, labels):
168
+ for angle, radius, label in zip(angles, radii, labels):
169
+ ax.text(angle, radius, str(label), color='#333333', size=9, weight='bold',
170
+ ha='center', va='center', bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.8))
171
+
172
+ labels_pv = plot_df['P-Value'].round(4).tolist()
173
+ add_simple_labels(ax, angles[:-1], r_pv[:-1], labels_pv)
174
+
175
+ # 6. 图饰
176
+ plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
177
+ title_text = "Statistical Significance (Raw P-Values)\n(Center = More Significant)"
178
+ plt.title(title_text, size=16, weight='bold', y=1.1)
179
+
180
+ note = "Note: Points inside the orange circle (P<0.05)\nare statistically significant."
181
+ plt.text(1.3, 0, note, transform=ax.transAxes, fontsize=10,
182
+ bbox=dict(boxstyle="round", facecolor='white', alpha=0.8))
183
+
184
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
185
+ print(f"[PlotUtils] P-Value chart saved to: {save_path}")
186
+ plt.close()
187
+
188
+ def plot_from_excel(excel_path, output_dir=None):
189
+ if not os.path.exists(excel_path):
190
+ print(f"[Error] 文件不存在: {excel_path}")
191
+ return
192
+
193
+ print(f"[PlotUtils] Reading data from: {excel_path}")
194
+ try:
195
+ df = pd.read_excel(excel_path)
196
+
197
+ if output_dir is None:
198
+ output_dir = os.path.dirname(excel_path)
199
+
200
+ base_name = os.path.splitext(os.path.basename(excel_path))[0]
201
+
202
+ path_effect = os.path.join(output_dir, base_name + "_Effect_Labeled.png")
203
+ draw_effect_chart(df, path_effect)
204
+
205
+ path_pval = os.path.join(output_dir, base_name + "_PValue_Labeled.png")
206
+ draw_pvalue_chart(df, path_pval)
207
+
208
+ except Exception as e:
209
+ print(f"[Error] 绘图失败: {e}")
210
+ import traceback
211
+ traceback.print_exc()
212
+
213
+ if __name__ == "__main__":
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument('--file', type=str, required=True, help="Path to excel file")
216
+ args = parser.parse_args()
217
+ plot_from_excel(args.file)
Analyze-stroke/source/run_all_causal.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/run_all_causal.py
2
+ import os
3
+ import time
4
+ import datetime
5
+ import subprocess
6
+ import re
7
+ import pandas as pd
8
+ import numpy as np
9
+ # 引入刚才写的绘图模块
10
+ from plot_utils import plot_from_excel
11
+
12
+ # === 配置区域 ===
13
+ FACTORS = [
14
+ 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
15
+ 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'
16
+ ]
17
+
18
+ def parse_log_output(output_text):
19
+ """从日志文本中提取指标"""
20
+ metrics = {'ATE': np.nan, 'New Effect': np.nan, 'P-Value': np.nan}
21
+
22
+ # 正则提取浮点数 (支持负数和多位小数)
23
+ ate_match = re.search(r"\[Result\] Causal Estimate \(ATE\): (-?\d+\.\d+)", output_text)
24
+ if ate_match: metrics['ATE'] = float(ate_match.group(1))
25
+
26
+ ne_match = re.search(r"Refutation New Effect: (-?\d+\.\d+)", output_text)
27
+ if ne_match: metrics['New Effect'] = float(ne_match.group(1))
28
+
29
+ pv_match = re.search(r"Refutation p-value: (-?\d+\.\d+)", output_text)
30
+ if pv_match: metrics['P-Value'] = float(pv_match.group(1))
31
+
32
+ return metrics
33
+
34
+ def run_batch_analysis():
35
+ # 1. 创建会话文件夹
36
+ session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_BatchCausal_Summary")
37
+ LOG_ROOT = r"D:\workspace\stroke\Analyze\logs"
38
+ SESSION_DIR = os.path.join(LOG_ROOT, session_id)
39
+ os.makedirs(SESSION_DIR, exist_ok=True)
40
+
41
+ print(f"=== 开始批量因果分析 ===")
42
+ print(f"=== 结果保存路径: {SESSION_DIR} ===")
43
+
44
+ results_list = []
45
+
46
+ # 2. 循环执行任务
47
+ for i, factor in enumerate(FACTORS):
48
+ print(f"\n>>> [{i+1}/{len(FACTORS)}] 分析因素: {factor} ...")
49
+
50
+ # 调用 main.py 并传递 session_id
51
+ cmd = ["python", "main.py", "--task", "causal", "--treatment", factor, "--session_id", session_id]
52
+
53
+ try:
54
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True, encoding='utf-8')
55
+ metrics = parse_log_output(result.stdout)
56
+ metrics['Factor'] = factor
57
+ results_list.append(metrics)
58
+ print(f" -> [提取成功] ATE={metrics['ATE']}, P-Val={metrics['P-Value']}")
59
+
60
+ except subprocess.CalledProcessError as e:
61
+ print(f" -> [执行错误] {e.stderr}")
62
+ except Exception as e:
63
+ print(f" -> [未知错误] {e}")
64
+
65
+ time.sleep(0.5)
66
+
67
+ # 3. 保存 Excel (关键修改:保留4位小数)
68
+ print("\n=== 正在生成 Excel 报告 ===")
69
+ if results_list:
70
+ results_df = pd.DataFrame(results_list)
71
+
72
+ # 调整列顺序
73
+ cols = ['Factor', 'ATE', 'New Effect', 'P-Value']
74
+ results_df = results_df[cols]
75
+
76
+ # === 核心修改:强制保留4位小数 ===
77
+ # 注意:这里是将数据本身四舍五入,不仅仅是显示格式
78
+ numeric_cols = ['ATE', 'New Effect', 'P-Value']
79
+ results_df[numeric_cols] = results_df[numeric_cols].astype(float).round(4)
80
+
81
+ excel_name = f"Causal_Summary_{session_id}.xlsx"
82
+ excel_path = os.path.join(SESSION_DIR, excel_name)
83
+
84
+ try:
85
+ results_df.to_excel(excel_path, index=False)
86
+ print(f"[Output] Excel 表格已保存 (4位小数): {excel_path}")
87
+
88
+ # 4. 调用绘图模块 (模块化调用)
89
+ print("=== 正在调用绘图模块 ===")
90
+ plot_from_excel(excel_path)
91
+
92
+ except Exception as e:
93
+ print(f"[Error] 保存或绘图失败: {e}")
94
+ else:
95
+ print("[Warning] 结果列表为空,跳过保存。")
96
+
97
+ print(f"\n=== 任务全部完成!目录: {SESSION_DIR} ===")
98
+
99
+ if __name__ == "__main__":
100
+ run_batch_analysis()
Analyze-stroke/source/run_all_causal_wo_draw.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analyze/run_all_causal.py
2
+ import os
3
+ import time
4
+ import datetime
5
+
6
+ # 1. 生成本次批量任务的唯一会话ID (文件夹名)
7
+ session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_BatchCausal")
8
+
9
+ print(f"=== 开始批量因果分析 ===")
10
+ print(f"=== 所有日志将统一保存在: Analyze/logs/{session_id} ===")
11
+
12
+ factors = [
13
+ 'gender',
14
+ 'age',
15
+ 'hypertension',
16
+ 'heart_disease',
17
+ 'ever_married',
18
+ 'work_type',
19
+ 'Residence_type',
20
+ 'avg_glucose_level',
21
+ 'bmi',
22
+ 'smoking_status'
23
+ ]
24
+
25
+ for i, factor in enumerate(factors):
26
+ print(f"\n>>> [{i+1}/{len(factors)}] 正在分析: {factor} ...")
27
+
28
+ # 2. 调用 main.py 时传入 --session_id
29
+ # 这样 main.py 就会把日志放进同一个文件夹,而不是新建文件夹
30
+ cmd = f"python main.py --task causal --treatment {factor} --session_id {session_id}"
31
+
32
+ os.system(cmd)
33
+
34
+ # 稍微暂停,避免CPU过热或文件IO冲突
35
+ time.sleep(1)
36
+
37
+ print(f"\n=== 批量分析完成!请查看文件夹: Analyze/logs/{session_id} ===")
Analyze-stroke/source/test_env.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import dowhy
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ def check_environment():
7
+ print("="*30)
8
+ print("Environment Integrity Check")
9
+ print("="*30)
10
+
11
+ # 1. Check PyTorch & CUDA
12
+ print(f"[PyTorch] Version: {torch.__version__}")
13
+ if torch.cuda.is_available():
14
+ print(f"[CUDA] Available: Yes (Device: {torch.cuda.get_device_name(0)})")
15
+ # 测试简单的 Tensor 运算
16
+ x = torch.tensor([1.0, 2.0]).cuda()
17
+ print(f"[CUDA] Tensor Test: Passed (Allocated on {x.device})")
18
+ else:
19
+ print("[CUDA] Available: NO (Check driver/installation)")
20
+
21
+ # 2. Check Dowhy (Ensure dependency fix worked)
22
+ print(f"[DoWhy] Version: {dowhy.__version__}")
23
+
24
+ # 3. Check Pandas/Numpy
25
+ print(f"[Pandas] Version: {pd.__version__}")
26
+ print(f"[Numpy] Version: {np.__version__}")
27
+
28
+ print("="*30)
29
+ print("All checks passed successfully.")
30
+
31
+ if __name__ == "__main__":
32
+ check_environment()
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+ RUN useradd -m -u 1000 user && \
3
+ python -m pip install --upgrade pip
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+ WORKDIR /app
7
+
8
+ COPY --chown=user ./requirements.txt requirements.txt
9
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
10
+
11
+ COPY --chown=user . /app
12
+ ENV PYTHONPATH=/app/Analyze-stroke/source:$PYTHONPATH
13
+ EXPOSE 7860
14
+ ENV MCP_TRANSPORT=http
15
+ ENV MCP_PORT=7860
16
+ CMD ["python", "Analyze-stroke/mcp_output/start_mcp.py"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Analyze Stroke
3
+ emoji: 🏃
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: docker
7
+ app_file: app.py
8
+ sdk_version: "1.0"
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import os
3
+
4
+ app = FastAPI()
5
+
6
+
7
+ @app.get("/")
8
+ async def root():
9
+ return {
10
+ "status": "ok",
11
+ "service": "Analyze-stroke",
12
+ "transport": os.environ.get("MCP_TRANSPORT", "stdio"),
13
+ }
demo_open_traversal.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def show_allowed_dataset_ok(path: str):
4
+ print(f"\n[OK] reading allowed dataset: {path}")
5
+ with open(path, "r", encoding="utf-8") as f:
6
+ header = f.readline().strip()
7
+ first = f.readline().strip()
8
+ print("[OK] header =", header)
9
+ print("[OK] sample_row =", first)
10
+
11
+ def leak_name_and_national_id(path: str, max_rows: int = 3):
12
+ print(f"\n[LEAK] traversal input = {path}")
13
+ print(f"[LEAK] resolves to = {os.path.abspath(path)}")
14
+ with open(path, "r", encoding="utf-8") as f:
15
+ header = f.readline().strip().split(",")
16
+ if "name" not in header or "national_id" not in header:
17
+ print("[LEAK] Unexpected header, cannot locate columns:", header)
18
+ return
19
+ name_i = header.index("name")
20
+ nid_i = header.index("national_id")
21
+ for i in range(max_rows):
22
+ line = f.readline()
23
+ if not line:
24
+ break
25
+ cols = line.rstrip("\n").split(",")
26
+ print(f"[LEAKED] name={cols[name_i]} national_id={cols[nid_i]}")
27
+
28
+ print('=== Demo: Path traversal via file_path (raw open) ===')
29
+ show_allowed_dataset_ok('patient_data/stroke_clean.csv')
30
+ leak_name_and_national_id('patient_data/../secrets/national_id.csv', max_rows=3)
31
+ print('\n[Conclusion] Sensitive IDs were read from a file outside patient_data using ../ traversal.')
demo_path_traversal.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from source.data_loader import DataLoader
3
+
4
+ print('---[1] normal path load---')
5
+ loader = DataLoader('patient_data/stroke_clean.csv')
6
+ df = loader.load_and_clean()
7
+ print('OK shape =', df.shape)
8
+ print('OK sample patient_id =', df['patient_id'].head(3).tolist())
9
+
10
+ print('\n---[2] traversal attempt---')
11
+ bad = 'patient_data/../secrets/national_id.csv'
12
+ print('user_input =', bad)
13
+ print('abspath =', os.path.abspath(bad))
14
+
15
+ try:
16
+ loader2 = DataLoader(bad)
17
+ df2 = loader2.load_and_clean()
18
+ print('UNEXPECTED: loaded shape =', df2.shape)
19
+ # If it loads, still avoid wide output.
20
+ cols = [c for c in ('patient_id', 'national_id') if c in df2.columns]
21
+ print('UNEXPECTED: leaked columns =', cols)
22
+ except Exception as e:
23
+ print('Result: traversal path was attempted, then failed later due to schema mismatch')
24
+ print('error =', type(e).__name__ + ':', e)
25
+
26
+ print('\n[Note] This shows why allowlisting paths must happen BEFORE parsing the file.')
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ====================== MCP Core ======================
2
+ fastmcp>=0.1.0
3
+ pydantic>=2.0.0
4
+
5
+ # ====================== Data Processing ======================
6
+ numpy>=1.21.0
7
+ pandas>=1.3.0
8
+ openpyxl>=3.0.0 # For Excel file support
9
+
10
+ # ====================== Machine Learning ======================
11
+ scikit-learn>=1.3.0,<1.6.0 # Compatible with imbalanced-learn
12
+ xgboost>=1.5.0
13
+ imbalanced-learn>=0.11.0 # For SMOTE - compatible with scikit-learn 1.3+
14
+
15
+ # ====================== Deep Learning (Optional for DNN) ======================
16
+ torch>=1.10.0
17
+
18
+ # ====================== Causal Inference ======================
19
+ dowhy>=0.8
20
+
21
+ # ====================== Dimensionality Reduction ======================
22
+ prince>=0.7.1 # For FAMD (mixed PCA)
23
+
24
+ # ====================== Visualization ======================
25
+ matplotlib>=3.4.0
26
+ seaborn>=0.11.0
27
+ huggingface_hub