LChambon commited on
Commit
e4c8837
·
1 Parent(s): 6e54552

initial commit

Browse files
Files changed (10) hide show
  1. .gitignore +3 -0
  2. DEPLOYMENT.md +139 -0
  3. README.md +76 -6
  4. README_SPACE.md +84 -0
  5. app.py +251 -0
  6. deploy_to_hf.sh +109 -0
  7. requirements.txt +22 -0
  8. src/backbone/vit_wrapper.py +180 -0
  9. utils/training.py +231 -0
  10. utils/visualization.py +190 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ .env
DEPLOYMENT.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploying NAF Demo to Hugging Face Spaces
2
+
3
+ ## Quick Setup
4
+
5
+ ### 1. Create a Hugging Face Space
6
+
7
+ 1. Go to [https://huggingface.co/spaces](https://huggingface.co/spaces)
8
+ 2. Click "Create new Space"
9
+ 3. Configure your Space:
10
+ - **Space name**: `naf-feature-upsampling` (or your choice)
11
+ - **License**: Apache 2.0 (or your choice)
12
+ - **Select the SDK**: Gradio
13
+ - **Space hardware**: CPU Basic (free) or GPU (T4 small recommended for faster inference)
14
+ - **Visibility**: Public or Private
15
+
16
+ ### 2. Required Files
17
+
18
+ Upload these files to your Hugging Face Space:
19
+
20
+ ```
21
+ your-space/
22
+ ├── app.py # Main application
23
+ ├── requirements.txt # Python dependencies
24
+ ├── README.md # Space documentation
25
+ ├── src/
26
+ │ └── backbone/
27
+ │ └── vit_wrapper.py # Backbone wrapper
28
+ └── utils/
29
+ ├── visualization.py # Visualization utilities
30
+ └── training.py # Training utilities (for round_to_nearest_multiple)
31
+ ```
32
+
33
+ ### 3. Clone Your Space Repository
34
+
35
+ ```bash
36
+ # Install git-lfs if not already installed
37
+ git lfs install
38
+
39
+ # Clone your space
40
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
41
+ cd YOUR_SPACE_NAME
42
+
43
+ # Copy files from your local NAF project
44
+ cp /home/lchambon/workspace/NAF_off/app.py .
45
+ cp /home/lchambon/workspace/NAF_off/requirements_demo.txt ./requirements.txt
46
+
47
+ # Copy source files
48
+ mkdir -p src/backbone utils
49
+ cp /home/lchambon/workspace/NAF_off/src/backbone/vit_wrapper.py src/backbone/
50
+ cp /home/lchambon/workspace/NAF_off/utils/visualization.py utils/
51
+ cp /home/lchambon/workspace/NAF_off/utils/training.py utils/
52
+ cp /home/lchambon/workspace/NAF_off/utils/img.py utils/ # If needed
53
+
54
+ # Copy sample images (optional)
55
+ mkdir -p asset
56
+ cp /home/lchambon/workspace/NAF_off/asset/*.png asset/
57
+ cp /home/lchambon/workspace/NAF_off/asset/*.jpg asset/
58
+
59
+ # Add all files
60
+ git add .
61
+ git commit -m "Initial commit: NAF Feature Upsampling Demo"
62
+ git push
63
+ ```
64
+
65
+ ### 4. Alternative: Use Hugging Face Web Interface
66
+
67
+ 1. Navigate to your Space's "Files" tab
68
+ 2. Click "Add file" → "Upload files"
69
+ 3. Upload all required files maintaining the directory structure
70
+ 4. Commit changes
71
+
72
+ ### 5. Monitor Deployment
73
+
74
+ - Once pushed, Hugging Face will automatically build your Space
75
+ - Check the "Logs" tab to monitor the build process
76
+ - The Space will be available at: `https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME`
77
+
78
+ ## Hardware Recommendations
79
+
80
+ - **CPU Basic (free)**: Works but slower inference (~10-30s per image)
81
+ - **T4 Small GPU**: Recommended for better performance (~2-5s per image)
82
+ - **T4 Medium/Large GPU**: For handling multiple concurrent users
83
+
84
+ ## Important Notes
85
+
86
+ ### Model Loading
87
+ The NAF model is loaded from torch.hub, which will download it on first run:
88
+ ```python
89
+ model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device)
90
+ ```
91
+
92
+ ### Memory Considerations
93
+ - Backbone models are loaded on-demand per request
94
+ - Consider caching popular models if you upgrade to persistent storage
95
+ - GPU spaces have more memory for handling larger images
96
+
97
+ ### Sample Images
98
+ - Upload sample images to the `asset/` folder
99
+ - Update `SAMPLE_IMAGES` list in `app.py` to match available images
100
+ - Or remove the examples section if images aren't available
101
+
102
+ ## Troubleshooting
103
+
104
+ ### Build Failures
105
+ - Check "Logs" tab for error messages
106
+ - Ensure all dependencies are in `requirements.txt`
107
+ - Verify Python version compatibility (3.8-3.10 recommended)
108
+
109
+ ### Import Errors
110
+ - Make sure `src/` and `utils/` directories are included
111
+ - Check that `__init__.py` files exist if needed
112
+ - Verify relative imports are correct
113
+
114
+ ### Memory Issues
115
+ - Reduce max resolution if needed
116
+ - Consider using CPU-only mode for free tier
117
+ - Upgrade to GPU hardware if processing large images
118
+
119
+ ## Updating Your Space
120
+
121
+ ```bash
122
+ # Make changes to your files
123
+ git add .
124
+ git commit -m "Update: description of changes"
125
+ git push
126
+ ```
127
+
128
+ Hugging Face will automatically rebuild and redeploy your Space.
129
+
130
+ ## Making Your Space Public
131
+
132
+ 1. Go to Space Settings
133
+ 2. Change visibility to "Public"
134
+ 3. Add a good README.md with description and usage instructions
135
+ 4. Consider adding a thumbnail image
136
+
137
+ ## Example Space README
138
+
139
+ See `README_SPACE.md` for a template README to use on Hugging Face Spaces.
README.md CHANGED
@@ -1,14 +1,84 @@
1
  ---
2
- title: NAF
3
- emoji: 📊
4
  colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: 'NAF: Zero-Shot Feature Upsampling via Neighborhood Attention'
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: NAF Zero-Shot Feature Upsampling
3
+ emoji: 🎯
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.50.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
+ # 🎯 NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering
14
+
15
+ This Space demonstrates **NAF (Neighborhood Attention Filtering)**, a method for upsampling features from Vision Foundation Models to any resolution without model-specific training.
16
+
17
+ ## 🚀 Features
18
+
19
+ - **Universal Upsampling**: Works with any Vision Foundation Model (DINOv2, DINOv3, RADIO, DINO, SigLIP, etc.)
20
+ - **Arbitrary Resolutions**: Upsample features to any target resolution while maintaining aspect ratio
21
+ - **Zero-Shot**: No model-specific training or fine-tuning required
22
+ - **Interactive Demo**: Upload your own images or try sample images from various domains
23
+
24
+ ## 🎨 How to Use
25
+
26
+ 1. **Upload an Image**: Click "Upload Your Image" or select from sample images
27
+ 2. **Choose a Model**: Select a Vision Foundation Model from the dropdown
28
+ 3. **Set Resolution**: Choose the target resolution for upsampled features (64-512)
29
+ 4. **Click "Upsample Features"**: See the comparison between low and high-resolution features
30
+
31
+ ## 📊 Visualization
32
+
33
+ The output shows three panels:
34
+ - **Left**: Your input image
35
+ - **Center**: Low-resolution features from the backbone (PCA visualization)
36
+ - **Right**: High-resolution features upsampled by NAF
37
+
38
+ Features are visualized using PCA for the first 3 principal components as RGB channels.
39
+
40
+ ## 🔬 Supported Models
41
+
42
+ - **DINOv3**: Latest self-supervised vision models
43
+ - **RADIO v2.5**: High-performance vision backbones
44
+ - **DINOv2**: Self-supervised learning with registers
45
+ - **DINO**: Original self-supervised ViT
46
+ - **SigLIP**: Contrastive vision-language models
47
+
48
+ ## 📖 Learn More
49
+
50
+ - **Paper**: [NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering](https://arxiv.org/abs/2501.01535)
51
+ - **Code**: [GitHub Repository](https://github.com/valeoai/NAF)
52
+ - **Organization**: [Valeo.ai](https://www.valeo.com/en/valeo-ai/)
53
+
54
+ ## 💡 Use Cases
55
+
56
+ NAF enables better feature representations for:
57
+ - Dense prediction tasks (segmentation, depth estimation)
58
+ - High-resolution visual understanding
59
+ - Feature matching and correspondence
60
+ - Vision-language alignment
61
+
62
+ ## ⚙️ Technical Details
63
+
64
+ - **Input**: Images up to 512px (maintains aspect ratio)
65
+ - **Processing**: Backbone feature extraction → NAF upsampling
66
+ - **Output**: High-resolution features at target resolution
67
+ - **Device**: Runs on CPU (free tier) or GPU (faster inference)
68
+
69
+ ## 🤝 Citation
70
+
71
+ If you use NAF in your research, please cite:
72
+
73
+ ```bibtex
74
+ @article{chambon2025naf,
75
+ title={NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering},
76
+ author={Chambon, Lucas and others},
77
+ journal={arXiv preprint arXiv:2501.01535},
78
+ year={2025}
79
+ }
80
+ ```
81
+
82
+ ## 📜 License
83
+
84
+ This demo is released under the Apache 2.0 license.
README_SPACE.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NAF Zero-Shot Feature Upsampling
3
+ emoji: 🎯
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.50.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # 🎯 NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering
14
+
15
+ This Space demonstrates **NAF (Neighborhood Attention Filtering)**, a method for upsampling features from Vision Foundation Models to any resolution without model-specific training.
16
+
17
+ ## 🚀 Features
18
+
19
+ - **Universal Upsampling**: Works with any Vision Foundation Model (DINOv2, DINOv3, RADIO, DINO, SigLIP, etc.)
20
+ - **Arbitrary Resolutions**: Upsample features to any target resolution while maintaining aspect ratio
21
+ - **Zero-Shot**: No model-specific training or fine-tuning required
22
+ - **Interactive Demo**: Upload your own images or try sample images from various domains
23
+
24
+ ## 🎨 How to Use
25
+
26
+ 1. **Upload an Image**: Click "Upload Your Image" or select from sample images
27
+ 2. **Choose a Model**: Select a Vision Foundation Model from the dropdown
28
+ 3. **Set Resolution**: Choose the target resolution for upsampled features (64-512)
29
+ 4. **Click "Upsample Features"**: See the comparison between low and high-resolution features
30
+
31
+ ## 📊 Visualization
32
+
33
+ The output shows three panels:
34
+ - **Left**: Your input image
35
+ - **Center**: Low-resolution features from the backbone (PCA visualization)
36
+ - **Right**: High-resolution features upsampled by NAF
37
+
38
+ Features are visualized using PCA for the first 3 principal components as RGB channels.
39
+
40
+ ## 🔬 Supported Models
41
+
42
+ - **DINOv3**: Latest self-supervised vision models
43
+ - **RADIO v2.5**: High-performance vision backbones
44
+ - **DINOv2**: Self-supervised learning with registers
45
+ - **DINO**: Original self-supervised ViT
46
+ - **SigLIP**: Contrastive vision-language models
47
+
48
+ ## 📖 Learn More
49
+
50
+ - **Paper**: [NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering](https://arxiv.org/abs/2501.01535)
51
+ - **Code**: [GitHub Repository](https://github.com/valeoai/NAF)
52
+ - **Organization**: [Valeo.ai](https://www.valeo.com/en/valeo-ai/)
53
+
54
+ ## 💡 Use Cases
55
+
56
+ NAF enables better feature representations for:
57
+ - Dense prediction tasks (segmentation, depth estimation)
58
+ - High-resolution visual understanding
59
+ - Feature matching and correspondence
60
+ - Vision-language alignment
61
+
62
+ ## ⚙️ Technical Details
63
+
64
+ - **Input**: Images up to 512px (maintains aspect ratio)
65
+ - **Processing**: Backbone feature extraction → NAF upsampling
66
+ - **Output**: High-resolution features at target resolution
67
+ - **Device**: Runs on CPU (free tier) or GPU (faster inference)
68
+
69
+ ## 🤝 Citation
70
+
71
+ If you use NAF in your research, please cite:
72
+
73
+ ```bibtex
74
+ @article{chambon2025naf,
75
+ title={NAF: Zero-Shot Feature Upsampling via Neighborhood Attention Filtering},
76
+ author={Chambon, Lucas and others},
77
+ journal={arXiv preprint arXiv:2501.01535},
78
+ year={2025}
79
+ }
80
+ ```
81
+
82
+ ## 📜 License
83
+
84
+ This demo is released under the Apache 2.0 license.
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+
13
+ # Add project root to path
14
+ sys.path.append(str(Path(__file__).parent))
15
+ from src.backbone.vit_wrapper import PretrainedViTWrapper
16
+ from utils.training import round_to_nearest_multiple
17
+ from utils.visualization import plot_feats
18
+
19
+ # Load NAF model
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device)
22
+ model.eval()
23
+
24
+ # Normalization for upsampling
25
+ ups_norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+
27
+ # Sample images
28
+ SAMPLE_IMAGES = [
29
+ "asset/Cartoon.png",
30
+ "asset/Natural.png",
31
+ "asset/Satellite.png",
32
+ "asset/Medical.png",
33
+ "asset/Ecosystems.png",
34
+ "asset/Driving.jpg",
35
+ "asset/Manufacturing.png",
36
+ ]
37
+
38
+
39
+ def resize_with_aspect_ratio(img, max_size, patch_size):
40
+ """Resize image maintaining aspect ratio with max dimension and patch size constraints"""
41
+ w, h = img.size
42
+
43
+ # Calculate scaling factor to fit within max_size
44
+ scale = min(max_size / w, max_size / h)
45
+ new_w = int(w * scale)
46
+ new_h = int(h * scale)
47
+
48
+ # Round to nearest patch size multiple
49
+ new_w = round_to_nearest_multiple(new_w, patch_size)
50
+ new_h = round_to_nearest_multiple(new_h, patch_size)
51
+
52
+ # Ensure minimum size
53
+ new_w = max(new_w, patch_size)
54
+ new_h = max(new_h, patch_size)
55
+
56
+ return new_w, new_h
57
+
58
+
59
+ @torch.no_grad()
60
+ def process_image(image, model_name, output_resolution):
61
+ """Process image with selected model and resolution"""
62
+ try:
63
+ # Load the backbone using vit_wrapper
64
+ backbone = PretrainedViTWrapper(model_name, norm=True).to(device)
65
+ backbone.eval()
66
+
67
+ # Get model config for normalization and input size
68
+ mean = backbone.config["mean"]
69
+ std = backbone.config["std"]
70
+ patch_size = backbone.patch_size
71
+ back_norm = T.Normalize(mean=mean, std=std)
72
+
73
+ # Prepare image at model's expected resolution
74
+ img = PIL.Image.fromarray(image).convert("RGB")
75
+ new_w, new_h = resize_with_aspect_ratio(img, max_size=512, patch_size=patch_size)
76
+
77
+ transform = T.Compose(
78
+ [
79
+ T.Resize((new_h, new_w)),
80
+ T.ToTensor(),
81
+ ]
82
+ )
83
+ img_tensor = transform(img).unsqueeze(0).to(device)
84
+
85
+ # Normalize for backbone
86
+ img_back = back_norm(img_tensor)
87
+ lr_feats = backbone(img_back)
88
+
89
+ # vit_wrapper already returns features in [B, C, H, W] format
90
+ if not isinstance(lr_feats, torch.Tensor):
91
+ raise ValueError(f"Unexpected feature type: {type(lr_feats)}")
92
+
93
+ if len(lr_feats.shape) != 4:
94
+ raise ValueError(f"Unexpected feature shape: {lr_feats.shape}. Expected [B, C, H, W].")
95
+
96
+ # Normalize for upsampling
97
+ img_ups = ups_norm(img_tensor)
98
+
99
+ # Calculate output resolution maintaining aspect ratio
100
+ _, _, h, w = lr_feats.shape
101
+ aspect_ratio = w / h
102
+ if aspect_ratio > 1: # Width > Height
103
+ out_h = round_to_nearest_multiple(int(output_resolution / aspect_ratio), patch_size)
104
+ out_w = output_resolution
105
+ else: # Height >= Width
106
+ out_h = output_resolution
107
+ out_w = round_to_nearest_multiple(int(output_resolution * aspect_ratio), patch_size)
108
+
109
+ upsampled_feats = model(img_ups, lr_feats, (out_h, out_w))
110
+
111
+ # Create visualization using plot_feats
112
+ plot_feats(
113
+ img_tensor[0],
114
+ lr_feats[0],
115
+ [upsampled_feats[0]],
116
+ legend=["Image", f"Low-Res: {h}x{w}", f"High-Res: {out_h}x{out_w}"],
117
+ font_size=14,
118
+ )
119
+
120
+ # Convert matplotlib figure to PIL Image
121
+ fig = plt.gcf() # Get current figure
122
+ buf = io.BytesIO()
123
+ fig.savefig(buf, format="png", dpi=100, bbox_inches="tight")
124
+ buf.seek(0)
125
+ result_img = PIL.Image.open(buf)
126
+ plt.close(fig)
127
+
128
+ return result_img
129
+
130
+ except Exception as e:
131
+ print(f"Error processing image: {e}")
132
+ import traceback
133
+
134
+ traceback.print_exc()
135
+ return None
136
+
137
+
138
+ # Popular vision models for the dropdown (from vit_wrapper.py)
139
+ POPULAR_MODELS = [
140
+ "vit_base_patch16_dinov3.lvd1689m",
141
+ "radio_v2.5-b",
142
+ "vit_base_patch14_reg4_dinov2",
143
+ "vit_base_patch14_dinov2.lvd142m",
144
+ "vit_base_patch16_224.dino",
145
+ "vit_base_patch16_siglip_512.v2_webli",
146
+ ]
147
+
148
+ # Create Gradio interface
149
+ with gr.Blocks(title="NAF: Zero-Shot Feature Upsampling") as demo:
150
+ gr.HTML(
151
+ """
152
+ <div style="text-align: center; margin-bottom: 2rem;">
153
+ <h1 class="title-text" style="font-size: 3rem; margin-bottom: 0.5rem;">
154
+ 🎯 NAF: Zero-Shot Feature Upsampling
155
+ </h1>
156
+ <p style="font-size: 1.2rem; color: #666; margin-bottom: 1rem;">
157
+ via Neighborhood Attention Filtering
158
+ </p>
159
+ <div class="info-box" style="max-width: 900px; margin: 0 auto;">
160
+ <p style="font-size: 1.1rem; margin-bottom: 0.8rem;">
161
+ 🚀 <strong>Upsample features from any Vision Foundation Model to any resolution!</strong>
162
+ </p>
163
+ <p style="font-size: 0.95rem; margin: 0;">
164
+ Upload an image, select a model, choose your target resolution, and see NAF in action.
165
+ </p>
166
+ </div>
167
+ </div>
168
+ """
169
+ )
170
+
171
+ with gr.Row():
172
+ with gr.Column(scale=1):
173
+ gr.Markdown("### 📤 Input Configuration")
174
+
175
+ image_input = gr.Image(label="Upload Your Image", type="numpy")
176
+
177
+ # Sample images
178
+ if any(Path(p).exists() for p in SAMPLE_IMAGES):
179
+ gr.Examples(
180
+ examples=[[p] for p in SAMPLE_IMAGES if Path(p).exists()],
181
+ inputs=image_input,
182
+ label="🖼️ Try Sample Images",
183
+ examples_per_page=4,
184
+ )
185
+
186
+ gr.Markdown("### ⚙️ Model Settings")
187
+
188
+ model_dropdown = gr.Dropdown(
189
+ choices=POPULAR_MODELS,
190
+ value=POPULAR_MODELS[0],
191
+ label="🤖 Vision Foundation Model",
192
+ )
193
+
194
+ resolution_slider = gr.Slider(
195
+ minimum=64,
196
+ maximum=512,
197
+ step=64,
198
+ value=448,
199
+ label="📏 Output Resolution (max dimension)",
200
+ )
201
+
202
+ process_btn = gr.Button("✨ Upsample Features", variant="primary")
203
+
204
+ with gr.Column(scale=2):
205
+ gr.Markdown("### 🎨 Visualization Results")
206
+ output_image = gr.Image(label="Feature Comparison", type="pil")
207
+
208
+ gr.Markdown(
209
+ """
210
+ <div style="background: #f0f7ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #667eea;">
211
+ <strong>📊 Visualization Guide:</strong>
212
+ <ul style="margin: 0.5rem 0;">
213
+ <li><strong>Left:</strong> Original input image</li>
214
+ <li><strong>Center:</strong> Low-resolution features (PCA visualization)</li>
215
+ <li><strong>Right:</strong> High-resolution features upsampled by NAF</li>
216
+ </ul>
217
+ <p style="margin-top: 0.5rem; font-size: 0.9rem; color: #555;">
218
+ <em>Note: Output features maintain the aspect ratio of the input image.</em>
219
+ </p>
220
+ </div>
221
+ """
222
+ )
223
+
224
+ process_btn.click(fn=process_image, inputs=[image_input, model_dropdown, resolution_slider], outputs=output_image)
225
+
226
+ gr.Markdown(
227
+ """
228
+ ---
229
+ <div style="text-align: center; padding: 2rem 0;">
230
+ <h3 style="color: #667eea;">💡 About NAF</h3>
231
+ <p style="max-width: 800px; margin: 1rem auto; font-size: 1.05rem; color: #555;">
232
+ NAF enables <strong>zero-shot feature upsampling</strong> from any Vision Foundation Model
233
+ to any resolution. It learns to filter and combine features using neighborhood attention,
234
+ without requiring model-specific training.
235
+ </p>
236
+ <div style="margin-top: 1.5rem;">
237
+ <a href="https://github.com/valeoai/NAF" target="_blank"
238
+ style="margin: 0 1rem; text-decoration: none; color: #667eea; font-weight: bold;">
239
+ 📦 GitHub Repository
240
+ </a>
241
+ <a href="https://arxiv.org/abs/2501.01535" target="_blank"
242
+ style="margin: 0 1rem; text-decoration: none; color: #667eea; font-weight: bold;">
243
+ 📄 Research Paper
244
+ </a>
245
+ </div>
246
+ </div>
247
+ """
248
+ )
249
+
250
+ if __name__ == "__main__":
251
+ demo.launch()
deploy_to_hf.sh ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Deploy NAF Demo to Hugging Face Spaces
4
+ # Usage: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME
5
+
6
+ set -e
7
+
8
+ if [ "$#" -ne 2 ]; then
9
+ echo "Usage: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME"
10
+ echo "Example: ./deploy_to_hf.sh myusername naf-demo"
11
+ exit 1
12
+ fi
13
+
14
+ USERNAME=$1
15
+ SPACE_NAME=$2
16
+ SPACE_URL="https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME}"
17
+
18
+ echo "🚀 Deploying NAF Demo to Hugging Face Spaces"
19
+ echo "Space URL will be: ${SPACE_URL}"
20
+ echo ""
21
+
22
+ # Check if git-lfs is installed
23
+ if ! command -v git-lfs &> /dev/null; then
24
+ echo "⚠️ git-lfs is not installed. Installing..."
25
+ git lfs install
26
+ fi
27
+
28
+ # Create temporary directory
29
+ TEMP_DIR=$(mktemp -d)
30
+ echo "📁 Created temporary directory: ${TEMP_DIR}"
31
+
32
+ # Clone the space
33
+ echo "📥 Cloning space repository..."
34
+ git clone https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME} ${TEMP_DIR}
35
+ cd ${TEMP_DIR}
36
+
37
+ # Copy main files
38
+ echo "📋 Copying files..."
39
+ cp /home/lchambon/workspace/NAF_off/app.py .
40
+ cp /home/lchambon/workspace/NAF_off/requirements_demo.txt ./requirements.txt
41
+ cp /home/lchambon/workspace/NAF_off/README_SPACE.md ./README.md
42
+
43
+ # Create directory structure
44
+ mkdir -p src/backbone utils asset
45
+
46
+ # Copy source files
47
+ cp /home/lchambon/workspace/NAF_off/src/backbone/vit_wrapper.py src/backbone/
48
+ cp /home/lchambon/workspace/NAF_off/utils/visualization.py utils/
49
+ cp /home/lchambon/workspace/NAF_off/utils/training.py utils/
50
+
51
+ # Copy utils/__init__.py if it exists
52
+ if [ -f /home/lchambon/workspace/NAF_off/utils/__init__.py ]; then
53
+ cp /home/lchambon/workspace/NAF_off/utils/__init__.py utils/
54
+ else
55
+ touch utils/__init__.py
56
+ fi
57
+
58
+ # Copy src/__init__.py if it exists
59
+ if [ -f /home/lchambon/workspace/NAF_off/src/__init__.py ]; then
60
+ cp /home/lchambon/workspace/NAF_off/src/__init__.py src/
61
+ else
62
+ touch src/__init__.py
63
+ fi
64
+
65
+ # Copy src/backbone/__init__.py if it exists
66
+ if [ -f /home/lchambon/workspace/NAF_off/src/backbone/__init__.py ]; then
67
+ cp /home/lchambon/workspace/NAF_off/src/backbone/__init__.py src/backbone/
68
+ else
69
+ touch src/backbone/__init__.py
70
+ fi
71
+
72
+ # Copy sample images if they exist
73
+ echo "🖼️ Copying sample images..."
74
+ if [ -d /home/lchambon/workspace/NAF_off/asset ]; then
75
+ cp /home/lchambon/workspace/NAF_off/asset/*.png asset/ 2>/dev/null || true
76
+ cp /home/lchambon/workspace/NAF_off/asset/*.jpg asset/ 2>/dev/null || true
77
+ fi
78
+
79
+ # Add .gitignore
80
+ cat > .gitignore << EOF
81
+ __pycache__/
82
+ *.py[cod]
83
+ *$py.class
84
+ *.so
85
+ .Python
86
+ .env
87
+ .venv
88
+ *.egg-info/
89
+ .DS_Store
90
+ EOF
91
+
92
+ # Git operations
93
+ echo "📤 Pushing to Hugging Face..."
94
+ git add .
95
+ git commit -m "Deploy NAF Feature Upsampling Demo"
96
+ git push
97
+
98
+ echo ""
99
+ echo "✅ Deployment complete!"
100
+ echo "🌐 Your Space will be available at: ${SPACE_URL}"
101
+ echo "⏳ It may take a few minutes to build..."
102
+ echo ""
103
+ echo "To check build status:"
104
+ echo " - Visit: ${SPACE_URL}"
105
+ echo " - Click on 'Logs' tab"
106
+
107
+ # Cleanup
108
+ cd -
109
+ rm -rf ${TEMP_DIR}
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.0
2
+ numpy==1.24.4
3
+ timm==1.0.22
4
+ plotly==6.0.0
5
+ tensorboard==2.20.0
6
+ hydra-core==1.3.2
7
+ matplotlib==3.7.0
8
+ rich==14.2.0
9
+ torchmetrics==1.6.2
10
+ scipy==1.15.2
11
+ kornia==0.8.2
12
+ ipykernel
13
+ ipympl
14
+ pytest
15
+
16
+ # Torch + CUDA 11.8 (HuggingFace compatible install)
17
+ torch==2.4.0
18
+ torchvision==0.19.0
19
+ --extra-index-url https://download.pytorch.org/whl/cu118
20
+
21
+ # NATTEN wheels
22
+ natten==0.17.4+torch240cu118 -f https://shi-labs.com/natten/wheels/
src/backbone/vit_wrapper.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import types
3
+ from typing import List, Tuple, Union
4
+
5
+ import timm
6
+ import timm.data
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from timm.models.vision_transformer import VisionTransformer
11
+ from torch import nn
12
+ from torchvision import transforms
13
+
14
+ # We provide a list of timm model names, more are available on their official repo.
15
+ MODEL_LIST = [
16
+ # DINO
17
+ "vit_base_patch16_224.dino",
18
+ # DINOv2
19
+ "vit_base_patch14_dinov2.lvd142m",
20
+ # DINOv2-R
21
+ "vit_base_patch14_reg4_dinov2",
22
+ # Franca
23
+ "franca_vitb14",
24
+ # DINOv3-ViT
25
+ "vit_base_patch16_dinov3.lvd1689m",
26
+ "vit_large_patch16_dinov3.lvd1689m",
27
+ "vit_7b_patch16_dinov3.lvd1689m",
28
+ # SigLIP2
29
+ "vit_base_patch16_siglip_512.v2_webli",
30
+ # PE Core
31
+ "vit_pe_core_small_patch16_384.fb",
32
+ # PE Spatial
33
+ "vit_pe_spatial_tiny_patch16_512.fb",
34
+ # RADIO
35
+ "radio_v2.5-b",
36
+ # CAPI
37
+ "capi_vitl14_lvd",
38
+ # MAE
39
+ "vit_large_patch16_224.mae",
40
+ ]
41
+
42
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ class PretrainedViTWrapper(nn.Module):
47
+
48
+ def __init__(
49
+ self,
50
+ name,
51
+ norm: bool = True,
52
+ dynamic_img_size: bool = True,
53
+ dynamic_img_pad: bool = False,
54
+ **kwargs,
55
+ ):
56
+ super().__init__()
57
+ # comment out the following line to test the models not in the list
58
+ self.name = name
59
+
60
+ load_weights = False
61
+ if "dvt_" == name[:4]:
62
+ load_weights = True
63
+ load_tag = "dvt"
64
+ name = name.replace("dvt_", "")
65
+ if "fit3d_" == name[:6]:
66
+ load_weights = True
67
+ load_tag = "fit3d"
68
+ name = name.replace("fit3d_", "")
69
+
70
+ # Set patch size
71
+ try:
72
+ self.patch_size = int(re.search(r"patch(\d+)", name).group(1))
73
+ except:
74
+ self.patch_size = 16
75
+ if "franca" in name or "capi" in name:
76
+ self.patch_size = 14
77
+ if "convnext" in name:
78
+ self.patch_size = 32
79
+
80
+ name, self.patch_size
81
+
82
+ self.dynamic_img_size = dynamic_img_size
83
+ self.dynamic_img_pad = dynamic_img_pad
84
+ self.model, self.config = self.create_model(name, **kwargs)
85
+ self.config["ps"] = self.patch_size
86
+ self.embed_dim = self.model.embed_dim
87
+ self.norm = norm
88
+
89
+ if load_weights:
90
+ ckpt = torch.load(f"/home/lchambon/workspace/JAFAR/ckpts/{load_tag}_{name}.pth", map_location="cpu")
91
+ if load_tag == "dvt":
92
+ self.load_state_dict(ckpt["model"], strict=True)
93
+ elif load_tag == "fit3d":
94
+ self.model.load_state_dict(ckpt, strict=True)
95
+
96
+ def create_model(self, name: str, **kwargs) -> Tuple[VisionTransformer, transforms.Compose]:
97
+ if "radio" in self.name:
98
+ model = torch.hub.load(
99
+ "NVlabs/RADIO",
100
+ "radio_model",
101
+ version=name,
102
+ progress=True,
103
+ skip_validation=True,
104
+ )
105
+ data_config = {
106
+ "mean": torch.tensor([0.0, 0.0, 0.0]),
107
+ "std": torch.tensor([1.0, 1.0, 1.0]),
108
+ "input_size": (3, 512, 512),
109
+ }
110
+
111
+ elif "franca" in self.name:
112
+ model = torch.hub.load("valeoai/Franca", name, use_rasa_head=True)
113
+ data_config = {"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "input_size": (3, 448, 448)}
114
+
115
+ elif "capi" in self.name:
116
+ model = torch.hub.load("facebookresearch/capi:main", name, force_reload=False)
117
+
118
+ data_config = {"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "input_size": (3, 448, 448)}
119
+
120
+ else:
121
+ timm_kwargs = dict(
122
+ pretrained=True,
123
+ num_classes=0,
124
+ patch_size=self.patch_size,
125
+ )
126
+
127
+ if "sam" not in self.name and "convnext" not in self.name:
128
+ timm_kwargs["dynamic_img_size"] = self.dynamic_img_size
129
+ timm_kwargs["dynamic_img_pad"] = self.dynamic_img_pad
130
+
131
+ timm_kwargs.update(kwargs)
132
+ model = timm.create_model(name, **timm_kwargs)
133
+ data_config = timm.data.resolve_model_data_config(model=model)
134
+
135
+ model = model.eval()
136
+
137
+ return model, data_config
138
+
139
+ def forward(
140
+ self,
141
+ x: torch.Tensor,
142
+ n: Union[int, List[int], Tuple[int]] = 1,
143
+ return_prefix_tokens: bool = False,
144
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
145
+ """Intermediate layer accessor inspired by DINO / DINOv2 interface.
146
+ Args:
147
+ x: Input tensor.
148
+ n: Take last n blocks if int, all if None, select matching indices if sequence
149
+ reshape: Whether to reshape the output.
150
+ """
151
+
152
+ common_kwargs = dict(
153
+ norm=self.norm,
154
+ output_fmt="NCHW",
155
+ intermediates_only=True,
156
+ )
157
+
158
+ if "sam" not in self.name and return_prefix_tokens:
159
+ common_kwargs["return_prefix_tokens"] = return_prefix_tokens
160
+
161
+ elif "franca" in self.name:
162
+ B, C, H, W = x.shape
163
+ feats = self.model.forward_features(x, use_rasa_head=True)
164
+ out = feats["patch_token_rasa"]
165
+ out = rearrange(out, "b (h w) c -> b c h w", h=H // self.patch_size, w=W // self.patch_size)
166
+
167
+ elif "capi" in self.name:
168
+ *_, out = self.model(x)
169
+ out = out.permute(0, 3, 1, 2)
170
+
171
+ else:
172
+ out = self.model.forward_intermediates(x, n, **common_kwargs)
173
+
174
+ # "sam" models return feats only, others may return (feats, prefix)
175
+ if not isinstance(out, list) and not isinstance(out, tuple):
176
+ out = [out]
177
+ return out[0]
178
+ else:
179
+ assert len(out) == 1, f"Out contains {len(out)} elements, expected 1."
180
+ return out[0]
utils/training.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint as checkpoint
8
+ import torchvision.transforms as T
9
+ from hydra.utils import instantiate
10
+ from omegaconf import ListConfig
11
+ from torch.utils.tensorboard import SummaryWriter
12
+ from torchvision.transforms.functional import InterpolationMode
13
+
14
+ from src.backbone.vit_wrapper import PretrainedViTWrapper
15
+ from utils.img import PILToTensor
16
+
17
+
18
+ def seed_worker():
19
+ worker_seed = torch.initial_seed() % 2**32
20
+ np.random.seed(worker_seed)
21
+ random.seed(worker_seed)
22
+
23
+
24
+ def round_to_nearest_multiple(value, multiple=14):
25
+ return multiple * round(value / multiple)
26
+
27
+
28
+ def compute_feats(cfg, backbone, image_batch, min_rescale=0.60, max_rescale=0.25):
29
+ _, _, H, W = image_batch.shape # Get original height and width
30
+
31
+ with torch.no_grad():
32
+ hr_feats = backbone(image_batch)
33
+
34
+ if cfg.get("lr_img_size", None) is not None:
35
+ size = (cfg.lr_img_size, cfg.lr_img_size)
36
+ else:
37
+ # Downscale
38
+ if cfg.down_factor == "random":
39
+ downscale_factor = np.random.uniform(min_rescale, max_rescale)
40
+
41
+ elif cfg.down_factor == "fixed":
42
+ downscale_factor = 0.5
43
+
44
+ new_H = round_to_nearest_multiple(H * downscale_factor, backbone.patch_size)
45
+ new_W = round_to_nearest_multiple(W * downscale_factor, backbone.patch_size)
46
+ size = (new_H, new_W)
47
+ low_res_batch = F.interpolate(image_batch, size=size, mode="bilinear")
48
+ lr_feats = backbone(low_res_batch)
49
+
50
+ return hr_feats, lr_feats
51
+
52
+
53
+ def logger(args, base_log_dir):
54
+ os.makedirs(base_log_dir, exist_ok=True)
55
+ existing_versions = [
56
+ int(d.split("_")[-1])
57
+ for d in os.listdir(base_log_dir)
58
+ if os.path.isdir(os.path.join(base_log_dir, d)) and d.startswith("version_")
59
+ ]
60
+ new_version = max(existing_versions, default=-1) + 1
61
+ new_log_dir = os.path.join(base_log_dir, f"version_{new_version}")
62
+
63
+ # Create the SummaryWriter with the new log directory
64
+ writer = SummaryWriter(log_dir=new_log_dir)
65
+ return writer, new_version, new_log_dir
66
+
67
+
68
+ def get_dataloaders(cfg, shuffle=True):
69
+ """Get dataloaders for either training or evaluation.
70
+
71
+ Args:
72
+ cfg: Configuration object
73
+ backbone: Backbone model for normalization parameters
74
+ """
75
+ # Default ImageNet normalization values
76
+ transforms = {
77
+ "image": T.Compose(
78
+ [
79
+ T.Resize(cfg.img_size, interpolation=InterpolationMode.BILINEAR),
80
+ T.CenterCrop((cfg.img_size, cfg.img_size)),
81
+ T.ToTensor(),
82
+ ]
83
+ )
84
+ }
85
+
86
+ transforms["label"] = T.Compose(
87
+ [
88
+ # T.ToTensor(),
89
+ T.Resize(cfg.target_size, interpolation=InterpolationMode.NEAREST_EXACT),
90
+ T.CenterCrop((cfg.target_size, cfg.target_size)),
91
+ PILToTensor(),
92
+ ]
93
+ )
94
+ train_dataset = cfg.dataset
95
+ val_dataset = cfg.dataset.copy()
96
+ if hasattr(val_dataset, "split"):
97
+ val_dataset.split = "val"
98
+
99
+ train_dataset = instantiate(
100
+ train_dataset,
101
+ transform=transforms["image"],
102
+ target_transform=transforms["label"],
103
+ )
104
+ val_dataset = instantiate(
105
+ val_dataset,
106
+ transform=transforms["image"],
107
+ target_transform=transforms["label"],
108
+ )
109
+
110
+ # Create generator for reproducibility
111
+ if not shuffle:
112
+ g = torch.Generator()
113
+ g.manual_seed(0)
114
+ else:
115
+ g = None
116
+
117
+ # Prepare dataloader configs - set worker_init_fn to None when shuffling for randomness
118
+ train_dataloader_cfg = cfg.train_dataloader.copy()
119
+ val_dataloader_cfg = cfg.val_dataloader.copy()
120
+
121
+ if shuffle:
122
+ # Set worker_init_fn to None to allow true randomness when shuffling
123
+ if "worker_init_fn" in train_dataloader_cfg:
124
+ train_dataloader_cfg["worker_init_fn"] = None
125
+ if "worker_init_fn" in val_dataloader_cfg:
126
+ val_dataloader_cfg["worker_init_fn"] = None
127
+
128
+ return (
129
+ instantiate(train_dataloader_cfg, dataset=train_dataset, generator=g),
130
+ instantiate(val_dataloader_cfg, dataset=val_dataset, generator=g),
131
+ )
132
+
133
+
134
+ def get_batch(batch, device):
135
+ """Process batch and return required tensors."""
136
+ batch["image"] = batch["image"].to(device)
137
+ return batch
138
+
139
+
140
+ def setup_training_optimizations(model, cfg):
141
+ """
142
+ Setup training optimizations based on configuration
143
+
144
+ Args:
145
+ model: The model to apply optimizations to
146
+ cfg: Configuration object with use_bf16 and use_checkpointing flags
147
+
148
+ Returns:
149
+ tuple: (scaler, use_bf16, use_checkpointing) for use in training loop
150
+ """
151
+ # Get configuration values with defaults
152
+ use_bf16 = getattr(cfg, "use_bf16", False)
153
+ use_checkpointing = getattr(cfg, "use_checkpointing", False)
154
+
155
+ # Initialize gradient scaler for mixed precision
156
+ scaler = torch.amp.GradScaler("cuda", enabled=use_bf16)
157
+
158
+ # Enable gradient checkpointing if requested
159
+ if use_checkpointing:
160
+ if hasattr(model, "gradient_checkpointing_enable"):
161
+ model.gradient_checkpointing_enable()
162
+ print(" ✓ Using built-in gradient checkpointing")
163
+ else:
164
+ # For custom models, wrap forward methods
165
+ def checkpoint_wrapper(module):
166
+ if hasattr(module, "forward"):
167
+ original_forward = module.forward
168
+
169
+ def checkpointed_forward(*args, **kwargs):
170
+ return checkpoint.checkpoint(original_forward, *args, **kwargs)
171
+
172
+ module.forward = checkpointed_forward
173
+
174
+ # Apply to key modules (adjust based on your model structure)
175
+ checkpointed_modules = []
176
+ for name, module in model.named_modules():
177
+ if any(key in name for key in ["cross_decode", "encoder", "sft"]):
178
+ checkpoint_wrapper(module)
179
+ checkpointed_modules.append(name)
180
+
181
+ if checkpointed_modules:
182
+ print(f" ✓ Applied custom gradient checkpointing to: {checkpointed_modules}")
183
+ else:
184
+ print(" ⚠ No modules found for gradient checkpointing")
185
+
186
+ print(f"Training optimizations:")
187
+ print(f" Mixed precision (bfloat16): {use_bf16}")
188
+ print(f" Gradient checkpointing: {use_checkpointing}")
189
+
190
+ return scaler, use_bf16, use_checkpointing
191
+
192
+
193
+ def load_multiple_backbones(cfg, backbone_configs, device):
194
+ """
195
+ Load multiple backbone models based on configuration.
196
+
197
+ Args:
198
+ cfg: Hydra configuration object
199
+ device: PyTorch device to load models on
200
+
201
+ Returns:
202
+ tuple: (backbones, backbone_names, primary_backbone)
203
+ - backbones: List of loaded backbone models
204
+ - backbone_names: List of backbone names
205
+ """
206
+ backbones = []
207
+ backbone_names = []
208
+ backbone_img_sizes = []
209
+
210
+ if not isinstance(backbone_configs, list) and not isinstance(backbone_configs, ListConfig):
211
+ backbone_configs = [backbone_configs]
212
+ print(f"Loading {len(backbone_configs)} backbone(s)...")
213
+
214
+ for i, backbone_config in enumerate(backbone_configs):
215
+ name = backbone_config["name"]
216
+ if name == "rgb":
217
+ backbone = instantiate(cfg.backbone)
218
+ else:
219
+ backbone = PretrainedViTWrapper(name=name)
220
+ print(f" [{i}] Loaded {backbone_config['name']}")
221
+
222
+ # Move to device and set to eval mode
223
+ backbone = backbone.to(device)
224
+ backbone.eval() # Set to eval mode for feature extraction
225
+
226
+ # Store backbone and name
227
+ backbones.append(backbone)
228
+ backbone_names.append(backbone_config["name"])
229
+ backbone_img_sizes.append(backbone.config["input_size"][1:])
230
+
231
+ return backbones, backbone_names, backbone_img_sizes
utils/visualization.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Visualization code from https://github.com/Tsingularity/dift/blob/main/src/utils/visualization.py
2
+
3
+ import io
4
+ from pathlib import Path
5
+
6
+ import matplotlib.colors as mcolors
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from PIL import Image
13
+
14
+ FONT_SIZE = 40
15
+
16
+
17
+ @torch.no_grad()
18
+ def plot_feats(
19
+ image,
20
+ target,
21
+ pred,
22
+ legend=["Image", "HR Features", "Pred Features"],
23
+ save_path=None,
24
+ return_array=False,
25
+ show_legend=True,
26
+ font_size=FONT_SIZE,
27
+ ):
28
+ """
29
+ Create a plot_feats visualization.
30
+ """
31
+ # Ensure hr_or_seg is a list
32
+ if not isinstance(pred, list):
33
+ pred = [pred]
34
+
35
+ # Prepare inputs for PCA
36
+ feats_for_pca = [target.unsqueeze(0)] + [_.unsqueeze(0) for _ in pred]
37
+ reduced_feats, _ = pca(feats_for_pca) # pca outputs a list of reduced tensors
38
+
39
+ target_imgs = reduced_feats[0]
40
+ pred_imgs = reduced_feats[1:]
41
+
42
+ # --- Plot ---
43
+ # Determine number of columns based on whether image is provided
44
+ n_cols = (1 if image is not None else 0) + 1 + len(pred_imgs)
45
+ fig, ax = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
46
+ # Reduce space between images
47
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
48
+
49
+ # Handle single subplot case
50
+ if n_cols == 1:
51
+ ax = [ax]
52
+
53
+ # Current axis index
54
+ ax_idx = 0
55
+
56
+ # Plot original image if provided
57
+ if image is not None:
58
+ if image.dim() == 3:
59
+ ax[ax_idx].imshow(image.permute(1, 2, 0).detach().cpu())
60
+ elif image.dim() == 2:
61
+ ax[ax_idx].imshow(image.detach().cpu(), cmap="inferno")
62
+ if show_legend:
63
+ ax[ax_idx].set_title(legend[0], fontsize=font_size)
64
+ ax_idx += 1
65
+
66
+ # Plot the low-resolution features or segmentation mask
67
+ ax[ax_idx].imshow(target_imgs[0].permute(1, 2, 0).detach().cpu())
68
+ if show_legend:
69
+ legend_idx = 1 if image is not None else 0
70
+ ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
71
+ ax_idx += 1
72
+
73
+ # Plot HR features or segmentation masks
74
+ for idx, pred_img in enumerate(pred_imgs):
75
+ ax[ax_idx].imshow(pred_img[0].permute(1, 2, 0).detach().cpu())
76
+ if show_legend:
77
+ legend_idx = (2 if image is not None else 1) + idx
78
+ if len(legend) > legend_idx:
79
+ ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size)
80
+ else:
81
+ ax[ax_idx].set_title(f"HR Features {idx}", fontsize=font_size)
82
+ ax_idx += 1
83
+
84
+ remove_axes(ax)
85
+
86
+ # Handle return_array case
87
+ if return_array:
88
+ # Turn off interactive mode temporarily
89
+ was_interactive = plt.isinteractive()
90
+ plt.ioff()
91
+
92
+ # Convert figure to numpy array
93
+ buf = io.BytesIO()
94
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
95
+ buf.seek(0)
96
+
97
+ # Convert to PIL Image then to numpy array
98
+ pil_img = Image.open(buf)
99
+ img_array = np.array(pil_img)
100
+
101
+ # Close the figure and buffer
102
+ plt.close(fig)
103
+ buf.close()
104
+
105
+ # Restore interactive mode if it was on
106
+ if was_interactive:
107
+ plt.ion()
108
+
109
+ return img_array
110
+
111
+ # Standard behavior: save and/or show
112
+ if save_path is not None:
113
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
114
+ plt.show()
115
+
116
+ return None
117
+
118
+
119
+ def remove_axes(axes):
120
+ def _remove_axes(ax):
121
+ ax.xaxis.set_major_formatter(plt.NullFormatter())
122
+ ax.yaxis.set_major_formatter(plt.NullFormatter())
123
+ ax.set_xticks([])
124
+ ax.set_yticks([])
125
+
126
+ if len(axes.shape) == 2:
127
+ for ax1 in axes:
128
+ for ax in ax1:
129
+ _remove_axes(ax)
130
+ else:
131
+ for ax in axes:
132
+ _remove_axes(ax)
133
+
134
+
135
+ def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None):
136
+ target_size = None
137
+ if len(image_feats_list) > 1 and fit_pca is None:
138
+ target_size = image_feats_list[0].shape[2]
139
+
140
+ def flatten(tensor, target_size=None):
141
+ B, C, H, W = tensor.shape
142
+ assert B == 1, "Batch size should be 1 for PCA flattening"
143
+ if target_size is not None:
144
+ tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear", align_corners=False)
145
+ return rearrange(tensor, "b c h w -> (b h w) c").detach().cpu()
146
+
147
+ flattened_feats = []
148
+ for feats in image_feats_list:
149
+ flattened_feats.append(flatten(feats, target_size))
150
+ x = torch.cat(flattened_feats, dim=0)
151
+
152
+ # Subsample the data if max_samples is set and the number of samples exceeds max_samples
153
+ if max_samples is not None and x.shape[0] > max_samples:
154
+ indices = torch.randperm(x.shape[0])[:max_samples]
155
+ x = x[indices]
156
+
157
+ if fit_pca is None:
158
+ fit_pca = TorchPCA(n_components=dim).fit(x)
159
+
160
+ reduced_feats = []
161
+ for feats in image_feats_list:
162
+ B, C, H, W = feats.shape
163
+ x_red = fit_pca.transform(flatten(feats))
164
+ if isinstance(x_red, np.ndarray):
165
+ x_red = torch.from_numpy(x_red)
166
+ x_red -= x_red.min(dim=0, keepdim=True).values
167
+ x_red /= x_red.max(dim=0, keepdim=True).values
168
+ reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2))
169
+
170
+ return reduced_feats, fit_pca
171
+
172
+
173
+ class TorchPCA(object):
174
+
175
+ def __init__(self, n_components, skip=0):
176
+ self.n_components = n_components
177
+ self.skip = skip
178
+
179
+ def fit(self, X):
180
+ self.mean_ = X.mean(dim=0)
181
+ unbiased = X - self.mean_
182
+ U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=20)
183
+ self.components_ = V[:, self.skip :]
184
+ self.singular_values_ = S
185
+ return self
186
+
187
+ def transform(self, X):
188
+ t0 = X - self.mean_.unsqueeze(0)
189
+ projected = t0 @ self.components_
190
+ return projected