Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse files- .gitignore +3 -0
- DEPLOYMENT.md +139 -0
- README.md +76 -6
- README_SPACE.md +84 -0
- app.py +251 -0
- deploy_to_hf.sh +109 -0
- requirements.txt +22 -0
- src/backbone/vit_wrapper.py +180 -0
- utils/training.py +231 -0
- utils/visualization.py +190 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
.env
|
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deploying NAF Demo to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
## Quick Setup
|
| 4 |
+
|
| 5 |
+
### 1. Create a Hugging Face Space
|
| 6 |
+
|
| 7 |
+
1. Go to [https://huggingface.co/spaces](https://huggingface.co/spaces)
|
| 8 |
+
2. Click "Create new Space"
|
| 9 |
+
3. Configure your Space:
|
| 10 |
+
- **Space name**: `naf-feature-upsampling` (or your choice)
|
| 11 |
+
- **License**: Apache 2.0 (or your choice)
|
| 12 |
+
- **Select the SDK**: Gradio
|
| 13 |
+
- **Space hardware**: CPU Basic (free) or GPU (T4 small recommended for faster inference)
|
| 14 |
+
- **Visibility**: Public or Private
|
| 15 |
+
|
| 16 |
+
### 2. Required Files
|
| 17 |
+
|
| 18 |
+
Upload these files to your Hugging Face Space:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
your-space/
|
| 22 |
+
├── app.py # Main application
|
| 23 |
+
├── requirements.txt # Python dependencies
|
| 24 |
+
├── README.md # Space documentation
|
| 25 |
+
├── src/
|
| 26 |
+
│ └── backbone/
|
| 27 |
+
│ └── vit_wrapper.py # Backbone wrapper
|
| 28 |
+
└── utils/
|
| 29 |
+
├── visualization.py # Visualization utilities
|
| 30 |
+
└── training.py # Training utilities (for round_to_nearest_multiple)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### 3. Clone Your Space Repository
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
# Install git-lfs if not already installed
|
| 37 |
+
git lfs install
|
| 38 |
+
|
| 39 |
+
# Clone your space
|
| 40 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 41 |
+
cd YOUR_SPACE_NAME
|
| 42 |
+
|
| 43 |
+
# Copy files from your local NAF project
|
| 44 |
+
cp /home/lchambon/workspace/NAF_off/app.py .
|
| 45 |
+
cp /home/lchambon/workspace/NAF_off/requirements_demo.txt ./requirements.txt
|
| 46 |
+
|
| 47 |
+
# Copy source files
|
| 48 |
+
mkdir -p src/backbone utils
|
| 49 |
+
cp /home/lchambon/workspace/NAF_off/src/backbone/vit_wrapper.py src/backbone/
|
| 50 |
+
cp /home/lchambon/workspace/NAF_off/utils/visualization.py utils/
|
| 51 |
+
cp /home/lchambon/workspace/NAF_off/utils/training.py utils/
|
| 52 |
+
cp /home/lchambon/workspace/NAF_off/utils/img.py utils/ # If needed
|
| 53 |
+
|
| 54 |
+
# Copy sample images (optional)
|
| 55 |
+
mkdir -p asset
|
| 56 |
+
cp /home/lchambon/workspace/NAF_off/asset/*.png asset/
|
| 57 |
+
cp /home/lchambon/workspace/NAF_off/asset/*.jpg asset/
|
| 58 |
+
|
| 59 |
+
# Add all files
|
| 60 |
+
git add .
|
| 61 |
+
git commit -m "Initial commit: NAF Feature Upsampling Demo"
|
| 62 |
+
git push
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### 4. Alternative: Use Hugging Face Web Interface
|
| 66 |
+
|
| 67 |
+
1. Navigate to your Space's "Files" tab
|
| 68 |
+
2. Click "Add file" → "Upload files"
|
| 69 |
+
3. Upload all required files maintaining the directory structure
|
| 70 |
+
4. Commit changes
|
| 71 |
+
|
| 72 |
+
### 5. Monitor Deployment
|
| 73 |
+
|
| 74 |
+
- Once pushed, Hugging Face will automatically build your Space
|
| 75 |
+
- Check the "Logs" tab to monitor the build process
|
| 76 |
+
- The Space will be available at: `https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME`
|
| 77 |
+
|
| 78 |
+
## Hardware Recommendations
|
| 79 |
+
|
| 80 |
+
- **CPU Basic (free)**: Works but slower inference (~10-30s per image)
|
| 81 |
+
- **T4 Small GPU**: Recommended for better performance (~2-5s per image)
|
| 82 |
+
- **T4 Medium/Large GPU**: For handling multiple concurrent users
|
| 83 |
+
|
| 84 |
+
## Important Notes
|
| 85 |
+
|
| 86 |
+
### Model Loading
|
| 87 |
+
The NAF model is loaded from torch.hub, which will download it on first run:
|
| 88 |
+
```python
|
| 89 |
+
model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device)
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Memory Considerations
|
| 93 |
+
- Backbone models are loaded on-demand per request
|
| 94 |
+
- Consider caching popular models if you upgrade to persistent storage
|
| 95 |
+
- GPU spaces have more memory for handling larger images
|
| 96 |
+
|
| 97 |
+
### Sample Images
|
| 98 |
+
- Upload sample images to the `asset/` folder
|
| 99 |
+
- Update `SAMPLE_IMAGES` list in `app.py` to match available images
|
| 100 |
+
- Or remove the examples section if images aren't available
|
| 101 |
+
|
| 102 |
+
## Troubleshooting
|
| 103 |
+
|
| 104 |
+
### Build Failures
|
| 105 |
+
- Check "Logs" tab for error messages
|
| 106 |
+
- Ensure all dependencies are in `requirements.txt`
|
| 107 |
+
- Verify Python version compatibility (3.8-3.10 recommended)
|
| 108 |
+
|
| 109 |
+
### Import Errors
|
| 110 |
+
- Make sure `src/` and `utils/` directories are included
|
| 111 |
+
- Check that `__init__.py` files exist if needed
|
| 112 |
+
- Verify relative imports are correct
|
| 113 |
+
|
| 114 |
+
### Memory Issues
|
| 115 |
+
- Reduce max resolution if needed
|
| 116 |
+
- Consider using CPU-only mode for free tier
|
| 117 |
+
- Upgrade to GPU hardware if processing large images
|
| 118 |
+
|
| 119 |
+
## Updating Your Space
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
# Make changes to your files
|
| 123 |
+
git add .
|
| 124 |
+
git commit -m "Update: description of changes"
|
| 125 |
+
git push
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
Hugging Face will automatically rebuild and redeploy your Space.
|
| 129 |
+
|
| 130 |
+
## Making Your Space Public
|
| 131 |
+
|
| 132 |
+
1. Go to Space Settings
|
| 133 |
+
2. Change visibility to "Public"
|
| 134 |
+
3. Add a good README.md with description and usage instructions
|
| 135 |
+
4. Consider adding a thumbnail image
|
| 136 |
+
|
| 137 |
+
## Example Space README
|
| 138 |
+
|
| 139 |
+
See `README_SPACE.md` for a template README to use on Hugging Face Spaces.
|
README.md
CHANGED
|
@@ -1,14 +1,84 @@
|
|
| 1 |
---
|
| 2 |
-
title: NAF
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description: 'NAF: Zero-Shot Feature Upsampling via Neighborhood Attention'
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: NAF Zero-Shot Feature Upsampling
|
| 3 |
+
emoji: 🎯
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.50.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 🎯 NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering
|
| 14 |
+
|
| 15 |
+
This Space demonstrates **NAF (Neighborhood Attention Filtering)**, a method for upsampling features from Vision Foundation Models to any resolution without model-specific training.
|
| 16 |
+
|
| 17 |
+
## 🚀 Features
|
| 18 |
+
|
| 19 |
+
- **Universal Upsampling**: Works with any Vision Foundation Model (DINOv2, DINOv3, RADIO, DINO, SigLIP, etc.)
|
| 20 |
+
- **Arbitrary Resolutions**: Upsample features to any target resolution while maintaining aspect ratio
|
| 21 |
+
- **Zero-Shot**: No model-specific training or fine-tuning required
|
| 22 |
+
- **Interactive Demo**: Upload your own images or try sample images from various domains
|
| 23 |
+
|
| 24 |
+
## 🎨 How to Use
|
| 25 |
+
|
| 26 |
+
1. **Upload an Image**: Click "Upload Your Image" or select from sample images
|
| 27 |
+
2. **Choose a Model**: Select a Vision Foundation Model from the dropdown
|
| 28 |
+
3. **Set Resolution**: Choose the target resolution for upsampled features (64-512)
|
| 29 |
+
4. **Click "Upsample Features"**: See the comparison between low and high-resolution features
|
| 30 |
+
|
| 31 |
+
## 📊 Visualization
|
| 32 |
+
|
| 33 |
+
The output shows three panels:
|
| 34 |
+
- **Left**: Your input image
|
| 35 |
+
- **Center**: Low-resolution features from the backbone (PCA visualization)
|
| 36 |
+
- **Right**: High-resolution features upsampled by NAF
|
| 37 |
+
|
| 38 |
+
Features are visualized using PCA for the first 3 principal components as RGB channels.
|
| 39 |
+
|
| 40 |
+
## 🔬 Supported Models
|
| 41 |
+
|
| 42 |
+
- **DINOv3**: Latest self-supervised vision models
|
| 43 |
+
- **RADIO v2.5**: High-performance vision backbones
|
| 44 |
+
- **DINOv2**: Self-supervised learning with registers
|
| 45 |
+
- **DINO**: Original self-supervised ViT
|
| 46 |
+
- **SigLIP**: Contrastive vision-language models
|
| 47 |
+
|
| 48 |
+
## 📖 Learn More
|
| 49 |
+
|
| 50 |
+
- **Paper**: [NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering](https://arxiv.org/abs/2501.01535)
|
| 51 |
+
- **Code**: [GitHub Repository](https://github.com/valeoai/NAF)
|
| 52 |
+
- **Organization**: [Valeo.ai](https://www.valeo.com/en/valeo-ai/)
|
| 53 |
+
|
| 54 |
+
## 💡 Use Cases
|
| 55 |
+
|
| 56 |
+
NAF enables better feature representations for:
|
| 57 |
+
- Dense prediction tasks (segmentation, depth estimation)
|
| 58 |
+
- High-resolution visual understanding
|
| 59 |
+
- Feature matching and correspondence
|
| 60 |
+
- Vision-language alignment
|
| 61 |
+
|
| 62 |
+
## ⚙️ Technical Details
|
| 63 |
+
|
| 64 |
+
- **Input**: Images up to 512px (maintains aspect ratio)
|
| 65 |
+
- **Processing**: Backbone feature extraction → NAF upsampling
|
| 66 |
+
- **Output**: High-resolution features at target resolution
|
| 67 |
+
- **Device**: Runs on CPU (free tier) or GPU (faster inference)
|
| 68 |
+
|
| 69 |
+
## 🤝 Citation
|
| 70 |
+
|
| 71 |
+
If you use NAF in your research, please cite:
|
| 72 |
+
|
| 73 |
+
```bibtex
|
| 74 |
+
@article{chambon2025naf,
|
| 75 |
+
title={NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering},
|
| 76 |
+
author={Chambon, Lucas and others},
|
| 77 |
+
journal={arXiv preprint arXiv:2501.01535},
|
| 78 |
+
year={2025}
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 📜 License
|
| 83 |
+
|
| 84 |
+
This demo is released under the Apache 2.0 license.
|
README_SPACE.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: NAF Zero-Shot Feature Upsampling
|
| 3 |
+
emoji: 🎯
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.50.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# 🎯 NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering
|
| 14 |
+
|
| 15 |
+
This Space demonstrates **NAF (Neighborhood Attention Filtering)**, a method for upsampling features from Vision Foundation Models to any resolution without model-specific training.
|
| 16 |
+
|
| 17 |
+
## 🚀 Features
|
| 18 |
+
|
| 19 |
+
- **Universal Upsampling**: Works with any Vision Foundation Model (DINOv2, DINOv3, RADIO, DINO, SigLIP, etc.)
|
| 20 |
+
- **Arbitrary Resolutions**: Upsample features to any target resolution while maintaining aspect ratio
|
| 21 |
+
- **Zero-Shot**: No model-specific training or fine-tuning required
|
| 22 |
+
- **Interactive Demo**: Upload your own images or try sample images from various domains
|
| 23 |
+
|
| 24 |
+
## 🎨 How to Use
|
| 25 |
+
|
| 26 |
+
1. **Upload an Image**: Click "Upload Your Image" or select from sample images
|
| 27 |
+
2. **Choose a Model**: Select a Vision Foundation Model from the dropdown
|
| 28 |
+
3. **Set Resolution**: Choose the target resolution for upsampled features (64-512)
|
| 29 |
+
4. **Click "Upsample Features"**: See the comparison between low and high-resolution features
|
| 30 |
+
|
| 31 |
+
## 📊 Visualization
|
| 32 |
+
|
| 33 |
+
The output shows three panels:
|
| 34 |
+
- **Left**: Your input image
|
| 35 |
+
- **Center**: Low-resolution features from the backbone (PCA visualization)
|
| 36 |
+
- **Right**: High-resolution features upsampled by NAF
|
| 37 |
+
|
| 38 |
+
Features are visualized using PCA for the first 3 principal components as RGB channels.
|
| 39 |
+
|
| 40 |
+
## 🔬 Supported Models
|
| 41 |
+
|
| 42 |
+
- **DINOv3**: Latest self-supervised vision models
|
| 43 |
+
- **RADIO v2.5**: High-performance vision backbones
|
| 44 |
+
- **DINOv2**: Self-supervised learning with registers
|
| 45 |
+
- **DINO**: Original self-supervised ViT
|
| 46 |
+
- **SigLIP**: Contrastive vision-language models
|
| 47 |
+
|
| 48 |
+
## 📖 Learn More
|
| 49 |
+
|
| 50 |
+
- **Paper**: [NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering](https://arxiv.org/abs/2501.01535)
|
| 51 |
+
- **Code**: [GitHub Repository](https://github.com/valeoai/NAF)
|
| 52 |
+
- **Organization**: [Valeo.ai](https://www.valeo.com/en/valeo-ai/)
|
| 53 |
+
|
| 54 |
+
## 💡 Use Cases
|
| 55 |
+
|
| 56 |
+
NAF enables better feature representations for:
|
| 57 |
+
- Dense prediction tasks (segmentation, depth estimation)
|
| 58 |
+
- High-resolution visual understanding
|
| 59 |
+
- Feature matching and correspondence
|
| 60 |
+
- Vision-language alignment
|
| 61 |
+
|
| 62 |
+
## ⚙️ Technical Details
|
| 63 |
+
|
| 64 |
+
- **Input**: Images up to 512px (maintains aspect ratio)
|
| 65 |
+
- **Processing**: Backbone feature extraction → NAF upsampling
|
| 66 |
+
- **Output**: High-resolution features at target resolution
|
| 67 |
+
- **Device**: Runs on CPU (free tier) or GPU (faster inference)
|
| 68 |
+
|
| 69 |
+
## 🤝 Citation
|
| 70 |
+
|
| 71 |
+
If you use NAF in your research, please cite:
|
| 72 |
+
|
| 73 |
+
```bibtex
|
| 74 |
+
@article{chambon2025naf,
|
| 75 |
+
title={NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering},
|
| 76 |
+
author={Chambon, Lucas and others},
|
| 77 |
+
journal={arXiv preprint arXiv:2501.01535},
|
| 78 |
+
year={2025}
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 📜 License
|
| 83 |
+
|
| 84 |
+
This demo is released under the Apache 2.0 license.
|
app.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import PIL.Image
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
|
| 13 |
+
# Add project root to path
|
| 14 |
+
sys.path.append(str(Path(__file__).parent))
|
| 15 |
+
from src.backbone.vit_wrapper import PretrainedViTWrapper
|
| 16 |
+
from utils.training import round_to_nearest_multiple
|
| 17 |
+
from utils.visualization import plot_feats
|
| 18 |
+
|
| 19 |
+
# Load NAF model
|
| 20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device)
|
| 22 |
+
model.eval()
|
| 23 |
+
|
| 24 |
+
# Normalization for upsampling
|
| 25 |
+
ups_norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 26 |
+
|
| 27 |
+
# Sample images
|
| 28 |
+
SAMPLE_IMAGES = [
|
| 29 |
+
"asset/Cartoon.png",
|
| 30 |
+
"asset/Natural.png",
|
| 31 |
+
"asset/Satellite.png",
|
| 32 |
+
"asset/Medical.png",
|
| 33 |
+
"asset/Ecosystems.png",
|
| 34 |
+
"asset/Driving.jpg",
|
| 35 |
+
"asset/Manufacturing.png",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def resize_with_aspect_ratio(img, max_size, patch_size):
|
| 40 |
+
"""Resize image maintaining aspect ratio with max dimension and patch size constraints"""
|
| 41 |
+
w, h = img.size
|
| 42 |
+
|
| 43 |
+
# Calculate scaling factor to fit within max_size
|
| 44 |
+
scale = min(max_size / w, max_size / h)
|
| 45 |
+
new_w = int(w * scale)
|
| 46 |
+
new_h = int(h * scale)
|
| 47 |
+
|
| 48 |
+
# Round to nearest patch size multiple
|
| 49 |
+
new_w = round_to_nearest_multiple(new_w, patch_size)
|
| 50 |
+
new_h = round_to_nearest_multiple(new_h, patch_size)
|
| 51 |
+
|
| 52 |
+
# Ensure minimum size
|
| 53 |
+
new_w = max(new_w, patch_size)
|
| 54 |
+
new_h = max(new_h, patch_size)
|
| 55 |
+
|
| 56 |
+
return new_w, new_h
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def process_image(image, model_name, output_resolution):
|
| 61 |
+
"""Process image with selected model and resolution"""
|
| 62 |
+
try:
|
| 63 |
+
# Load the backbone using vit_wrapper
|
| 64 |
+
backbone = PretrainedViTWrapper(model_name, norm=True).to(device)
|
| 65 |
+
backbone.eval()
|
| 66 |
+
|
| 67 |
+
# Get model config for normalization and input size
|
| 68 |
+
mean = backbone.config["mean"]
|
| 69 |
+
std = backbone.config["std"]
|
| 70 |
+
patch_size = backbone.patch_size
|
| 71 |
+
back_norm = T.Normalize(mean=mean, std=std)
|
| 72 |
+
|
| 73 |
+
# Prepare image at model's expected resolution
|
| 74 |
+
img = PIL.Image.fromarray(image).convert("RGB")
|
| 75 |
+
new_w, new_h = resize_with_aspect_ratio(img, max_size=512, patch_size=patch_size)
|
| 76 |
+
|
| 77 |
+
transform = T.Compose(
|
| 78 |
+
[
|
| 79 |
+
T.Resize((new_h, new_w)),
|
| 80 |
+
T.ToTensor(),
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
img_tensor = transform(img).unsqueeze(0).to(device)
|
| 84 |
+
|
| 85 |
+
# Normalize for backbone
|
| 86 |
+
img_back = back_norm(img_tensor)
|
| 87 |
+
lr_feats = backbone(img_back)
|
| 88 |
+
|
| 89 |
+
# vit_wrapper already returns features in [B, C, H, W] format
|
| 90 |
+
if not isinstance(lr_feats, torch.Tensor):
|
| 91 |
+
raise ValueError(f"Unexpected feature type: {type(lr_feats)}")
|
| 92 |
+
|
| 93 |
+
if len(lr_feats.shape) != 4:
|
| 94 |
+
raise ValueError(f"Unexpected feature shape: {lr_feats.shape}. Expected [B, C, H, W].")
|
| 95 |
+
|
| 96 |
+
# Normalize for upsampling
|
| 97 |
+
img_ups = ups_norm(img_tensor)
|
| 98 |
+
|
| 99 |
+
# Calculate output resolution maintaining aspect ratio
|
| 100 |
+
_, _, h, w = lr_feats.shape
|
| 101 |
+
aspect_ratio = w / h
|
| 102 |
+
if aspect_ratio > 1: # Width > Height
|
| 103 |
+
out_h = round_to_nearest_multiple(int(output_resolution / aspect_ratio), patch_size)
|
| 104 |
+
out_w = output_resolution
|
| 105 |
+
else: # Height >= Width
|
| 106 |
+
out_h = output_resolution
|
| 107 |
+
out_w = round_to_nearest_multiple(int(output_resolution * aspect_ratio), patch_size)
|
| 108 |
+
|
| 109 |
+
upsampled_feats = model(img_ups, lr_feats, (out_h, out_w))
|
| 110 |
+
|
| 111 |
+
# Create visualization using plot_feats
|
| 112 |
+
plot_feats(
|
| 113 |
+
img_tensor[0],
|
| 114 |
+
lr_feats[0],
|
| 115 |
+
[upsampled_feats[0]],
|
| 116 |
+
legend=["Image", f"Low-Res: {h}x{w}", f"High-Res: {out_h}x{out_w}"],
|
| 117 |
+
font_size=14,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Convert matplotlib figure to PIL Image
|
| 121 |
+
fig = plt.gcf() # Get current figure
|
| 122 |
+
buf = io.BytesIO()
|
| 123 |
+
fig.savefig(buf, format="png", dpi=100, bbox_inches="tight")
|
| 124 |
+
buf.seek(0)
|
| 125 |
+
result_img = PIL.Image.open(buf)
|
| 126 |
+
plt.close(fig)
|
| 127 |
+
|
| 128 |
+
return result_img
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error processing image: {e}")
|
| 132 |
+
import traceback
|
| 133 |
+
|
| 134 |
+
traceback.print_exc()
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Popular vision models for the dropdown (from vit_wrapper.py)
|
| 139 |
+
POPULAR_MODELS = [
|
| 140 |
+
"vit_base_patch16_dinov3.lvd1689m",
|
| 141 |
+
"radio_v2.5-b",
|
| 142 |
+
"vit_base_patch14_reg4_dinov2",
|
| 143 |
+
"vit_base_patch14_dinov2.lvd142m",
|
| 144 |
+
"vit_base_patch16_224.dino",
|
| 145 |
+
"vit_base_patch16_siglip_512.v2_webli",
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
# Create Gradio interface
|
| 149 |
+
with gr.Blocks(title="NAF: Zero-Shot Feature Upsampling") as demo:
|
| 150 |
+
gr.HTML(
|
| 151 |
+
"""
|
| 152 |
+
<div style="text-align: center; margin-bottom: 2rem;">
|
| 153 |
+
<h1 class="title-text" style="font-size: 3rem; margin-bottom: 0.5rem;">
|
| 154 |
+
🎯 NAF: Zero-Shot Feature Upsampling
|
| 155 |
+
</h1>
|
| 156 |
+
<p style="font-size: 1.2rem; color: #666; margin-bottom: 1rem;">
|
| 157 |
+
via Neighborhood Attention Filtering
|
| 158 |
+
</p>
|
| 159 |
+
<div class="info-box" style="max-width: 900px; margin: 0 auto;">
|
| 160 |
+
<p style="font-size: 1.1rem; margin-bottom: 0.8rem;">
|
| 161 |
+
🚀 <strong>Upsample features from any Vision Foundation Model to any resolution!</strong>
|
| 162 |
+
</p>
|
| 163 |
+
<p style="font-size: 0.95rem; margin: 0;">
|
| 164 |
+
Upload an image, select a model, choose your target resolution, and see NAF in action.
|
| 165 |
+
</p>
|
| 166 |
+
</div>
|
| 167 |
+
</div>
|
| 168 |
+
"""
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
with gr.Row():
|
| 172 |
+
with gr.Column(scale=1):
|
| 173 |
+
gr.Markdown("### 📤 Input Configuration")
|
| 174 |
+
|
| 175 |
+
image_input = gr.Image(label="Upload Your Image", type="numpy")
|
| 176 |
+
|
| 177 |
+
# Sample images
|
| 178 |
+
if any(Path(p).exists() for p in SAMPLE_IMAGES):
|
| 179 |
+
gr.Examples(
|
| 180 |
+
examples=[[p] for p in SAMPLE_IMAGES if Path(p).exists()],
|
| 181 |
+
inputs=image_input,
|
| 182 |
+
label="🖼️ Try Sample Images",
|
| 183 |
+
examples_per_page=4,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
gr.Markdown("### ⚙️ Model Settings")
|
| 187 |
+
|
| 188 |
+
model_dropdown = gr.Dropdown(
|
| 189 |
+
choices=POPULAR_MODELS,
|
| 190 |
+
value=POPULAR_MODELS[0],
|
| 191 |
+
label="🤖 Vision Foundation Model",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
resolution_slider = gr.Slider(
|
| 195 |
+
minimum=64,
|
| 196 |
+
maximum=512,
|
| 197 |
+
step=64,
|
| 198 |
+
value=448,
|
| 199 |
+
label="📏 Output Resolution (max dimension)",
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
process_btn = gr.Button("✨ Upsample Features", variant="primary")
|
| 203 |
+
|
| 204 |
+
with gr.Column(scale=2):
|
| 205 |
+
gr.Markdown("### 🎨 Visualization Results")
|
| 206 |
+
output_image = gr.Image(label="Feature Comparison", type="pil")
|
| 207 |
+
|
| 208 |
+
gr.Markdown(
|
| 209 |
+
"""
|
| 210 |
+
<div style="background: #f0f7ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #667eea;">
|
| 211 |
+
<strong>📊 Visualization Guide:</strong>
|
| 212 |
+
<ul style="margin: 0.5rem 0;">
|
| 213 |
+
<li><strong>Left:</strong> Original input image</li>
|
| 214 |
+
<li><strong>Center:</strong> Low-resolution features (PCA visualization)</li>
|
| 215 |
+
<li><strong>Right:</strong> High-resolution features upsampled by NAF</li>
|
| 216 |
+
</ul>
|
| 217 |
+
<p style="margin-top: 0.5rem; font-size: 0.9rem; color: #555;">
|
| 218 |
+
<em>Note: Output features maintain the aspect ratio of the input image.</em>
|
| 219 |
+
</p>
|
| 220 |
+
</div>
|
| 221 |
+
"""
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
process_btn.click(fn=process_image, inputs=[image_input, model_dropdown, resolution_slider], outputs=output_image)
|
| 225 |
+
|
| 226 |
+
gr.Markdown(
|
| 227 |
+
"""
|
| 228 |
+
---
|
| 229 |
+
<div style="text-align: center; padding: 2rem 0;">
|
| 230 |
+
<h3 style="color: #667eea;">💡 About NAF</h3>
|
| 231 |
+
<p style="max-width: 800px; margin: 1rem auto; font-size: 1.05rem; color: #555;">
|
| 232 |
+
NAF enables <strong>zero-shot feature upsampling</strong> from any Vision Foundation Model
|
| 233 |
+
to any resolution. It learns to filter and combine features using neighborhood attention,
|
| 234 |
+
without requiring model-specific training.
|
| 235 |
+
</p>
|
| 236 |
+
<div style="margin-top: 1.5rem;">
|
| 237 |
+
<a href="https://github.com/valeoai/NAF" target="_blank"
|
| 238 |
+
style="margin: 0 1rem; text-decoration: none; color: #667eea; font-weight: bold;">
|
| 239 |
+
📦 GitHub Repository
|
| 240 |
+
</a>
|
| 241 |
+
<a href="https://arxiv.org/abs/2501.01535" target="_blank"
|
| 242 |
+
style="margin: 0 1rem; text-decoration: none; color: #667eea; font-weight: bold;">
|
| 243 |
+
📄 Research Paper
|
| 244 |
+
</a>
|
| 245 |
+
</div>
|
| 246 |
+
</div>
|
| 247 |
+
"""
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
demo.launch()
|
deploy_to_hf.sh
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Deploy NAF Demo to Hugging Face Spaces
|
| 4 |
+
# Usage: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
if [ "$#" -ne 2 ]; then
|
| 9 |
+
echo "Usage: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME"
|
| 10 |
+
echo "Example: ./deploy_to_hf.sh myusername naf-demo"
|
| 11 |
+
exit 1
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
USERNAME=$1
|
| 15 |
+
SPACE_NAME=$2
|
| 16 |
+
SPACE_URL="https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME}"
|
| 17 |
+
|
| 18 |
+
echo "🚀 Deploying NAF Demo to Hugging Face Spaces"
|
| 19 |
+
echo "Space URL will be: ${SPACE_URL}"
|
| 20 |
+
echo ""
|
| 21 |
+
|
| 22 |
+
# Check if git-lfs is installed
|
| 23 |
+
if ! command -v git-lfs &> /dev/null; then
|
| 24 |
+
echo "⚠️ git-lfs is not installed. Installing..."
|
| 25 |
+
git lfs install
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
# Create temporary directory
|
| 29 |
+
TEMP_DIR=$(mktemp -d)
|
| 30 |
+
echo "📁 Created temporary directory: ${TEMP_DIR}"
|
| 31 |
+
|
| 32 |
+
# Clone the space
|
| 33 |
+
echo "📥 Cloning space repository..."
|
| 34 |
+
git clone https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME} ${TEMP_DIR}
|
| 35 |
+
cd ${TEMP_DIR}
|
| 36 |
+
|
| 37 |
+
# Copy main files
|
| 38 |
+
echo "📋 Copying files..."
|
| 39 |
+
cp /home/lchambon/workspace/NAF_off/app.py .
|
| 40 |
+
cp /home/lchambon/workspace/NAF_off/requirements_demo.txt ./requirements.txt
|
| 41 |
+
cp /home/lchambon/workspace/NAF_off/README_SPACE.md ./README.md
|
| 42 |
+
|
| 43 |
+
# Create directory structure
|
| 44 |
+
mkdir -p src/backbone utils asset
|
| 45 |
+
|
| 46 |
+
# Copy source files
|
| 47 |
+
cp /home/lchambon/workspace/NAF_off/src/backbone/vit_wrapper.py src/backbone/
|
| 48 |
+
cp /home/lchambon/workspace/NAF_off/utils/visualization.py utils/
|
| 49 |
+
cp /home/lchambon/workspace/NAF_off/utils/training.py utils/
|
| 50 |
+
|
| 51 |
+
# Copy utils/__init__.py if it exists
|
| 52 |
+
if [ -f /home/lchambon/workspace/NAF_off/utils/__init__.py ]; then
|
| 53 |
+
cp /home/lchambon/workspace/NAF_off/utils/__init__.py utils/
|
| 54 |
+
else
|
| 55 |
+
touch utils/__init__.py
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
# Copy src/__init__.py if it exists
|
| 59 |
+
if [ -f /home/lchambon/workspace/NAF_off/src/__init__.py ]; then
|
| 60 |
+
cp /home/lchambon/workspace/NAF_off/src/__init__.py src/
|
| 61 |
+
else
|
| 62 |
+
touch src/__init__.py
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
# Copy src/backbone/__init__.py if it exists
|
| 66 |
+
if [ -f /home/lchambon/workspace/NAF_off/src/backbone/__init__.py ]; then
|
| 67 |
+
cp /home/lchambon/workspace/NAF_off/src/backbone/__init__.py src/backbone/
|
| 68 |
+
else
|
| 69 |
+
touch src/backbone/__init__.py
|
| 70 |
+
fi
|
| 71 |
+
|
| 72 |
+
# Copy sample images if they exist
|
| 73 |
+
echo "🖼️ Copying sample images..."
|
| 74 |
+
if [ -d /home/lchambon/workspace/NAF_off/asset ]; then
|
| 75 |
+
cp /home/lchambon/workspace/NAF_off/asset/*.png asset/ 2>/dev/null || true
|
| 76 |
+
cp /home/lchambon/workspace/NAF_off/asset/*.jpg asset/ 2>/dev/null || true
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
# Add .gitignore
|
| 80 |
+
cat > .gitignore << EOF
|
| 81 |
+
__pycache__/
|
| 82 |
+
*.py[cod]
|
| 83 |
+
*$py.class
|
| 84 |
+
*.so
|
| 85 |
+
.Python
|
| 86 |
+
.env
|
| 87 |
+
.venv
|
| 88 |
+
*.egg-info/
|
| 89 |
+
.DS_Store
|
| 90 |
+
EOF
|
| 91 |
+
|
| 92 |
+
# Git operations
|
| 93 |
+
echo "📤 Pushing to Hugging Face..."
|
| 94 |
+
git add .
|
| 95 |
+
git commit -m "Deploy NAF Feature Upsampling Demo"
|
| 96 |
+
git push
|
| 97 |
+
|
| 98 |
+
echo ""
|
| 99 |
+
echo "✅ Deployment complete!"
|
| 100 |
+
echo "🌐 Your Space will be available at: ${SPACE_URL}"
|
| 101 |
+
echo "⏳ It may take a few minutes to build..."
|
| 102 |
+
echo ""
|
| 103 |
+
echo "To check build status:"
|
| 104 |
+
echo " - Visit: ${SPACE_URL}"
|
| 105 |
+
echo " - Click on 'Logs' tab"
|
| 106 |
+
|
| 107 |
+
# Cleanup
|
| 108 |
+
cd -
|
| 109 |
+
rm -rf ${TEMP_DIR}
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
einops==0.8.0
|
| 2 |
+
numpy==1.24.4
|
| 3 |
+
timm==1.0.22
|
| 4 |
+
plotly==6.0.0
|
| 5 |
+
tensorboard==2.20.0
|
| 6 |
+
hydra-core==1.3.2
|
| 7 |
+
matplotlib==3.7.0
|
| 8 |
+
rich==14.2.0
|
| 9 |
+
torchmetrics==1.6.2
|
| 10 |
+
scipy==1.15.2
|
| 11 |
+
kornia==0.8.2
|
| 12 |
+
ipykernel
|
| 13 |
+
ipympl
|
| 14 |
+
pytest
|
| 15 |
+
|
| 16 |
+
# Torch + CUDA 11.8 (HuggingFace compatible install)
|
| 17 |
+
torch==2.4.0
|
| 18 |
+
torchvision==0.19.0
|
| 19 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 20 |
+
|
| 21 |
+
# NATTEN wheels
|
| 22 |
+
natten==0.17.4+torch240cu118 -f https://shi-labs.com/natten/wheels/
|
src/backbone/vit_wrapper.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import types
|
| 3 |
+
from typing import List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import timm
|
| 6 |
+
import timm.data
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from timm.models.vision_transformer import VisionTransformer
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
|
| 14 |
+
# We provide a list of timm model names, more are available on their official repo.
|
| 15 |
+
MODEL_LIST = [
|
| 16 |
+
# DINO
|
| 17 |
+
"vit_base_patch16_224.dino",
|
| 18 |
+
# DINOv2
|
| 19 |
+
"vit_base_patch14_dinov2.lvd142m",
|
| 20 |
+
# DINOv2-R
|
| 21 |
+
"vit_base_patch14_reg4_dinov2",
|
| 22 |
+
# Franca
|
| 23 |
+
"franca_vitb14",
|
| 24 |
+
# DINOv3-ViT
|
| 25 |
+
"vit_base_patch16_dinov3.lvd1689m",
|
| 26 |
+
"vit_large_patch16_dinov3.lvd1689m",
|
| 27 |
+
"vit_7b_patch16_dinov3.lvd1689m",
|
| 28 |
+
# SigLIP2
|
| 29 |
+
"vit_base_patch16_siglip_512.v2_webli",
|
| 30 |
+
# PE Core
|
| 31 |
+
"vit_pe_core_small_patch16_384.fb",
|
| 32 |
+
# PE Spatial
|
| 33 |
+
"vit_pe_spatial_tiny_patch16_512.fb",
|
| 34 |
+
# RADIO
|
| 35 |
+
"radio_v2.5-b",
|
| 36 |
+
# CAPI
|
| 37 |
+
"capi_vitl14_lvd",
|
| 38 |
+
# MAE
|
| 39 |
+
"vit_large_patch16_224.mae",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 43 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PretrainedViTWrapper(nn.Module):
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
name,
|
| 51 |
+
norm: bool = True,
|
| 52 |
+
dynamic_img_size: bool = True,
|
| 53 |
+
dynamic_img_pad: bool = False,
|
| 54 |
+
**kwargs,
|
| 55 |
+
):
|
| 56 |
+
super().__init__()
|
| 57 |
+
# comment out the following line to test the models not in the list
|
| 58 |
+
self.name = name
|
| 59 |
+
|
| 60 |
+
load_weights = False
|
| 61 |
+
if "dvt_" == name[:4]:
|
| 62 |
+
load_weights = True
|
| 63 |
+
load_tag = "dvt"
|
| 64 |
+
name = name.replace("dvt_", "")
|
| 65 |
+
if "fit3d_" == name[:6]:
|
| 66 |
+
load_weights = True
|
| 67 |
+
load_tag = "fit3d"
|
| 68 |
+
name = name.replace("fit3d_", "")
|
| 69 |
+
|
| 70 |
+
# Set patch size
|
| 71 |
+
try:
|
| 72 |
+
self.patch_size = int(re.search(r"patch(\d+)", name).group(1))
|
| 73 |
+
except:
|
| 74 |
+
self.patch_size = 16
|
| 75 |
+
if "franca" in name or "capi" in name:
|
| 76 |
+
self.patch_size = 14
|
| 77 |
+
if "convnext" in name:
|
| 78 |
+
self.patch_size = 32
|
| 79 |
+
|
| 80 |
+
name, self.patch_size
|
| 81 |
+
|
| 82 |
+
self.dynamic_img_size = dynamic_img_size
|
| 83 |
+
self.dynamic_img_pad = dynamic_img_pad
|
| 84 |
+
self.model, self.config = self.create_model(name, **kwargs)
|
| 85 |
+
self.config["ps"] = self.patch_size
|
| 86 |
+
self.embed_dim = self.model.embed_dim
|
| 87 |
+
self.norm = norm
|
| 88 |
+
|
| 89 |
+
if load_weights:
|
| 90 |
+
ckpt = torch.load(f"/home/lchambon/workspace/JAFAR/ckpts/{load_tag}_{name}.pth", map_location="cpu")
|
| 91 |
+
if load_tag == "dvt":
|
| 92 |
+
self.load_state_dict(ckpt["model"], strict=True)
|
| 93 |
+
elif load_tag == "fit3d":
|
| 94 |
+
self.model.load_state_dict(ckpt, strict=True)
|
| 95 |
+
|
| 96 |
+
def create_model(self, name: str, **kwargs) -> Tuple[VisionTransformer, transforms.Compose]:
|
| 97 |
+
if "radio" in self.name:
|
| 98 |
+
model = torch.hub.load(
|
| 99 |
+
"NVlabs/RADIO",
|
| 100 |
+
"radio_model",
|
| 101 |
+
version=name,
|
| 102 |
+
progress=True,
|
| 103 |
+
skip_validation=True,
|
| 104 |
+
)
|
| 105 |
+
data_config = {
|
| 106 |
+
"mean": torch.tensor([0.0, 0.0, 0.0]),
|
| 107 |
+
"std": torch.tensor([1.0, 1.0, 1.0]),
|
| 108 |
+
"input_size": (3, 512, 512),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
elif "franca" in self.name:
|
| 112 |
+
model = torch.hub.load("valeoai/Franca", name, use_rasa_head=True)
|
| 113 |
+
data_config = {"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "input_size": (3, 448, 448)}
|
| 114 |
+
|
| 115 |
+
elif "capi" in self.name:
|
| 116 |
+
model = torch.hub.load("facebookresearch/capi:main", name, force_reload=False)
|
| 117 |
+
|
| 118 |
+
data_config = {"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "input_size": (3, 448, 448)}
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
timm_kwargs = dict(
|
| 122 |
+
pretrained=True,
|
| 123 |
+
num_classes=0,
|
| 124 |
+
patch_size=self.patch_size,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if "sam" not in self.name and "convnext" not in self.name:
|
| 128 |
+
timm_kwargs["dynamic_img_size"] = self.dynamic_img_size
|
| 129 |
+
timm_kwargs["dynamic_img_pad"] = self.dynamic_img_pad
|
| 130 |
+
|
| 131 |
+
timm_kwargs.update(kwargs)
|
| 132 |
+
model = timm.create_model(name, **timm_kwargs)
|
| 133 |
+
data_config = timm.data.resolve_model_data_config(model=model)
|
| 134 |
+
|
| 135 |
+
model = model.eval()
|
| 136 |
+
|
| 137 |
+
return model, data_config
|
| 138 |
+
|
| 139 |
+
def forward(
|
| 140 |
+
self,
|
| 141 |
+
x: torch.Tensor,
|
| 142 |
+
n: Union[int, List[int], Tuple[int]] = 1,
|
| 143 |
+
return_prefix_tokens: bool = False,
|
| 144 |
+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
| 145 |
+
"""Intermediate layer accessor inspired by DINO / DINOv2 interface.
|
| 146 |
+
Args:
|
| 147 |
+
x: Input tensor.
|
| 148 |
+
n: Take last n blocks if int, all if None, select matching indices if sequence
|
| 149 |
+
reshape: Whether to reshape the output.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
common_kwargs = dict(
|
| 153 |
+
norm=self.norm,
|
| 154 |
+
output_fmt="NCHW",
|
| 155 |
+
intermediates_only=True,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if "sam" not in self.name and return_prefix_tokens:
|
| 159 |
+
common_kwargs["return_prefix_tokens"] = return_prefix_tokens
|
| 160 |
+
|
| 161 |
+
elif "franca" in self.name:
|
| 162 |
+
B, C, H, W = x.shape
|
| 163 |
+
feats = self.model.forward_features(x, use_rasa_head=True)
|
| 164 |
+
out = feats["patch_token_rasa"]
|
| 165 |
+
out = rearrange(out, "b (h w) c -> b c h w", h=H // self.patch_size, w=W // self.patch_size)
|
| 166 |
+
|
| 167 |
+
elif "capi" in self.name:
|
| 168 |
+
*_, out = self.model(x)
|
| 169 |
+
out = out.permute(0, 3, 1, 2)
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
out = self.model.forward_intermediates(x, n, **common_kwargs)
|
| 173 |
+
|
| 174 |
+
# "sam" models return feats only, others may return (feats, prefix)
|
| 175 |
+
if not isinstance(out, list) and not isinstance(out, tuple):
|
| 176 |
+
out = [out]
|
| 177 |
+
return out[0]
|
| 178 |
+
else:
|
| 179 |
+
assert len(out) == 1, f"Out contains {len(out)} elements, expected 1."
|
| 180 |
+
return out[0]
|
utils/training.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.checkpoint as checkpoint
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
from hydra.utils import instantiate
|
| 10 |
+
from omegaconf import ListConfig
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 13 |
+
|
| 14 |
+
from src.backbone.vit_wrapper import PretrainedViTWrapper
|
| 15 |
+
from utils.img import PILToTensor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def seed_worker():
|
| 19 |
+
worker_seed = torch.initial_seed() % 2**32
|
| 20 |
+
np.random.seed(worker_seed)
|
| 21 |
+
random.seed(worker_seed)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def round_to_nearest_multiple(value, multiple=14):
|
| 25 |
+
return multiple * round(value / multiple)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_feats(cfg, backbone, image_batch, min_rescale=0.60, max_rescale=0.25):
|
| 29 |
+
_, _, H, W = image_batch.shape # Get original height and width
|
| 30 |
+
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
hr_feats = backbone(image_batch)
|
| 33 |
+
|
| 34 |
+
if cfg.get("lr_img_size", None) is not None:
|
| 35 |
+
size = (cfg.lr_img_size, cfg.lr_img_size)
|
| 36 |
+
else:
|
| 37 |
+
# Downscale
|
| 38 |
+
if cfg.down_factor == "random":
|
| 39 |
+
downscale_factor = np.random.uniform(min_rescale, max_rescale)
|
| 40 |
+
|
| 41 |
+
elif cfg.down_factor == "fixed":
|
| 42 |
+
downscale_factor = 0.5
|
| 43 |
+
|
| 44 |
+
new_H = round_to_nearest_multiple(H * downscale_factor, backbone.patch_size)
|
| 45 |
+
new_W = round_to_nearest_multiple(W * downscale_factor, backbone.patch_size)
|
| 46 |
+
size = (new_H, new_W)
|
| 47 |
+
low_res_batch = F.interpolate(image_batch, size=size, mode="bilinear")
|
| 48 |
+
lr_feats = backbone(low_res_batch)
|
| 49 |
+
|
| 50 |
+
return hr_feats, lr_feats
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def logger(args, base_log_dir):
|
| 54 |
+
os.makedirs(base_log_dir, exist_ok=True)
|
| 55 |
+
existing_versions = [
|
| 56 |
+
int(d.split("_")[-1])
|
| 57 |
+
for d in os.listdir(base_log_dir)
|
| 58 |
+
if os.path.isdir(os.path.join(base_log_dir, d)) and d.startswith("version_")
|
| 59 |
+
]
|
| 60 |
+
new_version = max(existing_versions, default=-1) + 1
|
| 61 |
+
new_log_dir = os.path.join(base_log_dir, f"version_{new_version}")
|
| 62 |
+
|
| 63 |
+
# Create the SummaryWriter with the new log directory
|
| 64 |
+
writer = SummaryWriter(log_dir=new_log_dir)
|
| 65 |
+
return writer, new_version, new_log_dir
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_dataloaders(cfg, shuffle=True):
|
| 69 |
+
"""Get dataloaders for either training or evaluation.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
cfg: Configuration object
|
| 73 |
+
backbone: Backbone model for normalization parameters
|
| 74 |
+
"""
|
| 75 |
+
# Default ImageNet normalization values
|
| 76 |
+
transforms = {
|
| 77 |
+
"image": T.Compose(
|
| 78 |
+
[
|
| 79 |
+
T.Resize(cfg.img_size, interpolation=InterpolationMode.BILINEAR),
|
| 80 |
+
T.CenterCrop((cfg.img_size, cfg.img_size)),
|
| 81 |
+
T.ToTensor(),
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
transforms["label"] = T.Compose(
|
| 87 |
+
[
|
| 88 |
+
# T.ToTensor(),
|
| 89 |
+
T.Resize(cfg.target_size, interpolation=InterpolationMode.NEAREST_EXACT),
|
| 90 |
+
T.CenterCrop((cfg.target_size, cfg.target_size)),
|
| 91 |
+
PILToTensor(),
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
train_dataset = cfg.dataset
|
| 95 |
+
val_dataset = cfg.dataset.copy()
|
| 96 |
+
if hasattr(val_dataset, "split"):
|
| 97 |
+
val_dataset.split = "val"
|
| 98 |
+
|
| 99 |
+
train_dataset = instantiate(
|
| 100 |
+
train_dataset,
|
| 101 |
+
transform=transforms["image"],
|
| 102 |
+
target_transform=transforms["label"],
|
| 103 |
+
)
|
| 104 |
+
val_dataset = instantiate(
|
| 105 |
+
val_dataset,
|
| 106 |
+
transform=transforms["image"],
|
| 107 |
+
target_transform=transforms["label"],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Create generator for reproducibility
|
| 111 |
+
if not shuffle:
|
| 112 |
+
g = torch.Generator()
|
| 113 |
+
g.manual_seed(0)
|
| 114 |
+
else:
|
| 115 |
+
g = None
|
| 116 |
+
|
| 117 |
+
# Prepare dataloader configs - set worker_init_fn to None when shuffling for randomness
|
| 118 |
+
train_dataloader_cfg = cfg.train_dataloader.copy()
|
| 119 |
+
val_dataloader_cfg = cfg.val_dataloader.copy()
|
| 120 |
+
|
| 121 |
+
if shuffle:
|
| 122 |
+
# Set worker_init_fn to None to allow true randomness when shuffling
|
| 123 |
+
if "worker_init_fn" in train_dataloader_cfg:
|
| 124 |
+
train_dataloader_cfg["worker_init_fn"] = None
|
| 125 |
+
if "worker_init_fn" in val_dataloader_cfg:
|
| 126 |
+
val_dataloader_cfg["worker_init_fn"] = None
|
| 127 |
+
|
| 128 |
+
return (
|
| 129 |
+
instantiate(train_dataloader_cfg, dataset=train_dataset, generator=g),
|
| 130 |
+
instantiate(val_dataloader_cfg, dataset=val_dataset, generator=g),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_batch(batch, device):
|
| 135 |
+
"""Process batch and return required tensors."""
|
| 136 |
+
batch["image"] = batch["image"].to(device)
|
| 137 |
+
return batch
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def setup_training_optimizations(model, cfg):
|
| 141 |
+
"""
|
| 142 |
+
Setup training optimizations based on configuration
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
model: The model to apply optimizations to
|
| 146 |
+
cfg: Configuration object with use_bf16 and use_checkpointing flags
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
tuple: (scaler, use_bf16, use_checkpointing) for use in training loop
|
| 150 |
+
"""
|
| 151 |
+
# Get configuration values with defaults
|
| 152 |
+
use_bf16 = getattr(cfg, "use_bf16", False)
|
| 153 |
+
use_checkpointing = getattr(cfg, "use_checkpointing", False)
|
| 154 |
+
|
| 155 |
+
# Initialize gradient scaler for mixed precision
|
| 156 |
+
scaler = torch.amp.GradScaler("cuda", enabled=use_bf16)
|
| 157 |
+
|
| 158 |
+
# Enable gradient checkpointing if requested
|
| 159 |
+
if use_checkpointing:
|
| 160 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 161 |
+
model.gradient_checkpointing_enable()
|
| 162 |
+
print(" ✓ Using built-in gradient checkpointing")
|
| 163 |
+
else:
|
| 164 |
+
# For custom models, wrap forward methods
|
| 165 |
+
def checkpoint_wrapper(module):
|
| 166 |
+
if hasattr(module, "forward"):
|
| 167 |
+
original_forward = module.forward
|
| 168 |
+
|
| 169 |
+
def checkpointed_forward(*args, **kwargs):
|
| 170 |
+
return checkpoint.checkpoint(original_forward, *args, **kwargs)
|
| 171 |
+
|
| 172 |
+
module.forward = checkpointed_forward
|
| 173 |
+
|
| 174 |
+
# Apply to key modules (adjust based on your model structure)
|
| 175 |
+
checkpointed_modules = []
|
| 176 |
+
for name, module in model.named_modules():
|
| 177 |
+
if any(key in name for key in ["cross_decode", "encoder", "sft"]):
|
| 178 |
+
checkpoint_wrapper(module)
|
| 179 |
+
checkpointed_modules.append(name)
|
| 180 |
+
|
| 181 |
+
if checkpointed_modules:
|
| 182 |
+
print(f" ✓ Applied custom gradient checkpointing to: {checkpointed_modules}")
|
| 183 |
+
else:
|
| 184 |
+
print(" ⚠ No modules found for gradient checkpointing")
|
| 185 |
+
|
| 186 |
+
print(f"Training optimizations:")
|
| 187 |
+
print(f" Mixed precision (bfloat16): {use_bf16}")
|
| 188 |
+
print(f" Gradient checkpointing: {use_checkpointing}")
|
| 189 |
+
|
| 190 |
+
return scaler, use_bf16, use_checkpointing
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def load_multiple_backbones(cfg, backbone_configs, device):
|
| 194 |
+
"""
|
| 195 |
+
Load multiple backbone models based on configuration.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
cfg: Hydra configuration object
|
| 199 |
+
device: PyTorch device to load models on
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
tuple: (backbones, backbone_names, primary_backbone)
|
| 203 |
+
- backbones: List of loaded backbone models
|
| 204 |
+
- backbone_names: List of backbone names
|
| 205 |
+
"""
|
| 206 |
+
backbones = []
|
| 207 |
+
backbone_names = []
|
| 208 |
+
backbone_img_sizes = []
|
| 209 |
+
|
| 210 |
+
if not isinstance(backbone_configs, list) and not isinstance(backbone_configs, ListConfig):
|
| 211 |
+
backbone_configs = [backbone_configs]
|
| 212 |
+
print(f"Loading {len(backbone_configs)} backbone(s)...")
|
| 213 |
+
|
| 214 |
+
for i, backbone_config in enumerate(backbone_configs):
|
| 215 |
+
name = backbone_config["name"]
|
| 216 |
+
if name == "rgb":
|
| 217 |
+
backbone = instantiate(cfg.backbone)
|
| 218 |
+
else:
|
| 219 |
+
backbone = PretrainedViTWrapper(name=name)
|
| 220 |
+
print(f" [{i}] Loaded {backbone_config['name']}")
|
| 221 |
+
|
| 222 |
+
# Move to device and set to eval mode
|
| 223 |
+
backbone = backbone.to(device)
|
| 224 |
+
backbone.eval() # Set to eval mode for feature extraction
|
| 225 |
+
|
| 226 |
+
# Store backbone and name
|
| 227 |
+
backbones.append(backbone)
|
| 228 |
+
backbone_names.append(backbone_config["name"])
|
| 229 |
+
backbone_img_sizes.append(backbone.config["input_size"][1:])
|
| 230 |
+
|
| 231 |
+
return backbones, backbone_names, backbone_img_sizes
|
utils/visualization.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Visualization code from https://github.com/Tsingularity/dift/blob/main/src/utils/visualization.py
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import matplotlib.colors as mcolors
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
FONT_SIZE = 40
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def plot_feats(
|
| 19 |
+
image,
|
| 20 |
+
target,
|
| 21 |
+
pred,
|
| 22 |
+
legend=["Image", "HR Features", "Pred Features"],
|
| 23 |
+
save_path=None,
|
| 24 |
+
return_array=False,
|
| 25 |
+
show_legend=True,
|
| 26 |
+
font_size=FONT_SIZE,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Create a plot_feats visualization.
|
| 30 |
+
"""
|
| 31 |
+
# Ensure hr_or_seg is a list
|
| 32 |
+
if not isinstance(pred, list):
|
| 33 |
+
pred = [pred]
|
| 34 |
+
|
| 35 |
+
# Prepare inputs for PCA
|
| 36 |
+
feats_for_pca = [target.unsqueeze(0)] + [_.unsqueeze(0) for _ in pred]
|
| 37 |
+
reduced_feats, _ = pca(feats_for_pca) # pca outputs a list of reduced tensors
|
| 38 |
+
|
| 39 |
+
target_imgs = reduced_feats[0]
|
| 40 |
+
pred_imgs = reduced_feats[1:]
|
| 41 |
+
|
| 42 |
+
# --- Plot ---
|
| 43 |
+
# Determine number of columns based on whether image is provided
|
| 44 |
+
n_cols = (1 if image is not None else 0) + 1 + len(pred_imgs)
|
| 45 |
+
fig, ax = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
|
| 46 |
+
# Reduce space between images
|
| 47 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
| 48 |
+
|
| 49 |
+
# Handle single subplot case
|
| 50 |
+
if n_cols == 1:
|
| 51 |
+
ax = [ax]
|
| 52 |
+
|
| 53 |
+
# Current axis index
|
| 54 |
+
ax_idx = 0
|
| 55 |
+
|
| 56 |
+
# Plot original image if provided
|
| 57 |
+
if image is not None:
|
| 58 |
+
if image.dim() == 3:
|
| 59 |
+
ax[ax_idx].imshow(image.permute(1, 2, 0).detach().cpu())
|
| 60 |
+
elif image.dim() == 2:
|
| 61 |
+
ax[ax_idx].imshow(image.detach().cpu(), cmap="inferno")
|
| 62 |
+
if show_legend:
|
| 63 |
+
ax[ax_idx].set_title(legend[0], fontsize=font_size)
|
| 64 |
+
ax_idx += 1
|
| 65 |
+
|
| 66 |
+
# Plot the low-resolution features or segmentation mask
|
| 67 |
+
ax[ax_idx].imshow(target_imgs[0].permute(1, 2, 0).detach().cpu())
|
| 68 |
+
if show_legend:
|
| 69 |
+
legend_idx = 1 if image is not None else 0
|
| 70 |
+
ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
|
| 71 |
+
ax_idx += 1
|
| 72 |
+
|
| 73 |
+
# Plot HR features or segmentation masks
|
| 74 |
+
for idx, pred_img in enumerate(pred_imgs):
|
| 75 |
+
ax[ax_idx].imshow(pred_img[0].permute(1, 2, 0).detach().cpu())
|
| 76 |
+
if show_legend:
|
| 77 |
+
legend_idx = (2 if image is not None else 1) + idx
|
| 78 |
+
if len(legend) > legend_idx:
|
| 79 |
+
ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
|
| 80 |
+
else:
|
| 81 |
+
ax[ax_idx].set_title(f"HR Features {idx}", fontsize=font_size)
|
| 82 |
+
ax_idx += 1
|
| 83 |
+
|
| 84 |
+
remove_axes(ax)
|
| 85 |
+
|
| 86 |
+
# Handle return_array case
|
| 87 |
+
if return_array:
|
| 88 |
+
# Turn off interactive mode temporarily
|
| 89 |
+
was_interactive = plt.isinteractive()
|
| 90 |
+
plt.ioff()
|
| 91 |
+
|
| 92 |
+
# Convert figure to numpy array
|
| 93 |
+
buf = io.BytesIO()
|
| 94 |
+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
| 95 |
+
buf.seek(0)
|
| 96 |
+
|
| 97 |
+
# Convert to PIL Image then to numpy array
|
| 98 |
+
pil_img = Image.open(buf)
|
| 99 |
+
img_array = np.array(pil_img)
|
| 100 |
+
|
| 101 |
+
# Close the figure and buffer
|
| 102 |
+
plt.close(fig)
|
| 103 |
+
buf.close()
|
| 104 |
+
|
| 105 |
+
# Restore interactive mode if it was on
|
| 106 |
+
if was_interactive:
|
| 107 |
+
plt.ion()
|
| 108 |
+
|
| 109 |
+
return img_array
|
| 110 |
+
|
| 111 |
+
# Standard behavior: save and/or show
|
| 112 |
+
if save_path is not None:
|
| 113 |
+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
| 114 |
+
plt.show()
|
| 115 |
+
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def remove_axes(axes):
|
| 120 |
+
def _remove_axes(ax):
|
| 121 |
+
ax.xaxis.set_major_formatter(plt.NullFormatter())
|
| 122 |
+
ax.yaxis.set_major_formatter(plt.NullFormatter())
|
| 123 |
+
ax.set_xticks([])
|
| 124 |
+
ax.set_yticks([])
|
| 125 |
+
|
| 126 |
+
if len(axes.shape) == 2:
|
| 127 |
+
for ax1 in axes:
|
| 128 |
+
for ax in ax1:
|
| 129 |
+
_remove_axes(ax)
|
| 130 |
+
else:
|
| 131 |
+
for ax in axes:
|
| 132 |
+
_remove_axes(ax)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None):
|
| 136 |
+
target_size = None
|
| 137 |
+
if len(image_feats_list) > 1 and fit_pca is None:
|
| 138 |
+
target_size = image_feats_list[0].shape[2]
|
| 139 |
+
|
| 140 |
+
def flatten(tensor, target_size=None):
|
| 141 |
+
B, C, H, W = tensor.shape
|
| 142 |
+
assert B == 1, "Batch size should be 1 for PCA flattening"
|
| 143 |
+
if target_size is not None:
|
| 144 |
+
tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear", align_corners=False)
|
| 145 |
+
return rearrange(tensor, "b c h w -> (b h w) c").detach().cpu()
|
| 146 |
+
|
| 147 |
+
flattened_feats = []
|
| 148 |
+
for feats in image_feats_list:
|
| 149 |
+
flattened_feats.append(flatten(feats, target_size))
|
| 150 |
+
x = torch.cat(flattened_feats, dim=0)
|
| 151 |
+
|
| 152 |
+
# Subsample the data if max_samples is set and the number of samples exceeds max_samples
|
| 153 |
+
if max_samples is not None and x.shape[0] > max_samples:
|
| 154 |
+
indices = torch.randperm(x.shape[0])[:max_samples]
|
| 155 |
+
x = x[indices]
|
| 156 |
+
|
| 157 |
+
if fit_pca is None:
|
| 158 |
+
fit_pca = TorchPCA(n_components=dim).fit(x)
|
| 159 |
+
|
| 160 |
+
reduced_feats = []
|
| 161 |
+
for feats in image_feats_list:
|
| 162 |
+
B, C, H, W = feats.shape
|
| 163 |
+
x_red = fit_pca.transform(flatten(feats))
|
| 164 |
+
if isinstance(x_red, np.ndarray):
|
| 165 |
+
x_red = torch.from_numpy(x_red)
|
| 166 |
+
x_red -= x_red.min(dim=0, keepdim=True).values
|
| 167 |
+
x_red /= x_red.max(dim=0, keepdim=True).values
|
| 168 |
+
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2))
|
| 169 |
+
|
| 170 |
+
return reduced_feats, fit_pca
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class TorchPCA(object):
|
| 174 |
+
|
| 175 |
+
def __init__(self, n_components, skip=0):
|
| 176 |
+
self.n_components = n_components
|
| 177 |
+
self.skip = skip
|
| 178 |
+
|
| 179 |
+
def fit(self, X):
|
| 180 |
+
self.mean_ = X.mean(dim=0)
|
| 181 |
+
unbiased = X - self.mean_
|
| 182 |
+
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=20)
|
| 183 |
+
self.components_ = V[:, self.skip :]
|
| 184 |
+
self.singular_values_ = S
|
| 185 |
+
return self
|
| 186 |
+
|
| 187 |
+
def transform(self, X):
|
| 188 |
+
t0 = X - self.mean_.unsqueeze(0)
|
| 189 |
+
projected = t0 @ self.components_
|
| 190 |
+
return projected
|