Spaces:
Sleeping
Sleeping
ChaoqianO
commited on
Commit
·
225e634
0
Parent(s):
init: Analyze-stroke MCP + security demo
Browse files- .gitattributes +35 -0
- Analyze-stroke/mcp_output/README_MCP.md +88 -0
- Analyze-stroke/mcp_output/analysis.json +181 -0
- Analyze-stroke/mcp_output/env_info.json +17 -0
- Analyze-stroke/mcp_output/mcp_logs/llm_statistics.json +11 -0
- Analyze-stroke/mcp_output/mcp_logs/run_log.json +74 -0
- Analyze-stroke/mcp_output/mcp_plugin/__init__.py +0 -0
- Analyze-stroke/mcp_output/mcp_plugin/__pycache__/adapter.cpython-310.pyc +0 -0
- Analyze-stroke/mcp_output/mcp_plugin/__pycache__/mcp_service.cpython-310.pyc +0 -0
- Analyze-stroke/mcp_output/mcp_plugin/adapter.py +147 -0
- Analyze-stroke/mcp_output/mcp_plugin/main.py +13 -0
- Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py +734 -0
- Analyze-stroke/mcp_output/requirements.txt +9 -0
- Analyze-stroke/mcp_output/simple_revise_error_analysis.json +6 -0
- Analyze-stroke/mcp_output/start_mcp.py +33 -0
- Analyze-stroke/mcp_output/tests_mcp/test_mcp_basic.py +49 -0
- Analyze-stroke/mcp_output/tests_smoke/test_smoke.py +12 -0
- Analyze-stroke/patient_data/README_DEMO.txt +12 -0
- Analyze-stroke/patient_data/stroke_clean.csv +23 -0
- Analyze-stroke/secrets/national_id.csv +23 -0
- Analyze-stroke/source/__init__.py +4 -0
- Analyze-stroke/source/__pycache__/__init__.cpython-311.pyc +0 -0
- Analyze-stroke/source/__pycache__/__init__.cpython-39.pyc +0 -0
- Analyze-stroke/source/__pycache__/causal_module.cpython-310.pyc +0 -0
- Analyze-stroke/source/__pycache__/data_loader.cpython-311.pyc +0 -0
- Analyze-stroke/source/__pycache__/data_loader.cpython-39.pyc +0 -0
- Analyze-stroke/source/causal_module.py +108 -0
- Analyze-stroke/source/data_loader.py +55 -0
- Analyze-stroke/source/dim_reduction.py +130 -0
- Analyze-stroke/source/environment.yml +223 -0
- Analyze-stroke/source/feature_selection.py +140 -0
- Analyze-stroke/source/healthcare-dataset-stroke-data.csv +0 -0
- Analyze-stroke/source/main.py +130 -0
- Analyze-stroke/source/models.py +160 -0
- Analyze-stroke/source/plot_utils.py +217 -0
- Analyze-stroke/source/run_all_causal.py +100 -0
- Analyze-stroke/source/run_all_causal_wo_draw.py +37 -0
- Analyze-stroke/source/test_env.py +32 -0
- Dockerfile +16 -0
- README.md +12 -0
- app.py +13 -0
- demo_open_traversal.py +31 -0
- demo_path_traversal.py +26 -0
- requirements.txt +27 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Analyze-stroke/mcp_output/README_MCP.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze Stroke Plugin
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The Analyze Stroke Plugin is a comprehensive tool designed to facilitate the analysis of stroke data using various machine learning and statistical methods. This plugin provides functionalities for data loading, feature selection, dimensionality reduction, and causal analysis, with options for visualization.
|
| 6 |
+
|
| 7 |
+
The project is hosted on GitHub and can be accessed via the following repository: [Analyze Stroke Repository](https://github.com/ghh1125/Analyze-stroke).
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
|
| 11 |
+
To set up the Analyze Stroke Plugin, follow these steps:
|
| 12 |
+
|
| 13 |
+
1. **Clone the Repository:**
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
git clone https://github.com/ghh1125/Analyze-stroke.git
|
| 17 |
+
cd Analyze-stroke
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
2. **Set Up the Environment:**
|
| 21 |
+
|
| 22 |
+
The project includes an `environment.yml` file for setting up the necessary Python environment. Use the following command to create the environment:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
conda env create -f environment.yml
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
3. **Activate the Environment:**
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
conda activate analyze-stroke
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
The Analyze Stroke Plugin provides several command-line interface (CLI) commands to execute different analysis processes. Below are the primary commands available:
|
| 37 |
+
|
| 38 |
+
- **Run All Causal Analysis with Visualization:**
|
| 39 |
+
|
| 40 |
+
This command executes all causal analysis processes and includes visualization of the results.
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python run_all_causal.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
- **Run All Causal Analysis without Visualization:**
|
| 47 |
+
|
| 48 |
+
This command executes all causal analysis processes without generating visual outputs.
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
python run_all_causal_wo_draw.py
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Tool Endpoints
|
| 55 |
+
|
| 56 |
+
The following modules are included in the Analyze Stroke Plugin:
|
| 57 |
+
|
| 58 |
+
- `causal_module.py`: Handles causal analysis processes.
|
| 59 |
+
- `data_loader.py`: Manages data loading operations.
|
| 60 |
+
- `dim_reduction.py`: Performs dimensionality reduction.
|
| 61 |
+
- `feature_selection.py`: Conducts feature selection.
|
| 62 |
+
- `main.py`: Main entry point for executing the plugin.
|
| 63 |
+
- `models.py`: Contains machine learning models used in the analysis.
|
| 64 |
+
- `plot_utils.py`: Provides utilities for plotting and visualization.
|
| 65 |
+
- `test_env.py`: Used for testing the environment setup.
|
| 66 |
+
|
| 67 |
+
## Dependencies
|
| 68 |
+
|
| 69 |
+
The Analyze Stroke Plugin requires the following dependencies:
|
| 70 |
+
|
| 71 |
+
- Required: `numpy`, `pandas`, `scikit-learn`, `matplotlib`
|
| 72 |
+
- Optional: `seaborn` (for enhanced visualization)
|
| 73 |
+
|
| 74 |
+
## Notes and Troubleshooting
|
| 75 |
+
|
| 76 |
+
- Ensure that all dependencies are correctly installed by using the provided `environment.yml` file.
|
| 77 |
+
- If you encounter issues with visualization, verify that `matplotlib` and `seaborn` are properly installed.
|
| 78 |
+
- For any issues related to data loading or processing, check the data format and ensure compatibility with the plugin's requirements.
|
| 79 |
+
|
| 80 |
+
## Contributing
|
| 81 |
+
|
| 82 |
+
Contributions to the Analyze Stroke Plugin are welcome. Please fork the repository and submit a pull request with your changes. Ensure that your code adheres to the project's coding standards and includes appropriate documentation.
|
| 83 |
+
|
| 84 |
+
## License
|
| 85 |
+
|
| 86 |
+
This project is licensed under the MIT License. See the [LICENSE](https://github.com/ghh1125/Analyze-stroke/blob/main/LICENSE) file for more details.
|
| 87 |
+
|
| 88 |
+
For further information, please visit the [GitHub repository](https://github.com/ghh1125/Analyze-stroke).
|
Analyze-stroke/mcp_output/analysis.json
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"summary": {
|
| 3 |
+
"repository_url": "https://github.com/ghh1125/Analyze-stroke",
|
| 4 |
+
"summary": "Imported via zip fallback, file count: 11",
|
| 5 |
+
"file_tree": {
|
| 6 |
+
"causal_module.py": {
|
| 7 |
+
"size": 3829
|
| 8 |
+
},
|
| 9 |
+
"data_loader.py": {
|
| 10 |
+
"size": 1789
|
| 11 |
+
},
|
| 12 |
+
"dim_reduction.py": {
|
| 13 |
+
"size": 5137
|
| 14 |
+
},
|
| 15 |
+
"environment.yml": {
|
| 16 |
+
"size": 7791
|
| 17 |
+
},
|
| 18 |
+
"feature_selection.py": {
|
| 19 |
+
"size": 5748
|
| 20 |
+
},
|
| 21 |
+
"main.py": {
|
| 22 |
+
"size": 5655
|
| 23 |
+
},
|
| 24 |
+
"models.py": {
|
| 25 |
+
"size": 7035
|
| 26 |
+
},
|
| 27 |
+
"plot_utils.py": {
|
| 28 |
+
"size": 7901
|
| 29 |
+
},
|
| 30 |
+
"run_all_causal.py": {
|
| 31 |
+
"size": 3386
|
| 32 |
+
},
|
| 33 |
+
"run_all_causal_wo_draw.py": {
|
| 34 |
+
"size": 901
|
| 35 |
+
},
|
| 36 |
+
"test_env.py": {
|
| 37 |
+
"size": 987
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
"processed_by": "zip_fallback",
|
| 41 |
+
"success": true
|
| 42 |
+
},
|
| 43 |
+
"structure": {
|
| 44 |
+
"packages": []
|
| 45 |
+
},
|
| 46 |
+
"dependencies": {
|
| 47 |
+
"has_environment_yml": true,
|
| 48 |
+
"has_requirements_txt": false,
|
| 49 |
+
"pyproject": false,
|
| 50 |
+
"setup_cfg": false,
|
| 51 |
+
"setup_py": false
|
| 52 |
+
},
|
| 53 |
+
"entry_points": {
|
| 54 |
+
"imports": [],
|
| 55 |
+
"cli": [],
|
| 56 |
+
"modules": []
|
| 57 |
+
},
|
| 58 |
+
"llm_analysis": {
|
| 59 |
+
"core_modules": [
|
| 60 |
+
{
|
| 61 |
+
"package": "https://github.com/ghh1125/Analyze-stroke/causal_module.py",
|
| 62 |
+
"module": "causal_module",
|
| 63 |
+
"functions": [
|
| 64 |
+
"function1",
|
| 65 |
+
"function2"
|
| 66 |
+
],
|
| 67 |
+
"classes": [
|
| 68 |
+
"CausalClass1",
|
| 69 |
+
"CausalClass2"
|
| 70 |
+
],
|
| 71 |
+
"description": "Handles causal inference and analysis for stroke data."
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"package": "https://github.com/ghh1125/Analyze-stroke/data_loader.py",
|
| 75 |
+
"module": "data_loader",
|
| 76 |
+
"functions": [
|
| 77 |
+
"load_data",
|
| 78 |
+
"preprocess_data"
|
| 79 |
+
],
|
| 80 |
+
"classes": [
|
| 81 |
+
"DataLoader"
|
| 82 |
+
],
|
| 83 |
+
"description": "Responsible for loading and preprocessing stroke datasets."
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"package": "https://github.com/ghh1125/Analyze-stroke/dim_reduction.py",
|
| 87 |
+
"module": "dim_reduction",
|
| 88 |
+
"functions": [
|
| 89 |
+
"reduce_dimensions"
|
| 90 |
+
],
|
| 91 |
+
"classes": [
|
| 92 |
+
"DimReducer"
|
| 93 |
+
],
|
| 94 |
+
"description": "Implements dimensionality reduction techniques."
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"package": "https://github.com/ghh1125/Analyze-stroke/feature_selection.py",
|
| 98 |
+
"module": "feature_selection",
|
| 99 |
+
"functions": [
|
| 100 |
+
"select_features"
|
| 101 |
+
],
|
| 102 |
+
"classes": [
|
| 103 |
+
"FeatureSelector"
|
| 104 |
+
],
|
| 105 |
+
"description": "Contains methods for feature selection in datasets."
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"package": "https://github.com/ghh1125/Analyze-stroke/models.py",
|
| 109 |
+
"module": "models",
|
| 110 |
+
"functions": [
|
| 111 |
+
"train_model",
|
| 112 |
+
"evaluate_model"
|
| 113 |
+
],
|
| 114 |
+
"classes": [
|
| 115 |
+
"ModelTrainer",
|
| 116 |
+
"ModelEvaluator"
|
| 117 |
+
],
|
| 118 |
+
"description": "Defines machine learning models and their training/evaluation."
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"package": "https://github.com/ghh1125/Analyze-stroke/plot_utils.py",
|
| 122 |
+
"module": "plot_utils",
|
| 123 |
+
"functions": [
|
| 124 |
+
"plot_results"
|
| 125 |
+
],
|
| 126 |
+
"classes": [],
|
| 127 |
+
"description": "Utility functions for plotting results and visualizations."
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
"cli_commands": [
|
| 131 |
+
{
|
| 132 |
+
"name": "run_all_causal",
|
| 133 |
+
"module": "https://github.com/ghh1125/Analyze-stroke/run_all_causal.py",
|
| 134 |
+
"description": "Executes all causal analysis processes with visualization."
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"name": "run_all_causal_wo_draw",
|
| 138 |
+
"module": "https://github.com/ghh1125/Analyze-stroke/run_all_causal_wo_draw.py",
|
| 139 |
+
"description": "Executes all causal analysis processes without visualization."
|
| 140 |
+
}
|
| 141 |
+
],
|
| 142 |
+
"import_strategy": {
|
| 143 |
+
"primary": "import",
|
| 144 |
+
"fallback": "cli",
|
| 145 |
+
"confidence": 0.85
|
| 146 |
+
},
|
| 147 |
+
"dependencies": {
|
| 148 |
+
"required": [
|
| 149 |
+
"numpy",
|
| 150 |
+
"pandas",
|
| 151 |
+
"scikit-learn",
|
| 152 |
+
"matplotlib"
|
| 153 |
+
],
|
| 154 |
+
"optional": [
|
| 155 |
+
"seaborn"
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
"risk_assessment": {
|
| 159 |
+
"import_feasibility": 0.8,
|
| 160 |
+
"intrusiveness_risk": "medium",
|
| 161 |
+
"complexity": "medium"
|
| 162 |
+
}
|
| 163 |
+
},
|
| 164 |
+
"deepwiki_analysis": {
|
| 165 |
+
"repo_url": "https://github.com/ghh1125/Analyze-stroke",
|
| 166 |
+
"repo_name": "Analyze-stroke",
|
| 167 |
+
"content": null,
|
| 168 |
+
"model": "gpt-4o",
|
| 169 |
+
"source": "selenium",
|
| 170 |
+
"success": true
|
| 171 |
+
},
|
| 172 |
+
"deepwiki_options": {
|
| 173 |
+
"enabled": true,
|
| 174 |
+
"model": "gpt-4o"
|
| 175 |
+
},
|
| 176 |
+
"risk": {
|
| 177 |
+
"import_feasibility": 0.8,
|
| 178 |
+
"intrusiveness_risk": "medium",
|
| 179 |
+
"complexity": "medium"
|
| 180 |
+
}
|
| 181 |
+
}
|
Analyze-stroke/mcp_output/env_info.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"environment": {
|
| 3 |
+
"type": "conda",
|
| 4 |
+
"name": "Analyze-stroke_366682_env",
|
| 5 |
+
"files": {
|
| 6 |
+
"environment_yml": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/source/environment.yml"
|
| 7 |
+
},
|
| 8 |
+
"python": "3.10",
|
| 9 |
+
"exec_prefix": []
|
| 10 |
+
},
|
| 11 |
+
"original_tests": {
|
| 12 |
+
"passed": false,
|
| 13 |
+
"report_path": null
|
| 14 |
+
},
|
| 15 |
+
"timestamp": 1765367116.7647827,
|
| 16 |
+
"conda_available": true
|
| 17 |
+
}
|
Analyze-stroke/mcp_output/mcp_logs/llm_statistics.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total_calls": 7,
|
| 3 |
+
"failed_calls": 0,
|
| 4 |
+
"retry_count": 0,
|
| 5 |
+
"total_prompt_tokens": 6752,
|
| 6 |
+
"total_completion_tokens": 3421,
|
| 7 |
+
"total_tokens": 10173,
|
| 8 |
+
"average_prompt_tokens": 964.5714285714286,
|
| 9 |
+
"average_completion_tokens": 488.7142857142857,
|
| 10 |
+
"average_tokens": 1453.2857142857142
|
| 11 |
+
}
|
Analyze-stroke/mcp_output/mcp_logs/run_log.json
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": 1765367363.30327,
|
| 3 |
+
"node": "RunNode",
|
| 4 |
+
"test_result": {
|
| 5 |
+
"passed": false,
|
| 6 |
+
"report_path": null,
|
| 7 |
+
"stdout": "",
|
| 8 |
+
"stderr": "ERROR conda.cli.main_run:execute(41): `conda run python mcp_output/start_mcp.py` failed. (See above for error)\nTraceback (most recent call last):\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py\", line 17, in <module>\n from mcp_service import create_app\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py\", line 5, in <module>\n import pandas as pd\nModuleNotFoundError: No module named 'pandas'\n\n"
|
| 9 |
+
},
|
| 10 |
+
"run_result": {
|
| 11 |
+
"success": false,
|
| 12 |
+
"test_passed": false,
|
| 13 |
+
"exit_code": 1,
|
| 14 |
+
"stdout": "",
|
| 15 |
+
"stderr": "ERROR conda.cli.main_run:execute(41): `conda run python mcp_output/start_mcp.py` failed. (See above for error)\nTraceback (most recent call last):\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py\", line 17, in <module>\n from mcp_service import create_app\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py\", line 5, in <module>\n import pandas as pd\nModuleNotFoundError: No module named 'pandas'\n\n",
|
| 16 |
+
"timestamp": 1765367363.3032475,
|
| 17 |
+
"error_type": "ImportError",
|
| 18 |
+
"error": "Module import failed: ERROR conda.cli.main_run:execute(41): `conda run python mcp_output/start_mcp.py` failed. (See above for error)\nTraceback (most recent call last):\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py\", line 17, in <module>\n from mcp_service import create_app\n File \"/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py\", line 5, in <module>\n import pandas as pd\nModuleNotFoundError: No module named 'pandas'\n\n",
|
| 19 |
+
"details": {
|
| 20 |
+
"command": "/home/wshiah/code/miniconda3/bin/conda run -n Analyze-stroke_366682_env --cwd /export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke python mcp_output/start_mcp.py",
|
| 21 |
+
"working_directory": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke",
|
| 22 |
+
"environment_type": "conda"
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
"environment": {
|
| 26 |
+
"type": "conda",
|
| 27 |
+
"name": "Analyze-stroke_366682_env",
|
| 28 |
+
"files": {
|
| 29 |
+
"environment_yml": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/source/environment.yml"
|
| 30 |
+
},
|
| 31 |
+
"python": "3.10",
|
| 32 |
+
"exec_prefix": []
|
| 33 |
+
},
|
| 34 |
+
"plugin_info": {
|
| 35 |
+
"files": {
|
| 36 |
+
"mcp_output/start_mcp.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/start_mcp.py",
|
| 37 |
+
"mcp_output/mcp_plugin/__init__.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/__init__.py",
|
| 38 |
+
"mcp_output/mcp_plugin/mcp_service.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py",
|
| 39 |
+
"mcp_output/mcp_plugin/adapter.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/adapter.py",
|
| 40 |
+
"mcp_output/mcp_plugin/main.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin/main.py",
|
| 41 |
+
"mcp_output/requirements.txt": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/requirements.txt",
|
| 42 |
+
"mcp_output/README_MCP.md": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/README_MCP.md",
|
| 43 |
+
"mcp_output/tests_mcp/test_mcp_basic.py": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/tests_mcp/test_mcp_basic.py"
|
| 44 |
+
},
|
| 45 |
+
"adapter_mode": "import",
|
| 46 |
+
"endpoints": [
|
| 47 |
+
"function1",
|
| 48 |
+
"function2",
|
| 49 |
+
"causalclass1",
|
| 50 |
+
"causalclass2",
|
| 51 |
+
"load_data",
|
| 52 |
+
"preprocess_data",
|
| 53 |
+
"dataloader",
|
| 54 |
+
"reduce_dimensions",
|
| 55 |
+
"dimreducer",
|
| 56 |
+
"select_features",
|
| 57 |
+
"featureselector",
|
| 58 |
+
"train_model",
|
| 59 |
+
"evaluate_model",
|
| 60 |
+
"modeltrainer",
|
| 61 |
+
"modelevaluator",
|
| 62 |
+
"plot_results"
|
| 63 |
+
],
|
| 64 |
+
"mcp_dir": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/mcp_plugin",
|
| 65 |
+
"tests_dir": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/tests_mcp",
|
| 66 |
+
"main_entry": "start_mcp.py",
|
| 67 |
+
"readme_path": "/export/project/shiweijie/ghh/LLM_MCP_RAG/MCP-agent-github-repo-output/workspace/Analyze-stroke/mcp_output/README_MCP.md",
|
| 68 |
+
"requirements": [
|
| 69 |
+
"fastmcp>=0.1.0",
|
| 70 |
+
"pydantic>=2.0.0"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
"fastmcp_installed": false
|
| 74 |
+
}
|
Analyze-stroke/mcp_output/mcp_plugin/__init__.py
ADDED
|
File without changes
|
Analyze-stroke/mcp_output/mcp_plugin/__pycache__/adapter.cpython-310.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
Analyze-stroke/mcp_output/mcp_plugin/__pycache__/mcp_service.cpython-310.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
Analyze-stroke/mcp_output/mcp_plugin/adapter.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Path settings
|
| 5 |
+
source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
|
| 6 |
+
sys.path.insert(0, source_path)
|
| 7 |
+
|
| 8 |
+
# Import statements
|
| 9 |
+
try:
|
| 10 |
+
from causal_module import CausalModule
|
| 11 |
+
from data_loader import DataLoader
|
| 12 |
+
from dim_reduction import DimReduction
|
| 13 |
+
from feature_selection import FeatureSelection
|
| 14 |
+
from models import Models
|
| 15 |
+
from plot_utils import PlotUtils
|
| 16 |
+
from run_all_causal import run_all_causal
|
| 17 |
+
from run_all_causal_wo_draw import run_all_causal_wo_draw
|
| 18 |
+
except ImportError as e:
|
| 19 |
+
print(f"Import error: {e}. Ensure all modules are available in the source directory.")
|
| 20 |
+
|
| 21 |
+
class Adapter:
|
| 22 |
+
"""
|
| 23 |
+
Adapter class for MCP plugin, providing methods to interact with the Analyze-stroke repository.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.mode = "import"
|
| 28 |
+
|
| 29 |
+
# -------------------- Causal Module Methods --------------------
|
| 30 |
+
|
| 31 |
+
def create_causal_module_instance(self):
|
| 32 |
+
"""
|
| 33 |
+
Create an instance of the CausalModule class.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
dict: Status of the operation and instance if successful.
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
instance = CausalModule()
|
| 40 |
+
return {"status": "success", "instance": instance}
|
| 41 |
+
except Exception as e:
|
| 42 |
+
return {"status": "error", "message": str(e)}
|
| 43 |
+
|
| 44 |
+
# -------------------- Data Loader Methods --------------------
|
| 45 |
+
|
| 46 |
+
def create_data_loader_instance(self):
|
| 47 |
+
"""
|
| 48 |
+
Create an instance of the DataLoader class.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
dict: Status of the operation and instance if successful.
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
instance = DataLoader()
|
| 55 |
+
return {"status": "success", "instance": instance}
|
| 56 |
+
except Exception as e:
|
| 57 |
+
return {"status": "error", "message": str(e)}
|
| 58 |
+
|
| 59 |
+
# -------------------- Dimensionality Reduction Methods --------------------
|
| 60 |
+
|
| 61 |
+
def create_dim_reduction_instance(self):
|
| 62 |
+
"""
|
| 63 |
+
Create an instance of the DimReduction class.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
dict: Status of the operation and instance if successful.
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
instance = DimReduction()
|
| 70 |
+
return {"status": "success", "instance": instance}
|
| 71 |
+
except Exception as e:
|
| 72 |
+
return {"status": "error", "message": str(e)}
|
| 73 |
+
|
| 74 |
+
# -------------------- Feature Selection Methods --------------------
|
| 75 |
+
|
| 76 |
+
def create_feature_selection_instance(self):
|
| 77 |
+
"""
|
| 78 |
+
Create an instance of the FeatureSelection class.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
dict: Status of the operation and instance if successful.
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
instance = FeatureSelection()
|
| 85 |
+
return {"status": "success", "instance": instance}
|
| 86 |
+
except Exception as e:
|
| 87 |
+
return {"status": "error", "message": str(e)}
|
| 88 |
+
|
| 89 |
+
# -------------------- Models Methods --------------------
|
| 90 |
+
|
| 91 |
+
def create_models_instance(self):
|
| 92 |
+
"""
|
| 93 |
+
Create an instance of the Models class.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
dict: Status of the operation and instance if successful.
|
| 97 |
+
"""
|
| 98 |
+
try:
|
| 99 |
+
instance = Models()
|
| 100 |
+
return {"status": "success", "instance": instance}
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return {"status": "error", "message": str(e)}
|
| 103 |
+
|
| 104 |
+
# -------------------- Plot Utilities Methods --------------------
|
| 105 |
+
|
| 106 |
+
def create_plot_utils_instance(self):
|
| 107 |
+
"""
|
| 108 |
+
Create an instance of the PlotUtils class.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
dict: Status of the operation and instance if successful.
|
| 112 |
+
"""
|
| 113 |
+
try:
|
| 114 |
+
instance = PlotUtils()
|
| 115 |
+
return {"status": "success", "instance": instance}
|
| 116 |
+
except Exception as e:
|
| 117 |
+
return {"status": "error", "message": str(e)}
|
| 118 |
+
|
| 119 |
+
# -------------------- Run All Causal Methods --------------------
|
| 120 |
+
|
| 121 |
+
def call_run_all_causal(self):
|
| 122 |
+
"""
|
| 123 |
+
Execute the run_all_causal function.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
dict: Status of the operation.
|
| 127 |
+
"""
|
| 128 |
+
try:
|
| 129 |
+
run_all_causal()
|
| 130 |
+
return {"status": "success"}
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return {"status": "error", "message": str(e)}
|
| 133 |
+
|
| 134 |
+
def call_run_all_causal_wo_draw(self):
|
| 135 |
+
"""
|
| 136 |
+
Execute the run_all_causal_wo_draw function.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
dict: Status of the operation.
|
| 140 |
+
"""
|
| 141 |
+
try:
|
| 142 |
+
run_all_causal_wo_draw()
|
| 143 |
+
return {"status": "success"}
|
| 144 |
+
except Exception as e:
|
| 145 |
+
return {"status": "error", "message": str(e)}
|
| 146 |
+
|
| 147 |
+
# End of Adapter class definition
|
Analyze-stroke/mcp_output/mcp_plugin/main.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Service Auto-Wrapper - Auto-generated
|
| 3 |
+
"""
|
| 4 |
+
from mcp_service import create_app
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
"""Main entry point"""
|
| 8 |
+
app = create_app()
|
| 9 |
+
return app
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
app = main()
|
| 13 |
+
app.run()
|
Analyze-stroke/mcp_output/mcp_plugin/mcp_service.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import subprocess
|
| 4 |
+
import re
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import datetime
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
|
| 13 |
+
# 抑制 DoWhy 和其他库的日志
|
| 14 |
+
logging.getLogger('dowhy').setLevel(logging.WARNING)
|
| 15 |
+
logging.getLogger('pygraphviz').setLevel(logging.WARNING)
|
| 16 |
+
|
| 17 |
+
# ====================== 路径配置 (相对路径) ======================
|
| 18 |
+
# 获取当前文件所在目录: mcp_output/mcp_plugin/
|
| 19 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 20 |
+
# mcp_output 目录: mcp_output/
|
| 21 |
+
MCP_OUTPUT_DIR = os.path.dirname(CURRENT_DIR)
|
| 22 |
+
# 项目根目录: Analyze-stroke/
|
| 23 |
+
PROJECT_ROOT = os.path.dirname(MCP_OUTPUT_DIR)
|
| 24 |
+
# source 目录: Analyze-stroke/source/
|
| 25 |
+
SOURCE_DIR = os.path.join(PROJECT_ROOT, "source")
|
| 26 |
+
|
| 27 |
+
# 添加源代码路径到 Python 路径
|
| 28 |
+
sys.path.insert(0, SOURCE_DIR)
|
| 29 |
+
|
| 30 |
+
from fastmcp import FastMCP
|
| 31 |
+
|
| 32 |
+
# 导入源代码模块
|
| 33 |
+
from data_loader import DataLoader
|
| 34 |
+
from dim_reduction import DimensionAnalyzer
|
| 35 |
+
from models import ModelManager
|
| 36 |
+
from causal_module import CausalAnalyzer
|
| 37 |
+
from feature_selection import FeatureSelectionAnalyzer
|
| 38 |
+
from plot_utils import plot_from_excel
|
| 39 |
+
from run_all_causal import parse_log_output
|
| 40 |
+
|
| 41 |
+
# 默认数据文件路径 (相对路径)
|
| 42 |
+
DEFAULT_DATA_FILE = os.path.join(SOURCE_DIR, "healthcare-dataset-stroke-data.csv")
|
| 43 |
+
# 默认可视化输出目录
|
| 44 |
+
DEFAULT_VIS_DIR = os.path.join(MCP_OUTPUT_DIR, "visualization")
|
| 45 |
+
# 默认日志目录
|
| 46 |
+
DEFAULT_LOG_DIR = os.path.join(MCP_OUTPUT_DIR, "logs")
|
| 47 |
+
|
| 48 |
+
mcp = FastMCP("AnalyzeStrokeService")
|
| 49 |
+
|
| 50 |
+
# ====================== 数据加载工具 ======================
|
| 51 |
+
|
| 52 |
+
@mcp.tool(name="load_stroke_data", description="Load and clean the stroke dataset. Returns basic statistics about the data.")
|
| 53 |
+
def load_stroke_data_tool(file_path: str = None) -> dict:
|
| 54 |
+
"""
|
| 55 |
+
Load and clean the stroke dataset.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
file_path (str, optional): Path to the CSV data file. Uses default if not provided.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
dict: A dictionary containing success, result (data shape, columns, basic stats), and error fields.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 65 |
+
loader = DataLoader(data_file)
|
| 66 |
+
df = loader.load_and_clean()
|
| 67 |
+
|
| 68 |
+
result = {
|
| 69 |
+
"shape": df.shape,
|
| 70 |
+
"columns": list(df.columns),
|
| 71 |
+
"stroke_distribution": df['stroke'].value_counts().to_dict(),
|
| 72 |
+
"missing_values": df.isnull().sum().to_dict(),
|
| 73 |
+
"numeric_stats": df.describe().to_dict()
|
| 74 |
+
}
|
| 75 |
+
return {"success": True, "result": result, "error": None}
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 78 |
+
|
| 79 |
+
# ====================== 降维分析工具 ======================
|
| 80 |
+
|
| 81 |
+
@mcp.tool(name="perform_pca_famd", description="Perform PCA/FAMD dimensionality reduction analysis on the stroke data. FAMD handles mixed numeric and categorical data.")
|
| 82 |
+
def perform_pca_famd_tool(file_path: str = None, save_dir: str = None) -> dict:
|
| 83 |
+
"""
|
| 84 |
+
Perform PCA/FAMD dimensionality reduction analysis.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
file_path (str, optional): Path to the CSV data file.
|
| 88 |
+
save_dir (str, optional): Directory to save visualization plots.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
dict: A dictionary containing success, result, and error fields.
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 95 |
+
loader = DataLoader(data_file)
|
| 96 |
+
df = loader.load_and_clean()
|
| 97 |
+
|
| 98 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 99 |
+
vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
|
| 100 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='pca')
|
| 103 |
+
analyzer.perform_pca_famd()
|
| 104 |
+
|
| 105 |
+
return {"success": True, "result": f"PCA/FAMD analysis completed. Plots saved to {vis_dir}", "error": None}
|
| 106 |
+
except Exception as e:
|
| 107 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 108 |
+
|
| 109 |
+
@mcp.tool(name="perform_tsne", description="Perform t-SNE dimensionality reduction analysis on the stroke data for visualization.")
|
| 110 |
+
def perform_tsne_tool(file_path: str = None, save_dir: str = None) -> dict:
|
| 111 |
+
"""
|
| 112 |
+
Perform t-SNE dimensionality reduction analysis.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
file_path (str, optional): Path to the CSV data file.
|
| 116 |
+
save_dir (str, optional): Directory to save visualization plots.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
dict: A dictionary containing success, result, and error fields.
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 123 |
+
loader = DataLoader(data_file)
|
| 124 |
+
df = loader.load_and_clean()
|
| 125 |
+
|
| 126 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 127 |
+
vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
|
| 128 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='tsne')
|
| 131 |
+
analyzer.perform_tsne()
|
| 132 |
+
|
| 133 |
+
return {"success": True, "result": f"t-SNE analysis completed. Plots saved to {vis_dir}", "error": None}
|
| 134 |
+
except Exception as e:
|
| 135 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 136 |
+
|
| 137 |
+
# ====================== 预测模型工具 ======================
|
| 138 |
+
|
| 139 |
+
@mcp.tool(name="run_prediction_model", description="Run a prediction model for stroke classification. User MUST specify model_type. Options: nb (Naive Bayes), lr (Logistic Regression), svm (SVM), knn (KNN), dt (Decision Tree), rf (Random Forest), xgboost (XGBoost), dnn (Deep Neural Network).")
|
| 140 |
+
def run_prediction_model_tool(model_type: str, file_path: str = None, save_dir: str = None) -> dict:
|
| 141 |
+
"""
|
| 142 |
+
Run a prediction model for stroke classification.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
model_type (str): Model type. Options: 'nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'.
|
| 146 |
+
file_path (str, optional): Path to the CSV data file.
|
| 147 |
+
save_dir (str, optional): Directory to save visualization plots.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
dict: A dictionary containing success, result (model metrics), and error fields.
|
| 151 |
+
"""
|
| 152 |
+
try:
|
| 153 |
+
valid_models = ['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn']
|
| 154 |
+
if model_type not in valid_models:
|
| 155 |
+
return {"success": False, "result": None, "error": f"Invalid model type. Choose from: {valid_models}"}
|
| 156 |
+
|
| 157 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 158 |
+
loader = DataLoader(data_file)
|
| 159 |
+
df = loader.load_and_clean()
|
| 160 |
+
X_train, X_test, y_train, y_test = loader.get_split_data()
|
| 161 |
+
|
| 162 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 163 |
+
vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
|
| 164 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
manager = ModelManager(X_train, y_train, X_test, y_test, save_dir=vis_dir, timestamp=timestamp)
|
| 167 |
+
|
| 168 |
+
if model_type == 'dnn':
|
| 169 |
+
manager.run_dnn()
|
| 170 |
+
return {"success": True, "result": f"DNN model training completed. Plots saved to {vis_dir}", "error": None}
|
| 171 |
+
else:
|
| 172 |
+
manager.run_sklearn_model(model_type)
|
| 173 |
+
return {"success": True, "result": f"{model_type.upper()} model training completed. Plots saved to {vis_dir}", "error": None}
|
| 174 |
+
except Exception as e:
|
| 175 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 176 |
+
|
| 177 |
+
# ====================== 特征选择工具 ======================
|
| 178 |
+
|
| 179 |
+
@mcp.tool(name="run_feature_selection", description="Run feature selection analysis. User MUST specify method. Options: kbest (Mutual Information), rfecv (Recursive Feature Elimination with CV), chi2 (Chi-Square Test).")
|
| 180 |
+
def run_feature_selection_tool(method: str, file_path: str = None, save_dir: str = None) -> dict:
|
| 181 |
+
"""
|
| 182 |
+
Run feature selection analysis.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
method (str): Feature selection method. Options: 'kbest', 'rfecv', 'chi2'.
|
| 186 |
+
file_path (str, optional): Path to the CSV data file.
|
| 187 |
+
save_dir (str, optional): Directory to save visualization plots.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
dict: A dictionary containing success, result, and error fields.
|
| 191 |
+
"""
|
| 192 |
+
try:
|
| 193 |
+
valid_methods = ['kbest', 'rfecv', 'chi2']
|
| 194 |
+
if method not in valid_methods:
|
| 195 |
+
return {"success": False, "result": None, "error": f"Invalid method. Choose from: {valid_methods}"}
|
| 196 |
+
|
| 197 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 198 |
+
loader = DataLoader(data_file)
|
| 199 |
+
df = loader.load_and_clean()
|
| 200 |
+
|
| 201 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 202 |
+
vis_dir = save_dir if save_dir else DEFAULT_VIS_DIR
|
| 203 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 204 |
+
|
| 205 |
+
analyzer = FeatureSelectionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp)
|
| 206 |
+
|
| 207 |
+
if method == 'kbest':
|
| 208 |
+
analyzer.run_select_kbest()
|
| 209 |
+
elif method == 'rfecv':
|
| 210 |
+
analyzer.run_rfecv()
|
| 211 |
+
elif method == 'chi2':
|
| 212 |
+
analyzer.run_chi2()
|
| 213 |
+
|
| 214 |
+
return {"success": True, "result": f"Feature selection ({method}) completed. Plots saved to {vis_dir}", "error": None}
|
| 215 |
+
except Exception as e:
|
| 216 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 217 |
+
|
| 218 |
+
# ====================== 因果推理工具 ======================
|
| 219 |
+
|
| 220 |
+
@mcp.tool(name="run_causal_analysis", description="Run causal inference analysis for a single treatment variable on stroke outcome using DoWhy framework. User MUST specify treatment variable. Options: gender, age, hypertension, heart_disease, ever_married, work_type, Residence_type, avg_glucose_level, bmi, smoking_status.")
|
| 221 |
+
def run_causal_analysis_tool(treatment: str, file_path: str = None) -> dict:
|
| 222 |
+
"""
|
| 223 |
+
Run causal inference analysis for a single treatment variable.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
treatment (str): Treatment variable. Options: 'gender', 'age', 'hypertension', 'heart_disease',
|
| 227 |
+
'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'.
|
| 228 |
+
file_path (str, optional): Path to the CSV data file.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
dict: A dictionary containing success, result (ATE value and interpretation), and error fields.
|
| 232 |
+
"""
|
| 233 |
+
try:
|
| 234 |
+
valid_treatments = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
|
| 235 |
+
'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status']
|
| 236 |
+
if treatment not in valid_treatments:
|
| 237 |
+
return {"success": False, "result": None, "error": f"Invalid treatment. Choose from: {valid_treatments}"}
|
| 238 |
+
|
| 239 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 240 |
+
loader = DataLoader(data_file)
|
| 241 |
+
df = loader.load_and_clean()
|
| 242 |
+
|
| 243 |
+
# 智能领域编码 (与 main.py 保持一致)
|
| 244 |
+
if 'gender' in df.columns:
|
| 245 |
+
df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
|
| 246 |
+
if 'ever_married' in df.columns:
|
| 247 |
+
df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
|
| 248 |
+
if 'Residence_type' in df.columns:
|
| 249 |
+
df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
|
| 250 |
+
if 'smoking_status' in df.columns:
|
| 251 |
+
smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
|
| 252 |
+
df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
|
| 253 |
+
if 'work_type' in df.columns:
|
| 254 |
+
work_map = {'Private': 1, 'Self-employed': 1, 'Govt_job': 0, 'children': 0, 'Never_worked': 0}
|
| 255 |
+
df['work_type'] = df['work_type'].map(work_map).fillna(0)
|
| 256 |
+
|
| 257 |
+
analyzer = CausalAnalyzer(df)
|
| 258 |
+
ate_value = analyzer.run_analysis(treatment_col=treatment)
|
| 259 |
+
|
| 260 |
+
result = {
|
| 261 |
+
"treatment": treatment,
|
| 262 |
+
"outcome": "stroke",
|
| 263 |
+
"ATE": ate_value,
|
| 264 |
+
"interpretation": f"If '{treatment}' increases by 1 unit, the probability of stroke changes by {ate_value:.4f}"
|
| 265 |
+
}
|
| 266 |
+
return {"success": True, "result": result, "error": None}
|
| 267 |
+
except Exception as e:
|
| 268 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 269 |
+
|
| 270 |
+
@mcp.tool(name="run_batch_causal_analysis", description="Run batch causal analysis for all treatment factors and generate an Excel report with visualizations. Analyzes all 10 factors: gender, age, hypertension, heart_disease, ever_married, work_type, Residence_type, avg_glucose_level, bmi, smoking_status.")
|
| 271 |
+
def run_batch_causal_analysis_tool(file_path: str = None, save_dir: str = None) -> dict:
|
| 272 |
+
"""
|
| 273 |
+
Run batch causal analysis for all factors and generate summary report.
|
| 274 |
+
|
| 275 |
+
This function directly performs causal analysis for all factors without
|
| 276 |
+
relying on subprocess calls, making it cross-platform compatible.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
file_path (str, optional): Path to the CSV data file.
|
| 280 |
+
save_dir (str, optional): Directory to save results and visualizations.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
dict: A dictionary containing success, result (report path and summary), and error fields.
|
| 284 |
+
"""
|
| 285 |
+
try:
|
| 286 |
+
# 定义所有因素
|
| 287 |
+
FACTORS = [
|
| 288 |
+
'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
|
| 289 |
+
'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
# 设置保存目录
|
| 293 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_BatchCausal")
|
| 294 |
+
output_dir = save_dir if save_dir else os.path.join(DEFAULT_LOG_DIR, timestamp)
|
| 295 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 296 |
+
|
| 297 |
+
# 加载并预处理数据
|
| 298 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 299 |
+
loader = DataLoader(data_file)
|
| 300 |
+
df = loader.load_and_clean()
|
| 301 |
+
|
| 302 |
+
# 智能领域编码
|
| 303 |
+
if 'gender' in df.columns:
|
| 304 |
+
df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
|
| 305 |
+
if 'ever_married' in df.columns:
|
| 306 |
+
df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
|
| 307 |
+
if 'Residence_type' in df.columns:
|
| 308 |
+
df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
|
| 309 |
+
if 'smoking_status' in df.columns:
|
| 310 |
+
smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
|
| 311 |
+
df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
|
| 312 |
+
if 'work_type' in df.columns:
|
| 313 |
+
work_map = {'Private': 1, 'Self-employed': 1, 'Govt_job': 0, 'children': 0, 'Never_worked': 0}
|
| 314 |
+
df['work_type'] = df['work_type'].map(work_map).fillna(0)
|
| 315 |
+
|
| 316 |
+
results_list = []
|
| 317 |
+
|
| 318 |
+
# 循环执行因果分析
|
| 319 |
+
for i, factor in enumerate(FACTORS):
|
| 320 |
+
print(f"\n>>> [{i+1}/{len(FACTORS)}] Analyzing factor: {factor} ...")
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
analyzer = CausalAnalyzer(df)
|
| 324 |
+
|
| 325 |
+
# 捕获输出以提取指标
|
| 326 |
+
import io
|
| 327 |
+
from contextlib import redirect_stdout
|
| 328 |
+
|
| 329 |
+
f = io.StringIO()
|
| 330 |
+
with redirect_stdout(f):
|
| 331 |
+
ate_value = analyzer.run_analysis(treatment_col=factor)
|
| 332 |
+
|
| 333 |
+
output_text = f.getvalue()
|
| 334 |
+
|
| 335 |
+
# 提取指标
|
| 336 |
+
metrics = parse_log_output(output_text)
|
| 337 |
+
metrics['Factor'] = factor
|
| 338 |
+
metrics['ATE'] = ate_value if not np.isnan(metrics['ATE']) else ate_value
|
| 339 |
+
results_list.append(metrics)
|
| 340 |
+
|
| 341 |
+
print(f" -> [Success] ATE={metrics['ATE']:.4f}, P-Val={metrics['P-Value']}")
|
| 342 |
+
|
| 343 |
+
except Exception as e:
|
| 344 |
+
print(f" -> [Error] {str(e)}")
|
| 345 |
+
results_list.append({
|
| 346 |
+
'Factor': factor,
|
| 347 |
+
'ATE': np.nan,
|
| 348 |
+
'New Effect': np.nan,
|
| 349 |
+
'P-Value': np.nan
|
| 350 |
+
})
|
| 351 |
+
|
| 352 |
+
# 生成 Excel 报告
|
| 353 |
+
if results_list:
|
| 354 |
+
results_df = pd.DataFrame(results_list)
|
| 355 |
+
cols = ['Factor', 'ATE', 'New Effect', 'P-Value']
|
| 356 |
+
results_df = results_df[cols]
|
| 357 |
+
|
| 358 |
+
# 保留4位小数
|
| 359 |
+
numeric_cols = ['ATE', 'New Effect', 'P-Value']
|
| 360 |
+
results_df[numeric_cols] = results_df[numeric_cols].astype(float).round(4)
|
| 361 |
+
|
| 362 |
+
excel_name = f"Causal_Summary_{timestamp}.xlsx"
|
| 363 |
+
excel_path = os.path.join(output_dir, excel_name)
|
| 364 |
+
results_df.to_excel(excel_path, index=False)
|
| 365 |
+
print(f"\n[Output] Excel report saved: {excel_path}")
|
| 366 |
+
|
| 367 |
+
# 生成可视化图表
|
| 368 |
+
try:
|
| 369 |
+
plot_from_excel(excel_path, output_dir)
|
| 370 |
+
print("[Output] Visualization charts generated.")
|
| 371 |
+
except Exception as e:
|
| 372 |
+
print(f"[Warning] Plot generation failed: {e}")
|
| 373 |
+
|
| 374 |
+
# 返回摘要
|
| 375 |
+
summary = results_df.to_dict('records')
|
| 376 |
+
return {
|
| 377 |
+
"success": True,
|
| 378 |
+
"result": {
|
| 379 |
+
"excel_path": excel_path,
|
| 380 |
+
"output_dir": output_dir,
|
| 381 |
+
"summary": summary,
|
| 382 |
+
"total_factors": len(FACTORS),
|
| 383 |
+
"successful_analyses": len([r for r in results_list if not np.isnan(r.get('ATE', np.nan))])
|
| 384 |
+
},
|
| 385 |
+
"error": None
|
| 386 |
+
}
|
| 387 |
+
else:
|
| 388 |
+
return {"success": False, "result": None, "error": "No results generated."}
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 392 |
+
|
| 393 |
+
@mcp.tool(name="parse_log_output", description="Extracts metrics from log text including ATE, New Effect, and P-Value.")
|
| 394 |
+
def parse_log_output_tool(output_text: str) -> dict:
|
| 395 |
+
"""
|
| 396 |
+
Extracts metrics from log text including ATE, New Effect, and P-Value.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
output_text (str): The log text containing the metrics.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
dict: A dictionary containing success, result (extracted metrics), and error fields.
|
| 403 |
+
"""
|
| 404 |
+
try:
|
| 405 |
+
result = parse_log_output(output_text)
|
| 406 |
+
return {"success": True, "result": result, "error": None}
|
| 407 |
+
except Exception as e:
|
| 408 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 409 |
+
|
| 410 |
+
# ====================== 绘图工具 ======================
|
| 411 |
+
|
| 412 |
+
@mcp.tool(name="plot_causal_results", description="Generate visualizations (radar chart, bar chart) from causal analysis Excel results.")
|
| 413 |
+
def plot_causal_results_tool(excel_path: str) -> dict:
|
| 414 |
+
"""
|
| 415 |
+
Generate visualizations from causal analysis Excel results.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
excel_path (str): Path to the Excel file containing causal analysis results.
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
dict: A dictionary containing success, result, and error fields.
|
| 422 |
+
"""
|
| 423 |
+
try:
|
| 424 |
+
if not os.path.exists(excel_path):
|
| 425 |
+
return {"success": False, "result": None, "error": f"Excel file not found: {excel_path}"}
|
| 426 |
+
|
| 427 |
+
plot_from_excel(excel_path)
|
| 428 |
+
return {"success": True, "result": f"Plots generated from {excel_path}", "error": None}
|
| 429 |
+
except Exception as e:
|
| 430 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 431 |
+
|
| 432 |
+
# ====================== 综合分析工具 ======================
|
| 433 |
+
|
| 434 |
+
@mcp.tool(name="run_stroke_analysis", description="Run comprehensive stroke analysis. Unified entry point for all analysis tasks. User MUST specify task. For prediction task, user MUST also specify model. For feature_selection task, user MUST also specify method. For causal task, user MUST also specify treatment.")
|
| 435 |
+
def run_stroke_analysis_tool(task: str, model: str = None, method: str = None, treatment: str = None, file_path: str = None) -> dict:
|
| 436 |
+
"""
|
| 437 |
+
Run comprehensive stroke analysis - unified entry point.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
task (str): Analysis task. Options: 'pca', 'tsne', 'prediction', 'causal', 'feature_selection'.
|
| 441 |
+
model (str, optional): Model type for prediction task. Required when task='prediction'. Options: 'nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'.
|
| 442 |
+
method (str, optional): Method for feature selection. Required when task='feature_selection'. Options: 'kbest', 'rfecv', 'chi2'.
|
| 443 |
+
treatment (str, optional): Treatment variable for causal analysis. Required when task='causal'. Options: 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'.
|
| 444 |
+
file_path (str, optional): Path to the CSV data file.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
dict: A dictionary containing success, result, and error fields.
|
| 448 |
+
"""
|
| 449 |
+
try:
|
| 450 |
+
valid_tasks = ['pca', 'tsne', 'prediction', 'causal', 'feature_selection']
|
| 451 |
+
if task not in valid_tasks:
|
| 452 |
+
return {"success": False, "result": None, "error": f"Invalid task. Choose from: {valid_tasks}"}
|
| 453 |
+
|
| 454 |
+
# 直接执行任务逻辑,避免调用被装饰的函数
|
| 455 |
+
data_file = file_path if file_path else DEFAULT_DATA_FILE
|
| 456 |
+
|
| 457 |
+
if task == 'pca':
|
| 458 |
+
loader = DataLoader(data_file)
|
| 459 |
+
df = loader.load_and_clean()
|
| 460 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 461 |
+
vis_dir = DEFAULT_VIS_DIR
|
| 462 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 463 |
+
analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='pca')
|
| 464 |
+
analyzer.perform_pca_famd()
|
| 465 |
+
return {"success": True, "result": f"PCA/FAMD analysis completed. Plots saved to {vis_dir}", "error": None}
|
| 466 |
+
|
| 467 |
+
elif task == 'tsne':
|
| 468 |
+
loader = DataLoader(data_file)
|
| 469 |
+
df = loader.load_and_clean()
|
| 470 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 471 |
+
vis_dir = DEFAULT_VIS_DIR
|
| 472 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 473 |
+
analyzer = DimensionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp, task_name='tsne')
|
| 474 |
+
analyzer.perform_tsne()
|
| 475 |
+
return {"success": True, "result": f"t-SNE analysis completed. Plots saved to {vis_dir}", "error": None}
|
| 476 |
+
|
| 477 |
+
elif task == 'prediction':
|
| 478 |
+
if not model:
|
| 479 |
+
return {"success": False, "result": None, "error": "For prediction task, 'model' parameter is required. Options: 'nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'"}
|
| 480 |
+
|
| 481 |
+
valid_models = ['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn']
|
| 482 |
+
if model not in valid_models:
|
| 483 |
+
return {"success": False, "result": None, "error": f"Invalid model type. Choose from: {valid_models}"}
|
| 484 |
+
|
| 485 |
+
loader = DataLoader(data_file)
|
| 486 |
+
df = loader.load_and_clean()
|
| 487 |
+
X_train, X_test, y_train, y_test = loader.get_split_data()
|
| 488 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 489 |
+
vis_dir = DEFAULT_VIS_DIR
|
| 490 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 491 |
+
manager = ModelManager(X_train, y_train, X_test, y_test, save_dir=vis_dir, timestamp=timestamp)
|
| 492 |
+
|
| 493 |
+
if model == 'dnn':
|
| 494 |
+
manager.run_dnn()
|
| 495 |
+
else:
|
| 496 |
+
manager.run_sklearn_model(model)
|
| 497 |
+
return {"success": True, "result": f"{model.upper()} model training completed. Plots saved to {vis_dir}", "error": None}
|
| 498 |
+
|
| 499 |
+
elif task == 'causal':
|
| 500 |
+
if not treatment:
|
| 501 |
+
return {"success": False, "result": None, "error": "For causal task, 'treatment' parameter is required. Options: 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'"}
|
| 502 |
+
|
| 503 |
+
valid_treatments = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
|
| 504 |
+
'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status']
|
| 505 |
+
if treatment not in valid_treatments:
|
| 506 |
+
return {"success": False, "result": None, "error": f"Invalid treatment. Choose from: {valid_treatments}"}
|
| 507 |
+
|
| 508 |
+
loader = DataLoader(data_file)
|
| 509 |
+
df = loader.load_and_clean()
|
| 510 |
+
|
| 511 |
+
# 智能领域编码
|
| 512 |
+
if 'gender' in df.columns:
|
| 513 |
+
df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
|
| 514 |
+
if 'ever_married' in df.columns:
|
| 515 |
+
df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
|
| 516 |
+
if 'Residence_type' in df.columns:
|
| 517 |
+
df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
|
| 518 |
+
if 'smoking_status' in df.columns:
|
| 519 |
+
smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
|
| 520 |
+
df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
|
| 521 |
+
if 'work_type' in df.columns:
|
| 522 |
+
work_map = {'Private': 1, 'Self-employed': 1, 'Govt_job': 0, 'children': 0, 'Never_worked': 0}
|
| 523 |
+
df['work_type'] = df['work_type'].map(work_map).fillna(0)
|
| 524 |
+
|
| 525 |
+
analyzer = CausalAnalyzer(df)
|
| 526 |
+
ate_value = analyzer.run_analysis(treatment_col=treatment)
|
| 527 |
+
result = {
|
| 528 |
+
"treatment": treatment,
|
| 529 |
+
"outcome": "stroke",
|
| 530 |
+
"ATE": ate_value,
|
| 531 |
+
"interpretation": f"If '{treatment}' increases by 1 unit, the probability of stroke changes by {ate_value:.4f}"
|
| 532 |
+
}
|
| 533 |
+
return {"success": True, "result": result, "error": None}
|
| 534 |
+
|
| 535 |
+
elif task == 'feature_selection':
|
| 536 |
+
if not method:
|
| 537 |
+
return {"success": False, "result": None, "error": "For feature_selection task, 'method' parameter is required. Options: 'kbest', 'rfecv', 'chi2'"}
|
| 538 |
+
|
| 539 |
+
valid_methods = ['kbest', 'rfecv', 'chi2']
|
| 540 |
+
if method not in valid_methods:
|
| 541 |
+
return {"success": False, "result": None, "error": f"Invalid method. Choose from: {valid_methods}"}
|
| 542 |
+
|
| 543 |
+
loader = DataLoader(data_file)
|
| 544 |
+
df = loader.load_and_clean()
|
| 545 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 546 |
+
vis_dir = DEFAULT_VIS_DIR
|
| 547 |
+
os.makedirs(vis_dir, exist_ok=True)
|
| 548 |
+
analyzer = FeatureSelectionAnalyzer(df, save_dir=vis_dir, timestamp=timestamp)
|
| 549 |
+
|
| 550 |
+
if method == 'kbest':
|
| 551 |
+
analyzer.run_select_kbest()
|
| 552 |
+
elif method == 'rfecv':
|
| 553 |
+
analyzer.run_rfecv()
|
| 554 |
+
elif method == 'chi2':
|
| 555 |
+
analyzer.run_chi2()
|
| 556 |
+
|
| 557 |
+
return {"success": True, "result": f"Feature selection ({method}) completed. Plots saved to {vis_dir}", "error": None}
|
| 558 |
+
|
| 559 |
+
except Exception as e:
|
| 560 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 561 |
+
|
| 562 |
+
@mcp.tool(name="get_available_options", description="Get all available options for stroke analysis tasks, models, methods, and treatment variables.")
|
| 563 |
+
def get_available_options_tool() -> dict:
|
| 564 |
+
"""
|
| 565 |
+
Get all available options for stroke analysis.
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
dict: A dictionary containing all available options.
|
| 569 |
+
"""
|
| 570 |
+
return {
|
| 571 |
+
"success": True,
|
| 572 |
+
"result": {
|
| 573 |
+
"tasks": ['pca', 'tsne', 'prediction', 'causal', 'feature_selection'],
|
| 574 |
+
"prediction_models": ['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'],
|
| 575 |
+
"feature_selection_methods": ['kbest', 'rfecv', 'chi2'],
|
| 576 |
+
"causal_treatments": ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
|
| 577 |
+
'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'],
|
| 578 |
+
"description": {
|
| 579 |
+
"pca": "PCA/FAMD dimensionality reduction analysis",
|
| 580 |
+
"tsne": "t-SNE visualization",
|
| 581 |
+
"prediction": "Stroke prediction using ML models",
|
| 582 |
+
"causal": "Causal inference using DoWhy framework",
|
| 583 |
+
"feature_selection": "Feature importance analysis"
|
| 584 |
+
}
|
| 585 |
+
},
|
| 586 |
+
"error": None
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
@mcp.tool(name="upload_to_hf_dataset", description="Upload generated file to HuggingFace Dataset for easy download")
|
| 590 |
+
def upload_to_hf_dataset(file_path: str, dataset_repo: str = None, hf_token: str = None) -> dict:
|
| 591 |
+
"""
|
| 592 |
+
Upload generated file to HuggingFace Dataset repository.
|
| 593 |
+
Users can then download files from the Dataset page.
|
| 594 |
+
|
| 595 |
+
Parameters:
|
| 596 |
+
file_path (str): Absolute path to the file on HF Space
|
| 597 |
+
dataset_repo (str): Dataset repo name (e.g., 'username/obspy-outputs').
|
| 598 |
+
If None, uses HF_DATASET_REPO environment variable
|
| 599 |
+
hf_token (str): HuggingFace token. If None, uses HF_TOKEN environment variable
|
| 600 |
+
|
| 601 |
+
Returns:
|
| 602 |
+
dict: Success status and download URL
|
| 603 |
+
"""
|
| 604 |
+
try:
|
| 605 |
+
from huggingface_hub import HfApi, create_repo
|
| 606 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 607 |
+
|
| 608 |
+
if dataset_repo is None:
|
| 609 |
+
dataset_repo = os.environ.get("HF_DATASET_REPO")
|
| 610 |
+
if not dataset_repo:
|
| 611 |
+
return {
|
| 612 |
+
"success": False,
|
| 613 |
+
"error": "Dataset repo not specified. Set HF_DATASET_REPO environment variable or pass dataset_repo parameter"
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
if hf_token is None:
|
| 617 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 618 |
+
if not hf_token:
|
| 619 |
+
return {
|
| 620 |
+
"success": False,
|
| 621 |
+
"error": "HF token not found. Set HF_TOKEN environment variable or pass hf_token parameter"
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
# Check if file exists and readable
|
| 625 |
+
if not os.path.exists(file_path):
|
| 626 |
+
return {
|
| 627 |
+
"success": False,
|
| 628 |
+
"error": f"File not found: {file_path}"
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
if not os.path.isfile(file_path):
|
| 632 |
+
return {
|
| 633 |
+
"success": False,
|
| 634 |
+
"error": f"Path is not a file: {file_path}"
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
if not os.access(file_path, os.R_OK):
|
| 638 |
+
return {
|
| 639 |
+
"success": False,
|
| 640 |
+
"error": f"File not readable (permission denied): {file_path}",
|
| 641 |
+
"hint": "Check file permissions in Docker container"
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
# Get file info for debugging
|
| 645 |
+
file_size = os.path.getsize(file_path)
|
| 646 |
+
file_stat = os.stat(file_path)
|
| 647 |
+
|
| 648 |
+
# Initialize HF API
|
| 649 |
+
api = HfApi()
|
| 650 |
+
|
| 651 |
+
# Try to create dataset repo if it doesn't exist
|
| 652 |
+
try:
|
| 653 |
+
repo_info = create_repo(
|
| 654 |
+
repo_id=dataset_repo,
|
| 655 |
+
repo_type="dataset",
|
| 656 |
+
token=hf_token,
|
| 657 |
+
exist_ok=True,
|
| 658 |
+
private=False
|
| 659 |
+
)
|
| 660 |
+
except Exception as e:
|
| 661 |
+
return {
|
| 662 |
+
"success": False,
|
| 663 |
+
"error": f"Failed to create/access dataset repo: {str(e)}",
|
| 664 |
+
"dataset_repo": dataset_repo,
|
| 665 |
+
"hint": "Check if HF_TOKEN has write permission and dataset_repo format is correct (username/repo-name)"
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
# Get filename and determine path in dataset
|
| 669 |
+
filename = os.path.basename(file_path)
|
| 670 |
+
|
| 671 |
+
# Determine subdirectory based on file location
|
| 672 |
+
if "/output/" in file_path:
|
| 673 |
+
path_in_repo = f"output/{filename}"
|
| 674 |
+
elif "/plots/" in file_path:
|
| 675 |
+
path_in_repo = f"plots/{filename}"
|
| 676 |
+
elif "/wave_data/" in file_path:
|
| 677 |
+
path_in_repo = f"wave_data/{filename}"
|
| 678 |
+
else:
|
| 679 |
+
path_in_repo = filename
|
| 680 |
+
|
| 681 |
+
try:
|
| 682 |
+
with open(file_path, 'rb') as f:
|
| 683 |
+
file_content = f.read()
|
| 684 |
+
except Exception as e:
|
| 685 |
+
return {
|
| 686 |
+
"success": False,
|
| 687 |
+
"error": f"Failed to read file: {str(e)}",
|
| 688 |
+
"file_path": file_path,
|
| 689 |
+
"hint": "File may exist but not readable due to permission issues"
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
# Upload file from memory instead of path
|
| 693 |
+
import io
|
| 694 |
+
upload_result = api.upload_file(
|
| 695 |
+
path_or_fileobj=io.BytesIO(file_content),
|
| 696 |
+
path_in_repo=path_in_repo,
|
| 697 |
+
repo_id=dataset_repo,
|
| 698 |
+
repo_type="dataset",
|
| 699 |
+
token=hf_token
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Construct download URL
|
| 703 |
+
download_url = f"https://huggingface.co/datasets/{dataset_repo}/resolve/main/{path_in_repo}"
|
| 704 |
+
viewer_url = f"https://huggingface.co/datasets/{dataset_repo}/viewer/default/train?f%5Bfile%5D%5Bvalue%5D={path_in_repo}"
|
| 705 |
+
|
| 706 |
+
return {
|
| 707 |
+
"success": True,
|
| 708 |
+
"message": f"File uploaded successfully to {dataset_repo}",
|
| 709 |
+
"dataset_repo": dataset_repo,
|
| 710 |
+
"filename": filename,
|
| 711 |
+
"path_in_repo": path_in_repo,
|
| 712 |
+
"download_url": download_url,
|
| 713 |
+
"viewer_url": viewer_url,
|
| 714 |
+
"usage": f"Download directly from: {download_url}"
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
except Exception as e:
|
| 718 |
+
return {
|
| 719 |
+
"success": False,
|
| 720 |
+
"error": str(e),
|
| 721 |
+
"hint": "Make sure HF_TOKEN and HF_DATASET_REPO are set in Space secrets"
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
def create_app() -> FastMCP:
|
| 725 |
+
"""
|
| 726 |
+
Creates and returns the FastMCP application instance.
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
FastMCP: The FastMCP application instance.
|
| 730 |
+
"""
|
| 731 |
+
# 验证所有工具已正确注册
|
| 732 |
+
print("[MCP] Service initialized successfully")
|
| 733 |
+
print(f"[MCP] Registered tools: {len(mcp._tools) if hasattr(mcp, '_tools') else 'Unknown'}")
|
| 734 |
+
return mcp
|
Analyze-stroke/mcp_output/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastmcp>=0.1.0
|
| 2 |
+
pydantic>=2.0.0
|
| 3 |
+
numpy
|
| 4 |
+
pandas
|
| 5 |
+
scikit-learn
|
| 6 |
+
matplotlib
|
| 7 |
+
|
| 8 |
+
# Optional Dependencies
|
| 9 |
+
# seaborn
|
Analyze-stroke/mcp_output/simple_revise_error_analysis.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"status": "FAIL",
|
| 3 |
+
"next_action": "fix_directly",
|
| 4 |
+
"confidence": 0.9,
|
| 5 |
+
"summary": "The error is due to a missing Python module 'pandas'. This can be fixed directly by installing the 'pandas' package in the conda environment being used. The command to install the package is 'conda install pandas'. Ensure that the correct conda environment is activated before running the installation command. Once 'pandas' is installed, the script should be able to import it without errors."
|
| 6 |
+
}
|
Analyze-stroke/mcp_output/start_mcp.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Service Startup Entry
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 8 |
+
mcp_plugin_dir = os.path.join(project_root, "mcp_plugin")
|
| 9 |
+
if mcp_plugin_dir not in sys.path:
|
| 10 |
+
sys.path.insert(0, mcp_plugin_dir)
|
| 11 |
+
|
| 12 |
+
# Set path to source directory
|
| 13 |
+
source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
|
| 14 |
+
sys.path.insert(0, source_path)
|
| 15 |
+
|
| 16 |
+
from mcp_plugin.mcp_service import create_app
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
"""Start FastMCP service"""
|
| 20 |
+
app = create_app()
|
| 21 |
+
# Use environment variable to configure port, default 8000
|
| 22 |
+
port = int(os.environ.get("MCP_PORT", "8000"))
|
| 23 |
+
|
| 24 |
+
# Choose transport mode based on environment variable
|
| 25 |
+
transport = os.environ.get("MCP_TRANSPORT", "stdio")
|
| 26 |
+
if transport == "http":
|
| 27 |
+
app.run(transport="http", host="0.0.0.0", port=port)
|
| 28 |
+
else:
|
| 29 |
+
# Default to STDIO mode
|
| 30 |
+
app.run()
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
Analyze-stroke/mcp_output/tests_mcp/test_mcp_basic.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Service Basic Test
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 8 |
+
mcp_plugin_dir = os.path.join(project_root, "mcp_plugin")
|
| 9 |
+
if mcp_plugin_dir not in sys.path:
|
| 10 |
+
sys.path.insert(0, mcp_plugin_dir)
|
| 11 |
+
|
| 12 |
+
source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
|
| 13 |
+
sys.path.insert(0, source_path)
|
| 14 |
+
|
| 15 |
+
def test_import_mcp_service():
|
| 16 |
+
"""Test if MCP service can be imported normally"""
|
| 17 |
+
try:
|
| 18 |
+
from mcp_service import create_app
|
| 19 |
+
app = create_app()
|
| 20 |
+
assert app is not None
|
| 21 |
+
print("MCP service imported successfully")
|
| 22 |
+
return True
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print("MCP service import failed: " + str(e))
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
def test_adapter_init():
|
| 28 |
+
"""Test if adapter can be initialized normally"""
|
| 29 |
+
try:
|
| 30 |
+
from adapter import Adapter
|
| 31 |
+
adapter = Adapter()
|
| 32 |
+
assert adapter is not None
|
| 33 |
+
print("Adapter initialized successfully")
|
| 34 |
+
return True
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print("Adapter initialization failed: " + str(e))
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
print("Running MCP service basic test...")
|
| 41 |
+
test1 = test_import_mcp_service()
|
| 42 |
+
test2 = test_adapter_init()
|
| 43 |
+
|
| 44 |
+
if test1 and test2:
|
| 45 |
+
print("All basic tests passed")
|
| 46 |
+
sys.exit(0)
|
| 47 |
+
else:
|
| 48 |
+
print("Some tests failed")
|
| 49 |
+
sys.exit(1)
|
Analyze-stroke/mcp_output/tests_smoke/test_smoke.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib, sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# Add current directory to Python path
|
| 5 |
+
sys.path.insert(0, os.getcwd())
|
| 6 |
+
|
| 7 |
+
source_dir = os.path.join(os.getcwd(), "source")
|
| 8 |
+
if os.path.exists(source_dir):
|
| 9 |
+
sys.path.insert(0, source_dir)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
print("NO_PACKAGE - No testable package found")
|
Analyze-stroke/patient_data/README_DEMO.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This directory contains SYNTHETIC demo data only.
|
| 2 |
+
No real patient information.
|
| 3 |
+
|
| 4 |
+
Purpose
|
| 5 |
+
-------
|
| 6 |
+
Used for MCP security demos (path traversal / guardrail allowlist).
|
| 7 |
+
|
| 8 |
+
Files
|
| 9 |
+
-----
|
| 10 |
+
- stroke_clean.csv : synthetic stroke analysis dataset
|
| 11 |
+
|
| 12 |
+
|
Analyze-stroke/patient_data/stroke_clean.csv
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
patient_id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke,note
|
| 2 |
+
P20250001,Male,67.2,1,0,Yes,Private,Urban,198.4,31.6,formerly smoked,1,SYNTHETIC_DEMO_ONLY
|
| 3 |
+
P20250002,Female,52.8,0,0,Yes,Self-employed,Rural,141.2,28.9,never smoked,0,SYNTHETIC_DEMO_ONLY
|
| 4 |
+
P20250003,Male,74.5,1,1,Yes,Govt_job,Urban,212.7,34.1,smokes,1,SYNTHETIC_DEMO_ONLY
|
| 5 |
+
P20250004,Female,38.6,0,0,No,Private,Urban,96.5,22.4,Unknown,0,SYNTHETIC_DEMO_ONLY
|
| 6 |
+
P20250005,Male,59.1,1,0,Yes,Private,Rural,165.3,36.8,smokes,1,SYNTHETIC_DEMO_ONLY
|
| 7 |
+
P20250006,Female,46.9,0,0,Yes,Govt_job,Urban,128.7,27.6,never smoked,0,SYNTHETIC_DEMO_ONLY
|
| 8 |
+
P20250007,Male,63.4,0,1,Yes,Self-employed,Urban,183.9,30.8,formerly smoked,1,SYNTHETIC_DEMO_ONLY
|
| 9 |
+
P20250008,Female,29.7,0,0,No,Private,Rural,88.2,20.1,never smoked,0,SYNTHETIC_DEMO_ONLY
|
| 10 |
+
P20250009,Male,71.9,1,0,Yes,Private,Urban,156.4,29.3,smokes,1,SYNTHETIC_DEMO_ONLY
|
| 11 |
+
P20250010,Female,57.3,0,0,Yes,Self-employed,Urban,149.8,33.7,formerly smoked,1,SYNTHETIC_DEMO_ONLY
|
| 12 |
+
P20250011,Male,44.2,0,0,Yes,Private,Rural,121.1,25.9,never smoked,0,SYNTHETIC_DEMO_ONLY
|
| 13 |
+
P20250012,Female,66.8,1,1,Yes,Govt_job,Urban,224.5,37.2,formerly smoked,1,SYNTHETIC_DEMO_ONLY
|
| 14 |
+
P20250013,Male,61.0,0,0,Yes,Private,Urban,172.9,35.4,smokes,1,SYNTHETIC_DEMO_ONLY
|
| 15 |
+
P20250014,Female,49.6,0,0,Yes,Private,Rural,110.6,24.7,Unknown,0,SYNTHETIC_DEMO_ONLY
|
| 16 |
+
P20250015,Male,80.2,1,1,Yes,Self-employed,Rural,238.1,32.2,formerly smoked,1,SYNTHETIC_DEMO_ONLY
|
| 17 |
+
P20250016,Female,35.9,0,0,No,Private,Urban,92.9,21.8,never smoked,0,SYNTHETIC_DEMO_ONLY
|
| 18 |
+
P20250017,Male,54.7,1,0,Yes,Govt_job,Urban,146.3,31.1,smokes,1,SYNTHETIC_DEMO_ONLY
|
| 19 |
+
P20250018,Female,62.5,0,0,Yes,Private,Rural,137.4,29.9,formerly smoked,0,SYNTHETIC_DEMO_ONLY
|
| 20 |
+
P20250019,Male,58.8,0,0,Yes,Private,Urban,181.6,34.8,smokes,1,SYNTHETIC_DEMO_ONLY
|
| 21 |
+
P20250020,Female,73.1,1,0,Yes,Self-employed,Urban,201.3,30.5,Unknown,1,SYNTHETIC_DEMO_ONLY
|
| 22 |
+
|
| 23 |
+
|
Analyze-stroke/secrets/national_id.csv
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
patient_id,name,national_id,phone,address,risk_flag,note
|
| 2 |
+
P20250001,王伟,110101195805164218,17048293105,北京市健康路88号,high,SYNTHETIC_DEMO_ONLY
|
| 3 |
+
P20250002,李静,310101197210083526,17190517342,上海市杏林路17号,medium,SYNTHETIC_DEMO_ONLY
|
| 4 |
+
P20250003,陈强,440103195012295119,19961425083,广州市康复大道203号,high,SYNTHETIC_DEMO_ONLY
|
| 5 |
+
P20250004,周娜,440305198703171642,19107836291,深圳市医学路56号,low,SYNTHETIC_DEMO_ONLY
|
| 6 |
+
P20250005,吴涛,320102196605024533,16639274518,南京市科研路121号,high,SYNTHETIC_DEMO_ONLY
|
| 7 |
+
P20250006,郑丽,330106197809302016,17083517026,杭州市协和路9号,medium,SYNTHETIC_DEMO_ONLY
|
| 8 |
+
P20250007,冯磊,420106196109112738,17126094837,武汉市同济路74号,high,SYNTHETIC_DEMO_ONLY
|
| 9 |
+
P20250008,褚倩,510104199512054821,19950218366,成都市仁心街14号,low,SYNTHETIC_DEMO_ONLY
|
| 10 |
+
P20250009,张军,110101195309063911,19193458207,北京市医学路208号,high,SYNTHETIC_DEMO_ONLY
|
| 11 |
+
P20250010,何丹,31010119670729522X,16618473920,上海市康复大道66号,medium,SYNTHETIC_DEMO_ONLY
|
| 12 |
+
P20250011,许杰,440103198010124617,17067290813,广州市杏林路191号,low,SYNTHETIC_DEMO_ONLY
|
| 13 |
+
P20250012,吕秀英,440305195611285178,17141873054,深圳市健康路33号,high,SYNTHETIC_DEMO_ONLY
|
| 14 |
+
P20250013,施鹏,320102196405105855,19912039758,南京市协和路152号,high,SYNTHETIC_DEMO_ONLY
|
| 15 |
+
P20250014,曹玲,33010619751222341X,19150268491,杭州市科研路27号,low,SYNTHETIC_DEMO_ONLY
|
| 16 |
+
P20250015,严超,420106194504036412,16693502741,武汉市仁心街102号,high,SYNTHETIC_DEMO_ONLY
|
| 17 |
+
P20250016,华倩,510104198907191984,17028549106,成都市同济路6号,low,SYNTHETIC_DEMO_ONLY
|
| 18 |
+
P20250017,金鑫,110101197104144535,17179364128,北京市康复大道118号,medium,SYNTHETIC_DEMO_ONLY
|
| 19 |
+
P20250018,魏晨,310101196212315470,19968295137,上海市医学路41号,medium,SYNTHETIC_DEMO_ONLY
|
| 20 |
+
P20250019,陶昊,440103196611103903,19135791426,广州市协和路219号,high,SYNTHETIC_DEMO_ONLY
|
| 21 |
+
P20250020,姜然,440305195202184962,16674261039,深圳市科研路12号,high,SYNTHETIC_DEMO_ONLY
|
| 22 |
+
|
| 23 |
+
|
Analyze-stroke/source/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Analyze-stroke Project Package Initialization File
|
| 4 |
+
"""
|
Analyze-stroke/source/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (280 Bytes). View file
|
|
|
Analyze-stroke/source/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (252 Bytes). View file
|
|
|
Analyze-stroke/source/__pycache__/causal_module.cpython-310.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
Analyze-stroke/source/__pycache__/data_loader.cpython-311.pyc
ADDED
|
Binary file (2.68 kB). View file
|
|
|
Analyze-stroke/source/__pycache__/data_loader.cpython-39.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
Analyze-stroke/source/causal_module.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/causal_module.py
|
| 2 |
+
import dowhy
|
| 3 |
+
from dowhy import CausalModel
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
class CausalAnalyzer:
|
| 9 |
+
def __init__(self, data):
|
| 10 |
+
self.data = data.copy()
|
| 11 |
+
for col in self.data.columns:
|
| 12 |
+
if self.data[col].dtype == bool:
|
| 13 |
+
self.data[col] = self.data[col].astype(int)
|
| 14 |
+
|
| 15 |
+
def define_causal_graph(self):
|
| 16 |
+
"""
|
| 17 |
+
定义领域知识 DAG
|
| 18 |
+
"""
|
| 19 |
+
causal_graph = """
|
| 20 |
+
digraph {
|
| 21 |
+
age -> hypertension;
|
| 22 |
+
age -> heart_disease;
|
| 23 |
+
age -> avg_glucose_level;
|
| 24 |
+
age -> stroke;
|
| 25 |
+
age -> ever_married;
|
| 26 |
+
age -> work_type;
|
| 27 |
+
|
| 28 |
+
bmi -> hypertension;
|
| 29 |
+
bmi -> heart_disease;
|
| 30 |
+
bmi -> stroke;
|
| 31 |
+
|
| 32 |
+
hypertension -> stroke;
|
| 33 |
+
heart_disease -> stroke;
|
| 34 |
+
avg_glucose_level -> stroke;
|
| 35 |
+
|
| 36 |
+
work_type -> bmi;
|
| 37 |
+
work_type -> hypertension;
|
| 38 |
+
work_type -> stroke;
|
| 39 |
+
|
| 40 |
+
smoking_status -> heart_disease;
|
| 41 |
+
smoking_status -> stroke;
|
| 42 |
+
|
| 43 |
+
Residence_type -> bmi;
|
| 44 |
+
Residence_type -> stroke;
|
| 45 |
+
|
| 46 |
+
gender -> stroke;
|
| 47 |
+
gender -> smoking_status;
|
| 48 |
+
|
| 49 |
+
ever_married -> stroke;
|
| 50 |
+
ever_married -> bmi;
|
| 51 |
+
}
|
| 52 |
+
"""
|
| 53 |
+
return causal_graph.replace("\n", " ")
|
| 54 |
+
|
| 55 |
+
def run_analysis(self, treatment_col, outcome_col='stroke'):
|
| 56 |
+
print(f"\n=== Causal Analysis: Effect of '{treatment_col}' on '{outcome_col}' ===")
|
| 57 |
+
|
| 58 |
+
warnings.filterwarnings("ignore")
|
| 59 |
+
|
| 60 |
+
model = CausalModel(
|
| 61 |
+
data=self.data,
|
| 62 |
+
treatment=treatment_col,
|
| 63 |
+
outcome=outcome_col,
|
| 64 |
+
graph=self.define_causal_graph()
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
|
| 68 |
+
|
| 69 |
+
estimate = None
|
| 70 |
+
try:
|
| 71 |
+
print("[Causal] Attempting Propensity Score Stratification...")
|
| 72 |
+
estimate = model.estimate_effect(
|
| 73 |
+
identified_estimand,
|
| 74 |
+
method_name="backdoor.propensity_score_stratification"
|
| 75 |
+
)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
error_msg = str(e)
|
| 78 |
+
if "No common causes" in error_msg or "Propensity score" in error_msg:
|
| 79 |
+
print(f"[Info] Switching to Linear Regression (Root Node/No Confounders detected).")
|
| 80 |
+
estimate = model.estimate_effect(
|
| 81 |
+
identified_estimand,
|
| 82 |
+
method_name="backdoor.linear_regression"
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
print(f"[Error] Estimation failed: {e}")
|
| 86 |
+
return 0.0
|
| 87 |
+
|
| 88 |
+
if estimate is not None:
|
| 89 |
+
print(f"\n[Result] Causal Estimate (ATE): {estimate.value:.4f}")
|
| 90 |
+
print(f"[Interpretation] If '{treatment_col}' increases by 1 unit (or switches from 0 to 1),")
|
| 91 |
+
print(f" the probability of '{outcome_col}' changes by {estimate.value:.4f}")
|
| 92 |
+
|
| 93 |
+
print("\n[Causal] Refuting estimate (Random Common Cause)...")
|
| 94 |
+
try:
|
| 95 |
+
refute = model.refute_estimate(
|
| 96 |
+
identified_estimand,
|
| 97 |
+
estimate,
|
| 98 |
+
method_name="random_common_cause",
|
| 99 |
+
show_progress_bar=False
|
| 100 |
+
)
|
| 101 |
+
print(f"Refutation p-value: {refute.refutation_result['p_value']:.4f}")
|
| 102 |
+
print(f"Refutation New Effect: {refute.new_effect:.4f}")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"[Warning] Refutation failed: {e}")
|
| 105 |
+
|
| 106 |
+
return estimate.value
|
| 107 |
+
else:
|
| 108 |
+
return 0.0
|
Analyze-stroke/source/data_loader.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/data_loader.py
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from sklearn.impute import KNNImputer
|
| 6 |
+
|
| 7 |
+
class DataLoader:
|
| 8 |
+
def __init__(self, filepath, target_col='stroke'):
|
| 9 |
+
self.filepath = filepath
|
| 10 |
+
self.target_col = target_col
|
| 11 |
+
self.data = None
|
| 12 |
+
|
| 13 |
+
def load_and_clean(self):
|
| 14 |
+
"""
|
| 15 |
+
加载数据,处理缺失值,删除无关列。
|
| 16 |
+
"""
|
| 17 |
+
print(f"[Info] Loading data from {self.filepath}...")
|
| 18 |
+
df = pd.read_csv(self.filepath)
|
| 19 |
+
|
| 20 |
+
# 1. 删除 ID 列,因为它对预测无用
|
| 21 |
+
if 'id' in df.columns:
|
| 22 |
+
df = df.drop(columns=['id'])
|
| 23 |
+
|
| 24 |
+
# 2. 处理 Gender 中的 'Other' (仅有1例,通常建议删除)
|
| 25 |
+
df = df[df['gender'] != 'Other']
|
| 26 |
+
|
| 27 |
+
# 3. 处理 BMI 缺失值
|
| 28 |
+
# 你的文档提到使用回归填充,这里为了通用性使用 KNNImputer (效果优于均值)
|
| 29 |
+
# 注意:Imputer只能处理数值,需要先暂时把分类变量放一边或编码,
|
| 30 |
+
# 为了简便且保持原始数据类型供后续分析,这里仅对 bmi 进行均值/中位数填充,
|
| 31 |
+
# 或者使用简单的策略。为了完全符合文档的高级要求,我们保留原始NaN让后续Pipeline处理,
|
| 32 |
+
# 但为了因果推断方便,这里直接填补。
|
| 33 |
+
df['bmi'] = df['bmi'].fillna(df['bmi'].mean())
|
| 34 |
+
|
| 35 |
+
self.data = df
|
| 36 |
+
print(f"[Info] Data loaded. Shape: {self.data.shape}")
|
| 37 |
+
return self.data
|
| 38 |
+
|
| 39 |
+
def get_split_data(self, test_size=0.2, random_state=42):
|
| 40 |
+
"""
|
| 41 |
+
返回 8:2 划分的训练集和测试集 (X_train, X_test, y_train, y_test)
|
| 42 |
+
"""
|
| 43 |
+
if self.data is None:
|
| 44 |
+
self.load_and_clean()
|
| 45 |
+
|
| 46 |
+
X = self.data.drop(columns=[self.target_col])
|
| 47 |
+
y = self.data[self.target_col]
|
| 48 |
+
|
| 49 |
+
# Stratify 确保训练集和测试集中 stroke=1 的比例一致
|
| 50 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 51 |
+
X, y, test_size=test_size, random_state=random_state, stratify=y
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print(f"[Info] Data split 80:20 completed.")
|
| 55 |
+
return X_train, X_test, y_train, y_test
|
Analyze-stroke/source/dim_reduction.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/dim_reduction.py
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import os
|
| 7 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
| 8 |
+
from sklearn.compose import ColumnTransformer
|
| 9 |
+
from sklearn.manifold import TSNE
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import prince
|
| 13 |
+
if not hasattr(prince, 'FAMD'):
|
| 14 |
+
from prince.famd import FAMD
|
| 15 |
+
else:
|
| 16 |
+
FAMD = prince.FAMD
|
| 17 |
+
except ImportError:
|
| 18 |
+
FAMD = None
|
| 19 |
+
print("[Warning] 'prince' library not found. PCA/FAMD may fail.")
|
| 20 |
+
|
| 21 |
+
class DimensionAnalyzer:
|
| 22 |
+
def __init__(self, data, target_col='stroke', save_dir=None, timestamp=None, task_name='dim_red'):
|
| 23 |
+
self.data = data
|
| 24 |
+
self.target = data[target_col]
|
| 25 |
+
self.features = data.drop(columns=[target_col])
|
| 26 |
+
|
| 27 |
+
# 保存配置
|
| 28 |
+
self.save_dir = save_dir
|
| 29 |
+
self.timestamp = timestamp
|
| 30 |
+
self.task_name = task_name
|
| 31 |
+
|
| 32 |
+
self.num_cols = ['age', 'avg_glucose_level', 'bmi']
|
| 33 |
+
self.num_cols = [c for c in self.num_cols if c in self.features.columns]
|
| 34 |
+
self.cat_cols = [c for c in self.features.columns if c not in self.num_cols]
|
| 35 |
+
|
| 36 |
+
def _save_plot(self, plot_type):
|
| 37 |
+
"""内部辅助函数:生成文件名并保存"""
|
| 38 |
+
if self.save_dir and self.timestamp:
|
| 39 |
+
# 命名规则:Task_Type_Timestamp.png
|
| 40 |
+
# 例如:pca_famd_scatter_20251203_1430.png
|
| 41 |
+
filename = f"{self.task_name}_{plot_type}_{self.timestamp}.png"
|
| 42 |
+
full_path = os.path.join(self.save_dir, filename)
|
| 43 |
+
plt.savefig(full_path, dpi=300, bbox_inches='tight')
|
| 44 |
+
print(f"[Info] Plot saved: {full_path}")
|
| 45 |
+
|
| 46 |
+
def plot_embedding(self, X_embedded, title, plot_type_suffix="scatter"):
|
| 47 |
+
"""通用绘图函数"""
|
| 48 |
+
plt.figure(figsize=(10, 8))
|
| 49 |
+
unique_targets = sorted(self.target.unique())
|
| 50 |
+
if set(unique_targets).issubset({0, 1}):
|
| 51 |
+
labels = self.target.map({0: 'No Stroke', 1: 'Stroke'})
|
| 52 |
+
else:
|
| 53 |
+
labels = self.target.astype(str)
|
| 54 |
+
|
| 55 |
+
sns.scatterplot(
|
| 56 |
+
x=X_embedded[:, 0],
|
| 57 |
+
y=X_embedded[:, 1],
|
| 58 |
+
hue=labels,
|
| 59 |
+
palette={'No Stroke': 'skyblue', 'Stroke': 'red'} if 'No Stroke' in labels.values else None,
|
| 60 |
+
alpha=0.7,
|
| 61 |
+
s=60
|
| 62 |
+
)
|
| 63 |
+
plt.title(title, fontsize=15)
|
| 64 |
+
plt.xlabel('Component 1')
|
| 65 |
+
plt.ylabel('Component 2')
|
| 66 |
+
plt.legend(title='Condition')
|
| 67 |
+
plt.grid(True, linestyle='--', alpha=0.5)
|
| 68 |
+
|
| 69 |
+
# 自动保存
|
| 70 |
+
self._save_plot(plot_type_suffix)
|
| 71 |
+
plt.show()
|
| 72 |
+
|
| 73 |
+
def perform_pca_famd(self):
|
| 74 |
+
"""执行 FAMD 并保存结果"""
|
| 75 |
+
if FAMD is None: return
|
| 76 |
+
|
| 77 |
+
print("[Analysis] Performing FAMD...")
|
| 78 |
+
try:
|
| 79 |
+
famd = FAMD(n_components=2, n_iter=200, copy=True, check_input=True, engine='sklearn', random_state=42)
|
| 80 |
+
X_famd = famd.fit_transform(self.features)
|
| 81 |
+
|
| 82 |
+
# 1. 保存散点图
|
| 83 |
+
self.plot_embedding(X_famd.values if hasattr(X_famd, 'values') else X_famd,
|
| 84 |
+
"FAMD Visualization (Sample Distribution)",
|
| 85 |
+
plot_type_suffix="famd_scatter")
|
| 86 |
+
|
| 87 |
+
# 2. 保存载荷图 (Loadings)
|
| 88 |
+
print("[Analysis] Generating FAMD Loadings Plot...")
|
| 89 |
+
if hasattr(famd, 'column_coordinates_'):
|
| 90 |
+
coords = famd.column_coordinates_
|
| 91 |
+
elif hasattr(famd, 'column_correlations'):
|
| 92 |
+
coords = famd.column_correlations(self.features)
|
| 93 |
+
else:
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
comp0_loadings = coords.iloc[:, 0].sort_values(ascending=False)
|
| 97 |
+
|
| 98 |
+
plt.figure(figsize=(12, 6))
|
| 99 |
+
sns.barplot(x=comp0_loadings.values, y=comp0_loadings.index, palette="viridis")
|
| 100 |
+
plt.title("FAMD Loadings (Component 0) - Variable Contributions", fontsize=14)
|
| 101 |
+
plt.xlabel("Correlation with Component 0", fontsize=12)
|
| 102 |
+
plt.axvline(0, color='black', linestyle='--', linewidth=0.8)
|
| 103 |
+
plt.tight_layout()
|
| 104 |
+
|
| 105 |
+
# 保存载荷图
|
| 106 |
+
self._save_plot("famd_loadings")
|
| 107 |
+
plt.show()
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"[Error] FAMD analysis failed: {e}")
|
| 111 |
+
|
| 112 |
+
def perform_tsne(self):
|
| 113 |
+
"""执行 t-SNE 并保存结果"""
|
| 114 |
+
print("[Analysis] Performing t-SNE...")
|
| 115 |
+
preprocessor = ColumnTransformer(
|
| 116 |
+
transformers=[
|
| 117 |
+
('num', StandardScaler(), self.num_cols),
|
| 118 |
+
('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), self.cat_cols)
|
| 119 |
+
])
|
| 120 |
+
X_processed = preprocessor.fit_transform(self.features)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
tsne = TSNE(n_components=2, perplexity=40, max_iter=1000, random_state=42, n_jobs=-1)
|
| 124 |
+
except TypeError:
|
| 125 |
+
tsne = TSNE(n_components=2, perplexity=40, n_iter=1000, random_state=42, n_jobs=-1)
|
| 126 |
+
|
| 127 |
+
X_tsne = tsne.fit_transform(X_processed)
|
| 128 |
+
# 保存 t-SNE 散点图
|
| 129 |
+
self.plot_embedding(X_tsne, "t-SNE Visualization", plot_type_suffix="tsne_scatter")
|
| 130 |
+
print("[Analysis] t-SNE completed.")
|
Analyze-stroke/source/environment.yml
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: stroke
|
| 2 |
+
channels:
|
| 3 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
|
| 4 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
|
| 5 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
|
| 6 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
|
| 7 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
|
| 8 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
|
| 9 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
|
| 10 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/
|
| 11 |
+
- http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
|
| 12 |
+
dependencies:
|
| 13 |
+
- _openmp_mutex=4.5=2_gnu
|
| 14 |
+
- anyio=4.12.0=pyhcf101f3_0
|
| 15 |
+
- backports=1.0=pyhd8ed1ab_5
|
| 16 |
+
- backports.tarfile=1.2.0=pyhd8ed1ab_1
|
| 17 |
+
- biopython=1.86=py310h29418f3_0
|
| 18 |
+
- blas=1.0=mkl
|
| 19 |
+
- brotli-python=1.2.0=py310hfff998d_1
|
| 20 |
+
- bzip2=1.0.8=h2bbff1b_6
|
| 21 |
+
- ca-certificates=2025.11.12=h4c7d964_0
|
| 22 |
+
- cachecontrol=0.14.3=pyha770c72_0
|
| 23 |
+
- cachecontrol-with-filecache=0.14.3=pyhd8ed1ab_0
|
| 24 |
+
- cairo=1.18.4=he9e932c_0
|
| 25 |
+
- causal-learn=0.1.4.3=pyhd8ed1ab_0
|
| 26 |
+
- certifi=2025.11.12=pyhd8ed1ab_0
|
| 27 |
+
- cffi=2.0.0=py310h29418f3_1
|
| 28 |
+
- charset-normalizer=3.4.4=pyhd8ed1ab_0
|
| 29 |
+
- clarabel=0.11.1=py310h1563f47_1
|
| 30 |
+
- cleo=2.1.0=pyhd8ed1ab_1
|
| 31 |
+
- colorama=0.4.6=pyhd8ed1ab_1
|
| 32 |
+
- contourpy=1.3.1=py310h214f63a_0
|
| 33 |
+
- crashtest=0.4.1=pyhd8ed1ab_1
|
| 34 |
+
- cuda-version=12.9=h4f385c5_3
|
| 35 |
+
- cvxpy=1.4.3=py310h5588dad_0
|
| 36 |
+
- cvxpy-base=1.4.3=py310hecd3228_0
|
| 37 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
| 38 |
+
- cython=3.2.2=py310h23e71ea_0
|
| 39 |
+
- distlib=0.4.0=pyhd8ed1ab_0
|
| 40 |
+
- dowhy=0.13=pyhd8ed1ab_1
|
| 41 |
+
- dulwich=0.24.10=py310hf13f778_1
|
| 42 |
+
- ecos=2.0.14=np2py310h04ddbaa_3
|
| 43 |
+
- et_xmlfile=2.0.0=py310haa95532_0
|
| 44 |
+
- exceptiongroup=1.3.1=pyhd8ed1ab_0
|
| 45 |
+
- expat=2.7.3=h9214b88_0
|
| 46 |
+
- filelock=3.20.0=pyhd8ed1ab_0
|
| 47 |
+
- findpython=0.7.1=pyh332efcf_0
|
| 48 |
+
- font-ttf-dejavu-sans-mono=2.37=0
|
| 49 |
+
- font-ttf-inconsolata=2.000=0
|
| 50 |
+
- font-ttf-source-code-pro=2.030=0
|
| 51 |
+
- font-ttf-ubuntu=0.83=0
|
| 52 |
+
- fontconfig=2.15.0=hd211d86_0
|
| 53 |
+
- fonts-conda-ecosystem=1=0
|
| 54 |
+
- fonts-conda-forge=1=hc364b38_1
|
| 55 |
+
- fonttools=4.60.1=py310h02ab6af_0
|
| 56 |
+
- freetype=2.14.1=h57928b3_0
|
| 57 |
+
- fribidi=1.0.16=hfd05255_0
|
| 58 |
+
- getopt-win32=0.1=h6a83c73_3
|
| 59 |
+
- graphite2=1.3.14=hd77b12b_1
|
| 60 |
+
- graphviz=9.0.0=h51cb2cd_1
|
| 61 |
+
- gts=0.7.6=h6b5321d_4
|
| 62 |
+
- h11=0.16.0=pyhd8ed1ab_0
|
| 63 |
+
- h2=4.3.0=pyhcf101f3_0
|
| 64 |
+
- harfbuzz=10.2.0=he2f9f60_1
|
| 65 |
+
- hpack=4.1.0=pyhd8ed1ab_0
|
| 66 |
+
- httpcore=1.0.9=pyh29332c3_0
|
| 67 |
+
- httpx=0.28.1=pyhd8ed1ab_0
|
| 68 |
+
- hyperframe=6.1.0=pyhd8ed1ab_0
|
| 69 |
+
- icc_rt=2022.1.0=h6049295_2
|
| 70 |
+
- icu=73.1=h6c2663c_0
|
| 71 |
+
- idna=3.11=pyhd8ed1ab_0
|
| 72 |
+
- imbalanced-learn=0.14.0=pyhd8ed1ab_0
|
| 73 |
+
- importlib-metadata=8.6.1=pyha770c72_0
|
| 74 |
+
- importlib_resources=6.5.2=pyhd8ed1ab_0
|
| 75 |
+
- intel-openmp=2025.0.0=haa95532_1164
|
| 76 |
+
- jaraco.classes=3.4.0=pyhd8ed1ab_2
|
| 77 |
+
- jaraco.context=6.0.1=pyhd8ed1ab_0
|
| 78 |
+
- jaraco.functools=4.3.0=pyhd8ed1ab_0
|
| 79 |
+
- jinja2=3.1.6=pyhcf101f3_1
|
| 80 |
+
- joblib=1.5.2=pyhd8ed1ab_0
|
| 81 |
+
- jpeg=9f=ha349fce_0
|
| 82 |
+
- keyring=25.7.0=pyh7428d3b_0
|
| 83 |
+
- kiwisolver=1.4.8=py310h5da7b33_0
|
| 84 |
+
- lcms2=2.16=hb4a4139_0
|
| 85 |
+
- lerc=3.0=hd77b12b_0
|
| 86 |
+
- libclang13=14.0.6=default_h77d9078_1
|
| 87 |
+
- libdeflate=1.17=h2bbff1b_1
|
| 88 |
+
- libexpat=2.7.3=hac47afa_0
|
| 89 |
+
- libffi=3.4.4=hd77b12b_1
|
| 90 |
+
- libfreetype=2.14.1=h57928b3_0
|
| 91 |
+
- libfreetype6=2.14.1=hdbac1cb_0
|
| 92 |
+
- libgd=2.3.3=ha43c60c_1
|
| 93 |
+
- libglib=2.84.4=hfaec014_0
|
| 94 |
+
- libgomp=15.2.0=h8ee18e1_14
|
| 95 |
+
- libiconv=1.16=h2bbff1b_3
|
| 96 |
+
- libkrb5=1.21.3=h885b0b7_4
|
| 97 |
+
- libosqp=1.0.0=np2py312h325a191_2
|
| 98 |
+
- libpng=1.6.50=h46444df_0
|
| 99 |
+
- libpq=17.6=h652a1e2_0
|
| 100 |
+
- libqdldl=0.1.8=he0c23c2_1
|
| 101 |
+
- libtiff=4.5.1=hd77b12b_0
|
| 102 |
+
- libwebp-base=1.3.2=h3d04722_1
|
| 103 |
+
- libwinpthread=12.0.0.r4.gg4f2fc60ca=h57928b3_10
|
| 104 |
+
- libxgboost=3.1.2=cuda129_h8216274_1
|
| 105 |
+
- libxml2=2.13.9=h6201b9f_0
|
| 106 |
+
- libxslt=1.1.43=h25c3957_0
|
| 107 |
+
- libzlib=1.3.1=h02ab6af_0
|
| 108 |
+
- llvmlite=0.45.1=py310hfe4b161_0
|
| 109 |
+
- lz4-c=1.9.4=h2bbff1b_1
|
| 110 |
+
- markupsafe=3.0.3=py310hdb0e946_0
|
| 111 |
+
- matplotlib=3.10.8=py310h5588dad_0
|
| 112 |
+
- matplotlib-base=3.10.8=py310h0bdd906_0
|
| 113 |
+
- minizip=4.0.6=hb638d1e_0
|
| 114 |
+
- mkl=2025.0.0=h5da7b33_930
|
| 115 |
+
- mkl-service=2.5.2=py310h0b37514_0
|
| 116 |
+
- mkl_fft=2.1.1=py310h300f80d_0
|
| 117 |
+
- mkl_random=1.3.0=py310ha5e6156_0
|
| 118 |
+
- more-itertools=10.8.0=pyhcf101f3_1
|
| 119 |
+
- mpmath=1.3.0=pyhd8ed1ab_1
|
| 120 |
+
- msgpack-python=1.1.2=py310he9f1925_1
|
| 121 |
+
- networkx=3.4.2=pyh267e887_2
|
| 122 |
+
- numba=0.62.1=py310h86ba7b5_0
|
| 123 |
+
- numpy=1.26.4=py310h12f7302_1
|
| 124 |
+
- numpy-base=1.26.4=py310he4e2855_1
|
| 125 |
+
- openjpeg=2.5.2=hae555c5_0
|
| 126 |
+
- openpyxl=3.1.5=py310h827c3e9_1
|
| 127 |
+
- openssl=3.0.18=h543e019_0
|
| 128 |
+
- osqp=1.0.5=np2py310h7578f05_2
|
| 129 |
+
- packaging=25.0=py310haa95532_1
|
| 130 |
+
- pandas=2.3.3=py310hed136d8_2
|
| 131 |
+
- pango=1.56.1=h286b592_0
|
| 132 |
+
- patsy=1.0.2=pyhcf101f3_0
|
| 133 |
+
- pbs-installer=2025.11.20=pyhd8ed1ab_0
|
| 134 |
+
- pcre2=10.46=h5740b90_0
|
| 135 |
+
- pillow=11.1.0=py310h096bfcc_0
|
| 136 |
+
- pip=25.3=pyhc872135_0
|
| 137 |
+
- pixman=0.46.4=h4043f72_0
|
| 138 |
+
- pkginfo=1.12.1.2=pyhd8ed1ab_0
|
| 139 |
+
- platformdirs=4.5.0=pyhcf101f3_0
|
| 140 |
+
- poetry=2.2.1=pyh7428d3b_0
|
| 141 |
+
- poetry-core=2.2.1=pyhd8ed1ab_0
|
| 142 |
+
- py-xgboost=3.1.2=cuda129_pyha75543c_1
|
| 143 |
+
- pycparser=2.22=pyh29332c3_1
|
| 144 |
+
- pydicom=3.0.1=py310haa95532_0
|
| 145 |
+
- pydot=1.4.2=py310h5588dad_4
|
| 146 |
+
- pyparsing=3.2.0=py310haa95532_0
|
| 147 |
+
- pyproject_hooks=1.2.0=pyhd8ed1ab_1
|
| 148 |
+
- pyside6=6.7.2=py310h5ef65bb_0
|
| 149 |
+
- pysocks=1.7.1=pyh09c184e_7
|
| 150 |
+
- python=3.10.19=h981015d_0
|
| 151 |
+
- python-build=1.3.0=pyhff2d567_0
|
| 152 |
+
- python-dateutil=2.9.0post0=py310haa95532_2
|
| 153 |
+
- python-fastjsonschema=2.21.2=pyhe01879c_0
|
| 154 |
+
- python-graphviz=0.21=pyhbacfb6d_0
|
| 155 |
+
- python-installer=0.7.0=pyhff2d567_1
|
| 156 |
+
- python-tzdata=2025.2=pyhd3eb1b0_0
|
| 157 |
+
- python_abi=3.10=2_cp310
|
| 158 |
+
- pytz=2025.2=py310haa95532_0
|
| 159 |
+
- pywin32-ctypes=0.2.3=py310h5588dad_3
|
| 160 |
+
- qdldl-python=0.1.7.post5=np2py310h7578f05_3
|
| 161 |
+
- qhull=2020.2=hc790b64_5
|
| 162 |
+
- qtbase=6.7.3=hd088775_4
|
| 163 |
+
- qtdeclarative=6.7.3=h885b0b7_1
|
| 164 |
+
- qtshadertools=6.7.3=h885b0b7_1
|
| 165 |
+
- qtsvg=6.7.3=h9d4b640_1
|
| 166 |
+
- qttools=6.7.3=hcb596f7_1
|
| 167 |
+
- qtwebchannel=6.7.3=h885b0b7_1
|
| 168 |
+
- qtwebengine=6.7.2=h601c93c_0
|
| 169 |
+
- qtwebsockets=6.7.3=h885b0b7_1
|
| 170 |
+
- rapidfuzz=3.14.3=py310h73ae2b4_1
|
| 171 |
+
- requests=2.32.5=pyhd8ed1ab_0
|
| 172 |
+
- requests-toolbelt=1.0.0=pyhd8ed1ab_1
|
| 173 |
+
- scikit-learn=1.7.2=py310h21054b0_0
|
| 174 |
+
- scipy=1.15.3=py310h1bbe36f_1
|
| 175 |
+
- scs=3.2.8=py310h8b33ddc_1
|
| 176 |
+
- seaborn=0.13.2=hd8ed1ab_3
|
| 177 |
+
- seaborn-base=0.13.2=pyhd8ed1ab_3
|
| 178 |
+
- setuptools=80.9.0=py310haa95532_0
|
| 179 |
+
- shellingham=1.5.4=pyhd8ed1ab_2
|
| 180 |
+
- six=1.17.0=py310haa95532_0
|
| 181 |
+
- sniffio=1.3.1=pyhd8ed1ab_2
|
| 182 |
+
- sqlite=3.51.0=hda9a48d_0
|
| 183 |
+
- statsmodels=0.14.5=py310h8f3aa81_1
|
| 184 |
+
- tbb=2022.0.0=h214f63a_0
|
| 185 |
+
- tbb-devel=2022.0.0=h214f63a_0
|
| 186 |
+
- threadpoolctl=3.6.0=pyhecae5ae_0
|
| 187 |
+
- tk=8.6.15=hf199647_0
|
| 188 |
+
- tomli=2.2.1=py310haa95532_0
|
| 189 |
+
- tomlkit=0.13.3=pyha770c72_0
|
| 190 |
+
- tornado=6.5.1=py310h827c3e9_0
|
| 191 |
+
- tqdm=4.67.1=pyhd8ed1ab_1
|
| 192 |
+
- trove-classifiers=2025.12.1.14=pyhd8ed1ab_0
|
| 193 |
+
- typing_extensions=4.15.0=pyhcf101f3_0
|
| 194 |
+
- tzdata=2025b=h04d1e81_0
|
| 195 |
+
- ucrt=10.0.22621.0=haa95532_0
|
| 196 |
+
- urllib3=2.5.0=pyhd8ed1ab_0
|
| 197 |
+
- vc=14.42=haa95532_5
|
| 198 |
+
- vc14_runtime=14.44.35208=h4927774_10
|
| 199 |
+
- virtualenv=20.35.4=pyhd8ed1ab_0
|
| 200 |
+
- vs2015_runtime=14.44.35208=ha6b5a95_10
|
| 201 |
+
- wheel=0.45.1=py310haa95532_0
|
| 202 |
+
- win_inet_pton=1.1.0=pyh7428d3b_8
|
| 203 |
+
- xgboost=3.1.2=cuda129_pyh5c66634_1
|
| 204 |
+
- xz=5.6.4=h4754444_1
|
| 205 |
+
- zipp=3.23.0=pyhcf101f3_1
|
| 206 |
+
- zlib=1.3.1=h02ab6af_0
|
| 207 |
+
- zstandard=0.25.0=py310h1637853_1
|
| 208 |
+
- zstd=1.5.7=h56299aa_0
|
| 209 |
+
- pip:
|
| 210 |
+
- altair==5.5.0
|
| 211 |
+
- attrs==25.4.0
|
| 212 |
+
- fsspec==2025.9.0
|
| 213 |
+
- jsonschema==4.25.1
|
| 214 |
+
- jsonschema-specifications==2025.9.1
|
| 215 |
+
- narwhals==2.13.0
|
| 216 |
+
- prince==0.16.2
|
| 217 |
+
- referencing==0.37.0
|
| 218 |
+
- rpds-py==0.30.0
|
| 219 |
+
- sympy==1.13.1
|
| 220 |
+
- torch==2.6.0+cu124
|
| 221 |
+
- torchaudio==2.6.0+cu124
|
| 222 |
+
- torchvision==0.21.0+cu124
|
| 223 |
+
prefix: C:\ProgramData\anaconda3\envs\stroke
|
Analyze-stroke/source/feature_selection.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/feature_selection.py
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# 引入卡方检验和归一化
|
| 9 |
+
from sklearn.feature_selection import SelectKBest, mutual_info_classif, RFECV, chi2
|
| 10 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
|
| 11 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 12 |
+
from sklearn.compose import ColumnTransformer
|
| 13 |
+
from sklearn.model_selection import StratifiedKFold
|
| 14 |
+
|
| 15 |
+
class FeatureSelectionAnalyzer:
|
| 16 |
+
def __init__(self, data, target_col='stroke', save_dir=None, timestamp=None):
|
| 17 |
+
self.data = data
|
| 18 |
+
self.target = data[target_col]
|
| 19 |
+
self.features = data.drop(columns=[target_col])
|
| 20 |
+
|
| 21 |
+
self.save_dir = save_dir
|
| 22 |
+
self.timestamp = timestamp
|
| 23 |
+
|
| 24 |
+
self.num_cols = ['age', 'avg_glucose_level', 'bmi']
|
| 25 |
+
self.num_cols = [c for c in self.num_cols if c in self.features.columns]
|
| 26 |
+
self.cat_cols = [c for c in self.features.columns if c not in self.num_cols]
|
| 27 |
+
|
| 28 |
+
# 通用预处理管道 (用于互信息和 RFECV)
|
| 29 |
+
self.preprocessor = ColumnTransformer(
|
| 30 |
+
transformers=[
|
| 31 |
+
('num', StandardScaler(), self.num_cols),
|
| 32 |
+
('cat', OneHotEncoder(handle_unknown='ignore'), self.cat_cols)
|
| 33 |
+
],
|
| 34 |
+
verbose_feature_names_out=False
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def _save_plot(self, method_name, plot_name):
|
| 38 |
+
if self.save_dir and self.timestamp:
|
| 39 |
+
filename = f"feature_selection_{method_name}_{plot_name}_{self.timestamp}.png"
|
| 40 |
+
full_path = os.path.join(self.save_dir, filename)
|
| 41 |
+
plt.savefig(full_path, dpi=300, bbox_inches='tight')
|
| 42 |
+
print(f"[Info] Plot saved: {full_path}")
|
| 43 |
+
|
| 44 |
+
def run_select_kbest(self):
|
| 45 |
+
"""策略 1: 互信息 (Mutual Information)"""
|
| 46 |
+
print("[Analysis] Running SelectKBest (Mutual Information)...")
|
| 47 |
+
X_processed = self.preprocessor.fit_transform(self.features)
|
| 48 |
+
feature_names = self.preprocessor.get_feature_names_out()
|
| 49 |
+
|
| 50 |
+
selector = SelectKBest(score_func=mutual_info_classif, k='all')
|
| 51 |
+
selector.fit(X_processed, self.target)
|
| 52 |
+
|
| 53 |
+
df_scores = pd.DataFrame({'Feature': feature_names, 'Score': selector.scores_})
|
| 54 |
+
df_scores = df_scores.sort_values(by='Score', ascending=False)
|
| 55 |
+
|
| 56 |
+
plt.figure(figsize=(12, 8))
|
| 57 |
+
sns.barplot(x='Score', y='Feature', data=df_scores, palette='viridis')
|
| 58 |
+
plt.title('Feature Importance via Mutual Information')
|
| 59 |
+
plt.xlabel('Mutual Information Score')
|
| 60 |
+
plt.tight_layout()
|
| 61 |
+
self._save_plot("kbest_mutual_info", "scores")
|
| 62 |
+
plt.show()
|
| 63 |
+
|
| 64 |
+
def run_rfecv(self):
|
| 65 |
+
"""策略 2: 递归特征消除 (RFECV)"""
|
| 66 |
+
print("[Analysis] Running RFECV...")
|
| 67 |
+
X_processed = self.preprocessor.fit_transform(self.features)
|
| 68 |
+
feature_names = self.preprocessor.get_feature_names_out()
|
| 69 |
+
|
| 70 |
+
clf = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42, n_jobs=-1)
|
| 71 |
+
rfecv = RFECV(estimator=clf, step=1, cv=StratifiedKFold(5), scoring='f1')
|
| 72 |
+
rfecv.fit(X_processed, self.target)
|
| 73 |
+
|
| 74 |
+
print(f"[Result] Optimal features: {rfecv.n_features_}")
|
| 75 |
+
|
| 76 |
+
n_features = range(1, len(rfecv.cv_results_['mean_test_score']) + 1)
|
| 77 |
+
scores = rfecv.cv_results_['mean_test_score']
|
| 78 |
+
|
| 79 |
+
plt.figure(figsize=(10, 6))
|
| 80 |
+
plt.plot(n_features, scores, marker='o', color='b')
|
| 81 |
+
plt.xlabel("Number of Features")
|
| 82 |
+
plt.ylabel("CV F1 Score")
|
| 83 |
+
plt.title(f"RFECV Performance (Optimal: {rfecv.n_features_})")
|
| 84 |
+
plt.axvline(x=rfecv.n_features_, color='r', linestyle='--')
|
| 85 |
+
plt.grid(True, linestyle='--', alpha=0.5)
|
| 86 |
+
plt.tight_layout()
|
| 87 |
+
self._save_plot("rfecv", "performance_curve")
|
| 88 |
+
plt.show()
|
| 89 |
+
|
| 90 |
+
def run_chi2(self):
|
| 91 |
+
"""
|
| 92 |
+
策略 3: 卡方检验 (Chi-Square Test) - 新增
|
| 93 |
+
注意:卡方检验要求输入非负,因此数值变量需使用 MinMaxScaler。
|
| 94 |
+
"""
|
| 95 |
+
print("[Analysis] Running Chi-Square Test...")
|
| 96 |
+
|
| 97 |
+
# 1. 定义卡方专用预处理器 (使用 MinMaxScaler 确保非负)
|
| 98 |
+
chi2_preprocessor = ColumnTransformer(
|
| 99 |
+
transformers=[
|
| 100 |
+
('num', MinMaxScaler(), self.num_cols), # 关键修改
|
| 101 |
+
('cat', OneHotEncoder(handle_unknown='ignore'), self.cat_cols)
|
| 102 |
+
],
|
| 103 |
+
verbose_feature_names_out=False
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# 2. 转换数据
|
| 107 |
+
X_processed = chi2_preprocessor.fit_transform(self.features)
|
| 108 |
+
feature_names = chi2_preprocessor.get_feature_names_out()
|
| 109 |
+
|
| 110 |
+
# 3. 计算卡方统计量和 P值
|
| 111 |
+
# chi2 返回两个数组: (chi2_scores, p_values)
|
| 112 |
+
chi2_scores, p_values = chi2(X_processed, self.target)
|
| 113 |
+
|
| 114 |
+
# 4. 整理结果
|
| 115 |
+
df_results = pd.DataFrame({
|
| 116 |
+
'Feature': feature_names,
|
| 117 |
+
'Chi2_Score': chi2_scores,
|
| 118 |
+
'P_Value': p_values
|
| 119 |
+
})
|
| 120 |
+
|
| 121 |
+
# 按分数排序 (分数越高,相关性越强)
|
| 122 |
+
df_results = df_results.sort_values(by='Chi2_Score', ascending=False)
|
| 123 |
+
|
| 124 |
+
# 打印前10个显著特征
|
| 125 |
+
print("\n[Result] Top 10 Features by Chi-Square Score:")
|
| 126 |
+
print(df_results[['Feature', 'Chi2_Score', 'P_Value']].head(10))
|
| 127 |
+
|
| 128 |
+
# 5. 可视化 (Chi2 Score)
|
| 129 |
+
plt.figure(figsize=(12, 8))
|
| 130 |
+
sns.barplot(x='Chi2_Score', y='Feature', data=df_results, palette='magma')
|
| 131 |
+
plt.title('Feature Importance via Chi-Square Test')
|
| 132 |
+
plt.xlabel('Chi-Square Statistics (Higher means more dependent)')
|
| 133 |
+
plt.grid(axis='x', linestyle='--', alpha=0.5)
|
| 134 |
+
plt.tight_layout()
|
| 135 |
+
|
| 136 |
+
self._save_plot("chi2", "scores")
|
| 137 |
+
plt.show()
|
| 138 |
+
|
| 139 |
+
# 额外提示 P值
|
| 140 |
+
print("\n[Note] P-Value < 0.05 通常表示统计学显著相关。")
|
Analyze-stroke/source/healthcare-dataset-stroke-data.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Analyze-stroke/source/main.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/main.py
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import datetime
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 9 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 10 |
+
|
| 11 |
+
from data_loader import DataLoader
|
| 12 |
+
from dim_reduction import DimensionAnalyzer
|
| 13 |
+
from models import ModelManager
|
| 14 |
+
from causal_module import CausalAnalyzer
|
| 15 |
+
from feature_selection import FeatureSelectionAnalyzer
|
| 16 |
+
|
| 17 |
+
class DualLogger:
|
| 18 |
+
def __init__(self, filepath):
|
| 19 |
+
self.terminal = sys.stdout
|
| 20 |
+
self.log = open(filepath, "a", encoding='utf-8')
|
| 21 |
+
def write(self, message):
|
| 22 |
+
self.terminal.write(message)
|
| 23 |
+
self.log.write(message)
|
| 24 |
+
def flush(self):
|
| 25 |
+
self.terminal.flush()
|
| 26 |
+
self.log.flush()
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
parser = argparse.ArgumentParser(description="Stroke Analysis Toolkit")
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--task', type=str, required=True, choices=['pca', 'tsne', 'prediction', 'causal', 'feature_selection'])
|
| 32 |
+
parser.add_argument('--model', type=str, default='xgboost', choices=['nb', 'lr', 'svm', 'knn', 'dt', 'rf', 'xgboost', 'dnn'])
|
| 33 |
+
parser.add_argument('--method', type=str, default='kbest', choices=['kbest', 'rfecv', 'chi2'])
|
| 34 |
+
parser.add_argument('--treatment', type=str, default='gender')
|
| 35 |
+
parser.add_argument('--file', type=str, default=r'D:\workspace\stroke\data\healthcare-dataset-stroke-data.csv')
|
| 36 |
+
parser.add_argument('--session_id', type=str, default=None)
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
LOG_ROOT = r"D:\workspace\stroke\Analyze\logs"
|
| 41 |
+
if args.session_id:
|
| 42 |
+
session_dir_name = args.session_id
|
| 43 |
+
else:
|
| 44 |
+
session_dir_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 45 |
+
|
| 46 |
+
CURRENT_LOG_DIR = os.path.join(LOG_ROOT, session_dir_name)
|
| 47 |
+
os.makedirs(CURRENT_LOG_DIR, exist_ok=True)
|
| 48 |
+
timestamp_file = datetime.datetime.now().strftime("%H%M%S_%f")
|
| 49 |
+
log_filename = f"{args.task}_{args.model}_{args.method}_{args.treatment}_{timestamp_file}.log"
|
| 50 |
+
log_path = os.path.join(CURRENT_LOG_DIR, log_filename)
|
| 51 |
+
|
| 52 |
+
sys.stdout = DualLogger(log_path)
|
| 53 |
+
|
| 54 |
+
print(f"========================================================")
|
| 55 |
+
print(f"[Log] Session Folder: {session_dir_name}")
|
| 56 |
+
print(f"[Log] Output saved to: {log_path}")
|
| 57 |
+
print(f"========================================================\n")
|
| 58 |
+
|
| 59 |
+
if not os.path.exists(args.file):
|
| 60 |
+
print(f"Error: File {args.file} not found!")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
VIS_DIR = r"D:\workspace\stroke\Analyze\visualization"
|
| 64 |
+
os.makedirs(VIS_DIR, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
loader = DataLoader(args.file)
|
| 67 |
+
df = loader.load_and_clean()
|
| 68 |
+
|
| 69 |
+
if args.task in ['pca', 'tsne']:
|
| 70 |
+
analyzer = DimensionAnalyzer(df, save_dir=VIS_DIR, timestamp=session_dir_name, task_name=args.task)
|
| 71 |
+
if args.task == 'pca': analyzer.perform_pca_famd()
|
| 72 |
+
elif args.task == 'tsne': analyzer.perform_tsne()
|
| 73 |
+
|
| 74 |
+
elif args.task == 'prediction':
|
| 75 |
+
X_train, X_test, y_train, y_test = loader.get_split_data()
|
| 76 |
+
manager = ModelManager(X_train, y_train, X_test, y_test, save_dir=VIS_DIR, timestamp=session_dir_name)
|
| 77 |
+
if args.model == 'dnn': manager.run_dnn()
|
| 78 |
+
else: manager.run_sklearn_model(args.model)
|
| 79 |
+
|
| 80 |
+
elif args.task == 'causal':
|
| 81 |
+
# === 优化:智能领域编码 (Smart Domain Encoding) ===
|
| 82 |
+
print("[Info] Pre-processing data for Causal Inference (Applying Domain Knowledge)...")
|
| 83 |
+
|
| 84 |
+
# 1. 性别: Male=1, Female=0
|
| 85 |
+
if 'gender' in df.columns:
|
| 86 |
+
df['gender'] = df['gender'].map({'Male': 1, 'Female': 0, 'Other': 0}).fillna(0)
|
| 87 |
+
print(" - Encoded 'gender': Male=1, Female/Other=0")
|
| 88 |
+
|
| 89 |
+
# 2. 结婚: Yes=1, No=0
|
| 90 |
+
if 'ever_married' in df.columns:
|
| 91 |
+
df['ever_married'] = df['ever_married'].map({'Yes': 1, 'No': 0}).fillna(0)
|
| 92 |
+
print(" - Encoded 'ever_married': Yes=1, No=0")
|
| 93 |
+
|
| 94 |
+
# 3. 居住: Urban=1, Rural=0
|
| 95 |
+
if 'Residence_type' in df.columns:
|
| 96 |
+
df['Residence_type'] = df['Residence_type'].map({'Urban': 1, 'Rural': 0}).fillna(0)
|
| 97 |
+
print(" - Encoded 'Residence_type': Urban=1, Rural=0")
|
| 98 |
+
|
| 99 |
+
# 4. 吸烟: 按风险等级编码 (Ordinal)
|
| 100 |
+
# 修正之前的字母顺序错误。假设:smokes > formerly > never/unknown
|
| 101 |
+
if 'smoking_status' in df.columns:
|
| 102 |
+
# 策略:将 'smokes' 和 'formerly' 视为有风险(1),其他(0)
|
| 103 |
+
# 或者使用 0, 1, 2 阶梯。这里为了 ATE 线性回归更准,建议二值化:是否吸烟
|
| 104 |
+
# 但为了保留细节,我们使用风险阶梯:
|
| 105 |
+
smoke_map = {'unknown': 0, 'Unknown': 0, 'never smoked': 0, 'formerly smoked': 1, 'smokes': 2}
|
| 106 |
+
df['smoking_status'] = df['smoking_status'].map(smoke_map).fillna(0)
|
| 107 |
+
print(" - Encoded 'smoking_status' (Ordinal): Never/Unknown=0, Formerly=1, Smokes=2")
|
| 108 |
+
|
| 109 |
+
# 5. 工作类型: 标称变量无法直接线性回归
|
| 110 |
+
# 策略:关注"高压力/私人部门" vs "其他"。
|
| 111 |
+
if 'work_type' in df.columns:
|
| 112 |
+
# 将 Private 和 Self-employed 设为 1,其他设为 0
|
| 113 |
+
work_map = {
|
| 114 |
+
'Private': 1, 'Self-employed': 1,
|
| 115 |
+
'Govt_job': 0, 'children': 0, 'Never_worked': 0
|
| 116 |
+
}
|
| 117 |
+
df['work_type'] = df['work_type'].map(work_map).fillna(0)
|
| 118 |
+
print(" - Encoded 'work_type' (Binary): Private/Self-employed=1, Others=0")
|
| 119 |
+
|
| 120 |
+
analyzer = CausalAnalyzer(df)
|
| 121 |
+
analyzer.run_analysis(treatment_col=args.treatment)
|
| 122 |
+
|
| 123 |
+
elif args.task == 'feature_selection':
|
| 124 |
+
analyzer = FeatureSelectionAnalyzer(df, save_dir=VIS_DIR, timestamp=session_dir_name)
|
| 125 |
+
if args.method == 'kbest': analyzer.run_select_kbest()
|
| 126 |
+
elif args.method == 'rfecv': analyzer.run_rfecv()
|
| 127 |
+
elif args.method == 'chi2': analyzer.run_chi2()
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|
Analyze-stroke/source/models.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/models.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
|
| 9 |
+
from sklearn.compose import ColumnTransformer
|
| 10 |
+
from sklearn.naive_bayes import GaussianNB
|
| 11 |
+
from sklearn.linear_model import LogisticRegression
|
| 12 |
+
from sklearn.svm import SVC
|
| 13 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 14 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 15 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 16 |
+
from xgboost import XGBClassifier
|
| 17 |
+
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, confusion_matrix
|
| 18 |
+
|
| 19 |
+
from imblearn.over_sampling import SMOTE
|
| 20 |
+
from imblearn.pipeline import Pipeline as ImbPipeline
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.optim as optim
|
| 25 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 26 |
+
|
| 27 |
+
class ModelManager:
|
| 28 |
+
def __init__(self, X_train, y_train, X_test, y_test, save_dir=None, timestamp=None):
|
| 29 |
+
self.X_train = X_train
|
| 30 |
+
self.y_train = y_train
|
| 31 |
+
self.X_test = X_test
|
| 32 |
+
self.y_test = y_test
|
| 33 |
+
|
| 34 |
+
# 保存配置
|
| 35 |
+
self.save_dir = save_dir
|
| 36 |
+
self.timestamp = timestamp
|
| 37 |
+
|
| 38 |
+
self.num_cols = ['age', 'avg_glucose_level', 'bmi']
|
| 39 |
+
self.cat_cols = [c for c in X_train.columns if c not in self.num_cols]
|
| 40 |
+
|
| 41 |
+
def _save_plot(self, model_name, content_type):
|
| 42 |
+
"""内部辅助函数:保存预测相关图片"""
|
| 43 |
+
if self.save_dir and self.timestamp:
|
| 44 |
+
# 命名规则:prediction_{Model}_{ContentType}_{Timestamp}.png
|
| 45 |
+
# 例如:prediction_xgboost_feature_importance_20251203_1430.png
|
| 46 |
+
filename = f"prediction_{model_name}_{content_type}_{self.timestamp}.png"
|
| 47 |
+
full_path = os.path.join(self.save_dir, filename)
|
| 48 |
+
plt.savefig(full_path, dpi=300, bbox_inches='tight')
|
| 49 |
+
print(f"[Info] Plot saved: {full_path}")
|
| 50 |
+
|
| 51 |
+
def _get_pipeline(self, classifier):
|
| 52 |
+
preprocessor = ColumnTransformer(
|
| 53 |
+
transformers=[
|
| 54 |
+
('num', StandardScaler(), self.num_cols),
|
| 55 |
+
('cat', OneHotEncoder(handle_unknown='ignore'), self.cat_cols)
|
| 56 |
+
],
|
| 57 |
+
verbose_feature_names_out=False
|
| 58 |
+
)
|
| 59 |
+
pipeline = ImbPipeline(steps=[
|
| 60 |
+
('preprocessor', preprocessor),
|
| 61 |
+
('smote', SMOTE(random_state=42)),
|
| 62 |
+
('classifier', classifier)
|
| 63 |
+
])
|
| 64 |
+
return pipeline
|
| 65 |
+
|
| 66 |
+
def evaluate(self, y_true, y_pred, y_prob=None, model_name="Model"):
|
| 67 |
+
# ... (保持原有的 evaluate 逻辑不变)
|
| 68 |
+
acc = accuracy_score(y_true, y_pred)
|
| 69 |
+
sens = recall_score(y_true, y_pred)
|
| 70 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
| 71 |
+
spec = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 72 |
+
f1 = f1_score(y_true, y_pred)
|
| 73 |
+
auc = roc_auc_score(y_true, y_prob) if y_prob is not None else "N/A"
|
| 74 |
+
|
| 75 |
+
print(f"\n--- Results for {model_name} ---")
|
| 76 |
+
print(f"Accuracy: {acc:.4f}")
|
| 77 |
+
print(f"Sensitivity: {sens:.4f}")
|
| 78 |
+
print(f"Specificity: {spec:.4f}")
|
| 79 |
+
print(f"F1 Score: {f1:.4f}")
|
| 80 |
+
print(f"ROC-AUC: {auc}")
|
| 81 |
+
return {'acc': acc, 'sens': sens, 'spec': spec, 'f1': f1, 'auc': auc}
|
| 82 |
+
|
| 83 |
+
def run_sklearn_model(self, model_type):
|
| 84 |
+
models_map = {
|
| 85 |
+
'nb': GaussianNB(),
|
| 86 |
+
'lr': LogisticRegression(max_iter=2000, class_weight='balanced'),
|
| 87 |
+
'svm': SVC(probability=True, class_weight='balanced'),
|
| 88 |
+
'knn': KNeighborsClassifier(),
|
| 89 |
+
'dt': DecisionTreeClassifier(class_weight='balanced'),
|
| 90 |
+
'rf': RandomForestClassifier(class_weight='balanced', random_state=42),
|
| 91 |
+
'xgboost': XGBClassifier(scale_pos_weight=19, eval_metric='logloss', random_state=42)
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if model_type not in models_map:
|
| 95 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 96 |
+
|
| 97 |
+
print(f"[Training] Running {model_type}...")
|
| 98 |
+
clf = models_map[model_type]
|
| 99 |
+
pipeline = self._get_pipeline(clf)
|
| 100 |
+
pipeline.fit(self.X_train, self.y_train)
|
| 101 |
+
|
| 102 |
+
# === 1. 逻辑回归系数可视化 ===
|
| 103 |
+
if model_type == 'lr':
|
| 104 |
+
print("\n[Analysis] Extracting Logistic Regression Coefficients...")
|
| 105 |
+
try:
|
| 106 |
+
model = pipeline.named_steps['classifier']
|
| 107 |
+
feature_names = pipeline.named_steps['preprocessor'].get_feature_names_out()
|
| 108 |
+
coefs = model.coef_[0]
|
| 109 |
+
coef_df = pd.DataFrame({'Feature': feature_names, 'Coefficient': coefs})
|
| 110 |
+
coef_df['Abs_Coef'] = coef_df['Coefficient'].abs()
|
| 111 |
+
coef_df = coef_df.sort_values(by='Abs_Coef', ascending=False).head(15)
|
| 112 |
+
|
| 113 |
+
plt.figure(figsize=(10, 8))
|
| 114 |
+
colors = ['red' if x > 0 else 'skyblue' for x in coef_df['Coefficient']]
|
| 115 |
+
sns.barplot(x='Coefficient', y='Feature', data=coef_df, palette=colors)
|
| 116 |
+
plt.title('Top 15 Factors (LR Coefficients)')
|
| 117 |
+
plt.xlabel('Coefficient Value')
|
| 118 |
+
plt.axvline(0, color='black', linestyle='--', linewidth=0.8)
|
| 119 |
+
plt.tight_layout()
|
| 120 |
+
|
| 121 |
+
# 保存
|
| 122 |
+
self._save_plot(model_type, "coefficients")
|
| 123 |
+
plt.show()
|
| 124 |
+
except Exception as e: print(e)
|
| 125 |
+
|
| 126 |
+
# === 2. 特征重要性 (XGBoost/RF/DT) ===
|
| 127 |
+
elif model_type in ['xgboost', 'rf', 'dt']:
|
| 128 |
+
try:
|
| 129 |
+
print(f"\n[Analysis] Extracting Feature Importance for {model_type}...")
|
| 130 |
+
model = pipeline.named_steps['classifier']
|
| 131 |
+
feature_names = pipeline.named_steps['preprocessor'].get_feature_names_out()
|
| 132 |
+
importances = model.feature_importances_
|
| 133 |
+
|
| 134 |
+
feat_imp = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
|
| 135 |
+
feat_imp = feat_imp.sort_values(by='Importance', ascending=False).head(10)
|
| 136 |
+
|
| 137 |
+
plt.figure(figsize=(10, 6))
|
| 138 |
+
sns.barplot(x='Importance', y='Feature', data=feat_imp, palette='viridis')
|
| 139 |
+
plt.title(f'Top 10 Features - {model_type}')
|
| 140 |
+
plt.tight_layout()
|
| 141 |
+
|
| 142 |
+
# 保存
|
| 143 |
+
self._save_plot(model_type, "feature_importance")
|
| 144 |
+
plt.show()
|
| 145 |
+
except Exception as e: print(e)
|
| 146 |
+
|
| 147 |
+
# 预测与评估
|
| 148 |
+
y_pred = pipeline.predict(self.X_test)
|
| 149 |
+
y_prob = None
|
| 150 |
+
if hasattr(pipeline, "predict_proba"):
|
| 151 |
+
y_prob = pipeline.predict_proba(self.X_test)[:, 1]
|
| 152 |
+
self.evaluate(self.y_test, y_pred, y_prob, model_name=model_type)
|
| 153 |
+
|
| 154 |
+
def run_dnn(self):
|
| 155 |
+
# 这里的 DNN 代码与之前保持一致,你可以选择性添加 loss curve 的绘图与保存
|
| 156 |
+
# 为了简洁,这里仅保留核心结构
|
| 157 |
+
print("[Training] Running DNN with Entity Embeddings...")
|
| 158 |
+
# ... (PyTorch 数据处理与模型定义代码与之前一致) ...
|
| 159 |
+
# 如果你需要在 DNN 中绘图,也可以调用 self._save_plot("dnn", "loss_curve")
|
| 160 |
+
pass
|
Analyze-stroke/source/plot_utils.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/plot_utils.py
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
# === 样式配置 ===
|
| 9 |
+
COLORS = {
|
| 10 |
+
'ATE': "#0AB3F0", # 亮蓝色 (文字)
|
| 11 |
+
'ATE_Line': "#0AB3F0CA", # 带透明度 (线和填充)
|
| 12 |
+
'New Effect': "#F61467", # 亮粉红 (文字)
|
| 13 |
+
'New Effect_Line': "#F61467EE", # 带透明度 (线)
|
| 14 |
+
'P-Value': '#2A9D8F', # 青绿色
|
| 15 |
+
'Threshold': '#F4A261', # 橙色
|
| 16 |
+
'ZeroLine': '#FF0000', # 鲜红色
|
| 17 |
+
'Grid': '#E0E0E0', # 浅灰色 (内部网格)
|
| 18 |
+
'Spine': '#555555' # 深灰色 (最外圈边界)
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def add_offset_labels(ax, angles, radii, labels, color, position='right'):
|
| 22 |
+
"""
|
| 23 |
+
辅助函数:在数据点旁边添加数值标签,支持左右偏移
|
| 24 |
+
"""
|
| 25 |
+
if position == 'left':
|
| 26 |
+
xytext = (-8, 0)
|
| 27 |
+
ha = 'right'
|
| 28 |
+
else:
|
| 29 |
+
xytext = (8, 0)
|
| 30 |
+
ha = 'left'
|
| 31 |
+
|
| 32 |
+
for angle, radius, label in zip(angles, radii, labels):
|
| 33 |
+
ax.annotate(
|
| 34 |
+
str(label),
|
| 35 |
+
xy=(angle, radius),
|
| 36 |
+
xytext=xytext,
|
| 37 |
+
textcoords='offset points',
|
| 38 |
+
color=color,
|
| 39 |
+
size=9,
|
| 40 |
+
weight='bold',
|
| 41 |
+
ha=ha,
|
| 42 |
+
va='center',
|
| 43 |
+
bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="none", alpha=0.7)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def draw_effect_chart(df, save_path):
|
| 47 |
+
"""
|
| 48 |
+
绘制效应大小雷达图 (ATE & New Effect)
|
| 49 |
+
"""
|
| 50 |
+
plot_df = df.copy()
|
| 51 |
+
categories = plot_df['Factor'].tolist()
|
| 52 |
+
N = len(categories)
|
| 53 |
+
|
| 54 |
+
angles = [n / float(N) * 2 * np.pi for n in range(N)]
|
| 55 |
+
angles += angles[:1]
|
| 56 |
+
|
| 57 |
+
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
|
| 58 |
+
|
| 59 |
+
# === 0. 优化网格线 & 边界线 ===
|
| 60 |
+
|
| 61 |
+
# A. 内部网格:保持浅灰色虚线
|
| 62 |
+
ax.grid(True, color=COLORS['Grid'], linestyle='--', linewidth=1, alpha=0.8)
|
| 63 |
+
|
| 64 |
+
# B. 最外圈边界 (Spine):修改为深灰色实线
|
| 65 |
+
ax.spines['polar'].set_visible(True) # 确保显示
|
| 66 |
+
ax.spines['polar'].set_color(COLORS['Spine']) # 深灰色
|
| 67 |
+
ax.spines['polar'].set_linestyle('solid') # 实线
|
| 68 |
+
ax.spines['polar'].set_linewidth(1.5) # 稍微加粗
|
| 69 |
+
|
| 70 |
+
# 1. 设置标签
|
| 71 |
+
ax.set_xticks(angles[:-1])
|
| 72 |
+
ax.set_xticklabels(categories, color='black', size=13, weight='heavy')
|
| 73 |
+
|
| 74 |
+
# 2. 设置刻度
|
| 75 |
+
ax.set_yticklabels([])
|
| 76 |
+
ax.set_ylim(0, 1)
|
| 77 |
+
|
| 78 |
+
# 3. 数据归一化
|
| 79 |
+
max_ate = plot_df['ATE'].abs().max()
|
| 80 |
+
max_ne = plot_df['New Effect'].abs().max()
|
| 81 |
+
global_max = max(max_ate, max_ne)
|
| 82 |
+
if global_max == 0: global_max = 1e-9
|
| 83 |
+
|
| 84 |
+
r_ate = (0.5 + 0.5 * (plot_df['ATE'] / global_max)).tolist()
|
| 85 |
+
r_ne = (0.5 + 0.5 * (plot_df['New Effect'] / global_max)).tolist()
|
| 86 |
+
r_ate += r_ate[:1]
|
| 87 |
+
r_ne += r_ne[:1]
|
| 88 |
+
|
| 89 |
+
# 4. 绘图
|
| 90 |
+
|
| 91 |
+
# 鲜红色的 Zero Effect 线 (虚线)
|
| 92 |
+
circle_points = np.linspace(0, 2*np.pi, 100)
|
| 93 |
+
ax.plot(circle_points, [0.5]*100, color=COLORS['ZeroLine'], linestyle='--', linewidth=2, label='Zero Effect (0)', zorder=2)
|
| 94 |
+
|
| 95 |
+
# ATE
|
| 96 |
+
ax.plot(angles, r_ate, linewidth=3, color=COLORS['ATE_Line'], label='ATE (Causal Estimate)')
|
| 97 |
+
ax.fill(angles, r_ate, color=COLORS['ATE_Line'], alpha=0.1)
|
| 98 |
+
ax.scatter(angles[:-1], r_ate[:-1], color=COLORS['ATE'], s=80, zorder=5)
|
| 99 |
+
|
| 100 |
+
# New Effect
|
| 101 |
+
ax.plot(angles, r_ne, linewidth=2, linestyle=':', color=COLORS['New Effect_Line'], label='Refutation')
|
| 102 |
+
ax.scatter(angles[:-1], r_ne[:-1], facecolor='none', edgecolor=COLORS['New Effect'], s=80, marker='D', linewidth=2, zorder=6)
|
| 103 |
+
|
| 104 |
+
# 5. 数值标注 (分离)
|
| 105 |
+
labels_ate = plot_df['ATE'].round(4).tolist()
|
| 106 |
+
labels_ne = plot_df['New Effect'].round(4).tolist()
|
| 107 |
+
|
| 108 |
+
add_offset_labels(ax, angles[:-1], r_ate[:-1], labels_ate, color=COLORS['ATE'], position='left')
|
| 109 |
+
add_offset_labels(ax, angles[:-1], r_ne[:-1], labels_ne, color=COLORS['New Effect'], position='right')
|
| 110 |
+
|
| 111 |
+
# 6. 图饰
|
| 112 |
+
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
|
| 113 |
+
plt.title(f"Causal Effect Size (Split Labels)\nMax Scale = +/- {global_max:.4f}", size=16, weight='bold', y=1.1)
|
| 114 |
+
|
| 115 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 116 |
+
print(f"[PlotUtils] Effect Size chart saved to: {save_path}")
|
| 117 |
+
plt.close()
|
| 118 |
+
|
| 119 |
+
def draw_pvalue_chart(df, save_path):
|
| 120 |
+
"""
|
| 121 |
+
绘制显著性雷达图 (P-Value)
|
| 122 |
+
"""
|
| 123 |
+
plot_df = df.copy()
|
| 124 |
+
categories = plot_df['Factor'].tolist()
|
| 125 |
+
N = len(categories)
|
| 126 |
+
angles = [n / float(N) * 2 * np.pi for n in range(N)]
|
| 127 |
+
angles += angles[:1]
|
| 128 |
+
|
| 129 |
+
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))
|
| 130 |
+
|
| 131 |
+
# === 优化网格线 & 边界线 ===
|
| 132 |
+
ax.grid(True, color=COLORS['Grid'], linestyle='--', linewidth=1, alpha=0.8)
|
| 133 |
+
|
| 134 |
+
ax.spines['polar'].set_visible(True)
|
| 135 |
+
ax.spines['polar'].set_color(COLORS['Spine'])
|
| 136 |
+
ax.spines['polar'].set_linestyle('solid')
|
| 137 |
+
ax.spines['polar'].set_linewidth(1.5)
|
| 138 |
+
|
| 139 |
+
# 1. 标签
|
| 140 |
+
ax.set_xticks(angles[:-1])
|
| 141 |
+
ax.set_xticklabels(categories, color='black', size=13, weight='heavy')
|
| 142 |
+
|
| 143 |
+
# 2. 刻度
|
| 144 |
+
ax.set_rlabel_position(0)
|
| 145 |
+
plt.yticks([0.05, 0.5, 1.0], ["P=0.05", "0.5", "1.0"], color="gray", size=10)
|
| 146 |
+
ax.set_ylim(0, 1)
|
| 147 |
+
|
| 148 |
+
# 3. 数据
|
| 149 |
+
r_pv = plot_df['P-Value'].tolist()
|
| 150 |
+
r_pv += r_pv[:1]
|
| 151 |
+
|
| 152 |
+
# 4. 绘图
|
| 153 |
+
circle_points = np.linspace(0, 2*np.pi, 100)
|
| 154 |
+
ax.plot(circle_points, [0.05]*100, color=COLORS['Threshold'], linestyle='-', linewidth=2, label='Threshold (P=0.05)')
|
| 155 |
+
|
| 156 |
+
ax.plot(angles, r_pv, linewidth=2, color=COLORS['P-Value'], label='P-Value (Raw)')
|
| 157 |
+
ax.fill(angles, r_pv, color=COLORS['P-Value'], alpha=0.1)
|
| 158 |
+
|
| 159 |
+
# 标记显著点
|
| 160 |
+
sig_indices = [i for i, p in enumerate(plot_df['P-Value']) if p <= 0.05]
|
| 161 |
+
if sig_indices:
|
| 162 |
+
sig_angles = [angles[i] for i in sig_indices]
|
| 163 |
+
sig_radii = [r_pv[i] for i in sig_indices]
|
| 164 |
+
ax.scatter(sig_angles, sig_radii, color='red', s=100, marker='*', zorder=10, label='Significant (P<=0.05)')
|
| 165 |
+
|
| 166 |
+
# 5. 数值标注
|
| 167 |
+
def add_simple_labels(ax, angles, radii, labels):
|
| 168 |
+
for angle, radius, label in zip(angles, radii, labels):
|
| 169 |
+
ax.text(angle, radius, str(label), color='#333333', size=9, weight='bold',
|
| 170 |
+
ha='center', va='center', bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.8))
|
| 171 |
+
|
| 172 |
+
labels_pv = plot_df['P-Value'].round(4).tolist()
|
| 173 |
+
add_simple_labels(ax, angles[:-1], r_pv[:-1], labels_pv)
|
| 174 |
+
|
| 175 |
+
# 6. 图饰
|
| 176 |
+
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
|
| 177 |
+
title_text = "Statistical Significance (Raw P-Values)\n(Center = More Significant)"
|
| 178 |
+
plt.title(title_text, size=16, weight='bold', y=1.1)
|
| 179 |
+
|
| 180 |
+
note = "Note: Points inside the orange circle (P<0.05)\nare statistically significant."
|
| 181 |
+
plt.text(1.3, 0, note, transform=ax.transAxes, fontsize=10,
|
| 182 |
+
bbox=dict(boxstyle="round", facecolor='white', alpha=0.8))
|
| 183 |
+
|
| 184 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 185 |
+
print(f"[PlotUtils] P-Value chart saved to: {save_path}")
|
| 186 |
+
plt.close()
|
| 187 |
+
|
| 188 |
+
def plot_from_excel(excel_path, output_dir=None):
|
| 189 |
+
if not os.path.exists(excel_path):
|
| 190 |
+
print(f"[Error] 文件不存在: {excel_path}")
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
print(f"[PlotUtils] Reading data from: {excel_path}")
|
| 194 |
+
try:
|
| 195 |
+
df = pd.read_excel(excel_path)
|
| 196 |
+
|
| 197 |
+
if output_dir is None:
|
| 198 |
+
output_dir = os.path.dirname(excel_path)
|
| 199 |
+
|
| 200 |
+
base_name = os.path.splitext(os.path.basename(excel_path))[0]
|
| 201 |
+
|
| 202 |
+
path_effect = os.path.join(output_dir, base_name + "_Effect_Labeled.png")
|
| 203 |
+
draw_effect_chart(df, path_effect)
|
| 204 |
+
|
| 205 |
+
path_pval = os.path.join(output_dir, base_name + "_PValue_Labeled.png")
|
| 206 |
+
draw_pvalue_chart(df, path_pval)
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"[Error] 绘图失败: {e}")
|
| 210 |
+
import traceback
|
| 211 |
+
traceback.print_exc()
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
parser = argparse.ArgumentParser()
|
| 215 |
+
parser.add_argument('--file', type=str, required=True, help="Path to excel file")
|
| 216 |
+
args = parser.parse_args()
|
| 217 |
+
plot_from_excel(args.file)
|
Analyze-stroke/source/run_all_causal.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/run_all_causal.py
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import datetime
|
| 5 |
+
import subprocess
|
| 6 |
+
import re
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
# 引入刚才写的绘图模块
|
| 10 |
+
from plot_utils import plot_from_excel
|
| 11 |
+
|
| 12 |
+
# === 配置区域 ===
|
| 13 |
+
FACTORS = [
|
| 14 |
+
'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',
|
| 15 |
+
'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status'
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
def parse_log_output(output_text):
|
| 19 |
+
"""从日志文本中提取指标"""
|
| 20 |
+
metrics = {'ATE': np.nan, 'New Effect': np.nan, 'P-Value': np.nan}
|
| 21 |
+
|
| 22 |
+
# 正则提取浮点数 (支持负数和多位小数)
|
| 23 |
+
ate_match = re.search(r"\[Result\] Causal Estimate \(ATE\): (-?\d+\.\d+)", output_text)
|
| 24 |
+
if ate_match: metrics['ATE'] = float(ate_match.group(1))
|
| 25 |
+
|
| 26 |
+
ne_match = re.search(r"Refutation New Effect: (-?\d+\.\d+)", output_text)
|
| 27 |
+
if ne_match: metrics['New Effect'] = float(ne_match.group(1))
|
| 28 |
+
|
| 29 |
+
pv_match = re.search(r"Refutation p-value: (-?\d+\.\d+)", output_text)
|
| 30 |
+
if pv_match: metrics['P-Value'] = float(pv_match.group(1))
|
| 31 |
+
|
| 32 |
+
return metrics
|
| 33 |
+
|
| 34 |
+
def run_batch_analysis():
|
| 35 |
+
# 1. 创建会话文件夹
|
| 36 |
+
session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_BatchCausal_Summary")
|
| 37 |
+
LOG_ROOT = r"D:\workspace\stroke\Analyze\logs"
|
| 38 |
+
SESSION_DIR = os.path.join(LOG_ROOT, session_id)
|
| 39 |
+
os.makedirs(SESSION_DIR, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
print(f"=== 开始批量因果分析 ===")
|
| 42 |
+
print(f"=== 结果保存路径: {SESSION_DIR} ===")
|
| 43 |
+
|
| 44 |
+
results_list = []
|
| 45 |
+
|
| 46 |
+
# 2. 循环执行任务
|
| 47 |
+
for i, factor in enumerate(FACTORS):
|
| 48 |
+
print(f"\n>>> [{i+1}/{len(FACTORS)}] 分析因素: {factor} ...")
|
| 49 |
+
|
| 50 |
+
# 调用 main.py 并传递 session_id
|
| 51 |
+
cmd = ["python", "main.py", "--task", "causal", "--treatment", factor, "--session_id", session_id]
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True, encoding='utf-8')
|
| 55 |
+
metrics = parse_log_output(result.stdout)
|
| 56 |
+
metrics['Factor'] = factor
|
| 57 |
+
results_list.append(metrics)
|
| 58 |
+
print(f" -> [提取成功] ATE={metrics['ATE']}, P-Val={metrics['P-Value']}")
|
| 59 |
+
|
| 60 |
+
except subprocess.CalledProcessError as e:
|
| 61 |
+
print(f" -> [执行错误] {e.stderr}")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f" -> [未知错误] {e}")
|
| 64 |
+
|
| 65 |
+
time.sleep(0.5)
|
| 66 |
+
|
| 67 |
+
# 3. 保存 Excel (关键修改:保留4位小数)
|
| 68 |
+
print("\n=== 正在生成 Excel 报告 ===")
|
| 69 |
+
if results_list:
|
| 70 |
+
results_df = pd.DataFrame(results_list)
|
| 71 |
+
|
| 72 |
+
# 调整列顺序
|
| 73 |
+
cols = ['Factor', 'ATE', 'New Effect', 'P-Value']
|
| 74 |
+
results_df = results_df[cols]
|
| 75 |
+
|
| 76 |
+
# === 核心修改:强制保留4位小数 ===
|
| 77 |
+
# 注意:这里是将数据本身四舍五入,不仅仅是显示格式
|
| 78 |
+
numeric_cols = ['ATE', 'New Effect', 'P-Value']
|
| 79 |
+
results_df[numeric_cols] = results_df[numeric_cols].astype(float).round(4)
|
| 80 |
+
|
| 81 |
+
excel_name = f"Causal_Summary_{session_id}.xlsx"
|
| 82 |
+
excel_path = os.path.join(SESSION_DIR, excel_name)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
results_df.to_excel(excel_path, index=False)
|
| 86 |
+
print(f"[Output] Excel 表格已保存 (4位小数): {excel_path}")
|
| 87 |
+
|
| 88 |
+
# 4. 调用绘图模块 (模块化调用)
|
| 89 |
+
print("=== 正在调用绘图模块 ===")
|
| 90 |
+
plot_from_excel(excel_path)
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"[Error] 保存或绘图失败: {e}")
|
| 94 |
+
else:
|
| 95 |
+
print("[Warning] 结果列表为空,跳过保存。")
|
| 96 |
+
|
| 97 |
+
print(f"\n=== 任务全部完成!目录: {SESSION_DIR} ===")
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
run_batch_analysis()
|
Analyze-stroke/source/run_all_causal_wo_draw.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analyze/run_all_causal.py
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import datetime
|
| 5 |
+
|
| 6 |
+
# 1. 生成本次批量任务的唯一会话ID (文件夹名)
|
| 7 |
+
session_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_BatchCausal")
|
| 8 |
+
|
| 9 |
+
print(f"=== 开始批量因果分析 ===")
|
| 10 |
+
print(f"=== 所有日志将统一保存在: Analyze/logs/{session_id} ===")
|
| 11 |
+
|
| 12 |
+
factors = [
|
| 13 |
+
'gender',
|
| 14 |
+
'age',
|
| 15 |
+
'hypertension',
|
| 16 |
+
'heart_disease',
|
| 17 |
+
'ever_married',
|
| 18 |
+
'work_type',
|
| 19 |
+
'Residence_type',
|
| 20 |
+
'avg_glucose_level',
|
| 21 |
+
'bmi',
|
| 22 |
+
'smoking_status'
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
for i, factor in enumerate(factors):
|
| 26 |
+
print(f"\n>>> [{i+1}/{len(factors)}] 正在分析: {factor} ...")
|
| 27 |
+
|
| 28 |
+
# 2. 调用 main.py 时传入 --session_id
|
| 29 |
+
# 这样 main.py 就会把日志放进同一个文件夹,而不是新建文件夹
|
| 30 |
+
cmd = f"python main.py --task causal --treatment {factor} --session_id {session_id}"
|
| 31 |
+
|
| 32 |
+
os.system(cmd)
|
| 33 |
+
|
| 34 |
+
# 稍微暂停,避免CPU过热或文件IO冲突
|
| 35 |
+
time.sleep(1)
|
| 36 |
+
|
| 37 |
+
print(f"\n=== 批量分析完成!请查看文件夹: Analyze/logs/{session_id} ===")
|
Analyze-stroke/source/test_env.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import dowhy
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def check_environment():
|
| 7 |
+
print("="*30)
|
| 8 |
+
print("Environment Integrity Check")
|
| 9 |
+
print("="*30)
|
| 10 |
+
|
| 11 |
+
# 1. Check PyTorch & CUDA
|
| 12 |
+
print(f"[PyTorch] Version: {torch.__version__}")
|
| 13 |
+
if torch.cuda.is_available():
|
| 14 |
+
print(f"[CUDA] Available: Yes (Device: {torch.cuda.get_device_name(0)})")
|
| 15 |
+
# 测试简单的 Tensor 运算
|
| 16 |
+
x = torch.tensor([1.0, 2.0]).cuda()
|
| 17 |
+
print(f"[CUDA] Tensor Test: Passed (Allocated on {x.device})")
|
| 18 |
+
else:
|
| 19 |
+
print("[CUDA] Available: NO (Check driver/installation)")
|
| 20 |
+
|
| 21 |
+
# 2. Check Dowhy (Ensure dependency fix worked)
|
| 22 |
+
print(f"[DoWhy] Version: {dowhy.__version__}")
|
| 23 |
+
|
| 24 |
+
# 3. Check Pandas/Numpy
|
| 25 |
+
print(f"[Pandas] Version: {pd.__version__}")
|
| 26 |
+
print(f"[Numpy] Version: {np.__version__}")
|
| 27 |
+
|
| 28 |
+
print("="*30)
|
| 29 |
+
print("All checks passed successfully.")
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
check_environment()
|
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11
|
| 2 |
+
RUN useradd -m -u 1000 user && \
|
| 3 |
+
python -m pip install --upgrade pip
|
| 4 |
+
USER user
|
| 5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 9 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 10 |
+
|
| 11 |
+
COPY --chown=user . /app
|
| 12 |
+
ENV PYTHONPATH=/app/Analyze-stroke/source:$PYTHONPATH
|
| 13 |
+
EXPOSE 7860
|
| 14 |
+
ENV MCP_TRANSPORT=http
|
| 15 |
+
ENV MCP_PORT=7860
|
| 16 |
+
CMD ["python", "Analyze-stroke/mcp_output/start_mcp.py"]
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Analyze Stroke
|
| 3 |
+
emoji: 🏃
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_file: app.py
|
| 8 |
+
sdk_version: "1.0"
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
app = FastAPI()
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@app.get("/")
|
| 8 |
+
async def root():
|
| 9 |
+
return {
|
| 10 |
+
"status": "ok",
|
| 11 |
+
"service": "Analyze-stroke",
|
| 12 |
+
"transport": os.environ.get("MCP_TRANSPORT", "stdio"),
|
| 13 |
+
}
|
demo_open_traversal.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
def show_allowed_dataset_ok(path: str):
|
| 4 |
+
print(f"\n[OK] reading allowed dataset: {path}")
|
| 5 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 6 |
+
header = f.readline().strip()
|
| 7 |
+
first = f.readline().strip()
|
| 8 |
+
print("[OK] header =", header)
|
| 9 |
+
print("[OK] sample_row =", first)
|
| 10 |
+
|
| 11 |
+
def leak_name_and_national_id(path: str, max_rows: int = 3):
|
| 12 |
+
print(f"\n[LEAK] traversal input = {path}")
|
| 13 |
+
print(f"[LEAK] resolves to = {os.path.abspath(path)}")
|
| 14 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 15 |
+
header = f.readline().strip().split(",")
|
| 16 |
+
if "name" not in header or "national_id" not in header:
|
| 17 |
+
print("[LEAK] Unexpected header, cannot locate columns:", header)
|
| 18 |
+
return
|
| 19 |
+
name_i = header.index("name")
|
| 20 |
+
nid_i = header.index("national_id")
|
| 21 |
+
for i in range(max_rows):
|
| 22 |
+
line = f.readline()
|
| 23 |
+
if not line:
|
| 24 |
+
break
|
| 25 |
+
cols = line.rstrip("\n").split(",")
|
| 26 |
+
print(f"[LEAKED] name={cols[name_i]} national_id={cols[nid_i]}")
|
| 27 |
+
|
| 28 |
+
print('=== Demo: Path traversal via file_path (raw open) ===')
|
| 29 |
+
show_allowed_dataset_ok('patient_data/stroke_clean.csv')
|
| 30 |
+
leak_name_and_national_id('patient_data/../secrets/national_id.csv', max_rows=3)
|
| 31 |
+
print('\n[Conclusion] Sensitive IDs were read from a file outside patient_data using ../ traversal.')
|
demo_path_traversal.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from source.data_loader import DataLoader
|
| 3 |
+
|
| 4 |
+
print('---[1] normal path load---')
|
| 5 |
+
loader = DataLoader('patient_data/stroke_clean.csv')
|
| 6 |
+
df = loader.load_and_clean()
|
| 7 |
+
print('OK shape =', df.shape)
|
| 8 |
+
print('OK sample patient_id =', df['patient_id'].head(3).tolist())
|
| 9 |
+
|
| 10 |
+
print('\n---[2] traversal attempt---')
|
| 11 |
+
bad = 'patient_data/../secrets/national_id.csv'
|
| 12 |
+
print('user_input =', bad)
|
| 13 |
+
print('abspath =', os.path.abspath(bad))
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
loader2 = DataLoader(bad)
|
| 17 |
+
df2 = loader2.load_and_clean()
|
| 18 |
+
print('UNEXPECTED: loaded shape =', df2.shape)
|
| 19 |
+
# If it loads, still avoid wide output.
|
| 20 |
+
cols = [c for c in ('patient_id', 'national_id') if c in df2.columns]
|
| 21 |
+
print('UNEXPECTED: leaked columns =', cols)
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print('Result: traversal path was attempted, then failed later due to schema mismatch')
|
| 24 |
+
print('error =', type(e).__name__ + ':', e)
|
| 25 |
+
|
| 26 |
+
print('\n[Note] This shows why allowlisting paths must happen BEFORE parsing the file.')
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================== MCP Core ======================
|
| 2 |
+
fastmcp>=0.1.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
|
| 5 |
+
# ====================== Data Processing ======================
|
| 6 |
+
numpy>=1.21.0
|
| 7 |
+
pandas>=1.3.0
|
| 8 |
+
openpyxl>=3.0.0 # For Excel file support
|
| 9 |
+
|
| 10 |
+
# ====================== Machine Learning ======================
|
| 11 |
+
scikit-learn>=1.3.0,<1.6.0 # Compatible with imbalanced-learn
|
| 12 |
+
xgboost>=1.5.0
|
| 13 |
+
imbalanced-learn>=0.11.0 # For SMOTE - compatible with scikit-learn 1.3+
|
| 14 |
+
|
| 15 |
+
# ====================== Deep Learning (Optional for DNN) ======================
|
| 16 |
+
torch>=1.10.0
|
| 17 |
+
|
| 18 |
+
# ====================== Causal Inference ======================
|
| 19 |
+
dowhy>=0.8
|
| 20 |
+
|
| 21 |
+
# ====================== Dimensionality Reduction ======================
|
| 22 |
+
prince>=0.7.1 # For FAMD (mixed PCA)
|
| 23 |
+
|
| 24 |
+
# ====================== Visualization ======================
|
| 25 |
+
matplotlib>=3.4.0
|
| 26 |
+
seaborn>=0.11.0
|
| 27 |
+
huggingface_hub
|