HarleyCooper commited on
Commit
af7be05
·
verified ·
1 Parent(s): 92f42c9

Sync docs from GitHub

Browse files
.env.template ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nanochatAquaRat Environment Template
2
+ # Copy this file to .env and fill in actual values before running any scripts.
3
+
4
+ # -----------------------------------------------------------------------------
5
+ # Cloud GPU Providers
6
+ # -----------------------------------------------------------------------------
7
+
8
+ # Lambda Labs automation (scripts/launch_lambda_training.py, launch_lambda.py)
9
+ # https://cloud.lambdalabs.com/api-keys
10
+ LAMBDA_API_KEY=your-lambda-api-key-here
11
+
12
+ # Hyperbolic Labs automation (scripts/launch_hyperbolic_training.py)
13
+ # https://app.hyperbolic.ai/settings/api-keys
14
+ HYPERBOLIC_API_KEY=your-hyperbolic-api-key-here
15
+
16
+ # Optional Hyperbolic defaults (override with CLI args if needed)
17
+ # HYPERBOLIC_REGION=us-east
18
+ # HYPERBOLIC_MAX_PRICE=6.00
19
+
20
+ # -----------------------------------------------------------------------------
21
+ # Experiment Tracking (Weights & Biases)
22
+ # -----------------------------------------------------------------------------
23
+ # Get your key at https://wandb.ai/authorize
24
+ WANDB_API_KEY=your-wandb-api-key-here
25
+ WANDB_PROJECT=nanochat-aquarat
26
+ WANDB_ENTITY=your-wandb-username-or-team-name
27
+
28
+ # Optional run metadata
29
+ WANDB_RUN=aquarat-$(date +%Y%m%d-%H%M%S)
30
+ WANDB_MODE=online # online | offline | disabled
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Google Cloud Storage Uploads (scripts/upload_to_gcs.sh)
34
+ # See GCS_UPLOAD_GUIDE.md for detailed instructions
35
+ # -----------------------------------------------------------------------------
36
+
37
+ # GCP project that owns your storage bucket
38
+ GCP_PROJECT_ID=your-gcp-project-id
39
+
40
+ # Default bucket for model artifacts (used by upload scripts and automation)
41
+ GCS_BUCKET=gs://your-model-bucket
42
+
43
+ # Service account credentials (recommended for automation). Point to the JSON
44
+ # key you download with `gcloud iam service-accounts keys create ...`
45
+ GOOGLE_APPLICATION_CREDENTIALS=/path/to/nanochat-gcs-key.json
46
+
47
+ # -----------------------------------------------------------------------------
48
+ # Optional cache/directories
49
+ # -----------------------------------------------------------------------------
50
+
51
+ # Override the default ~/.cache/nanochat location (used for checkpoints/data)
52
+ # NANOCHAT_BASE_DIR=/mnt/nanochat-cache
53
+
54
+ # If you pre-convert AQuA-RAT with scripts/prepare_aqua.py, set this so tasks
55
+ # and training scripts reuse the cached JSONL splits instead of downloading.
56
+ # AQUA_DATA_DIR=/mnt/datasets/aqua
57
+
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ aquarat2.png filter=lfs diff=lfs merge=lfs -text
GCS_UPLOAD_GUIDE.md ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Google Cloud Storage Upload Guide
2
+
3
+ After training your nanochat model on Lambda Labs, use this guide to upload all weights and artifacts to Google Cloud Storage.
4
+
5
+ ## Quick Start
6
+
7
+ ```bash
8
+ # After training completes, SSH to your Lambda instance
9
+ ssh ubuntu@<INSTANCE_IP>
10
+
11
+ # Navigate to project directory
12
+ cd ~/nanochatAquaRat
13
+
14
+ # Run upload script
15
+ bash scripts/upload_to_gcs.sh --bucket gs://your-bucket-name
16
+ ```
17
+
18
+ The script will:
19
+ 1. Check/install gcloud CLI if needed
20
+ 2. Verify authentication and bucket access
21
+ 3. Show what will be uploaded and ask for confirmation
22
+ 4. Upload all artifacts with progress
23
+ 5. Ask if you want to terminate the Lambda instance
24
+
25
+ ## Prerequisites
26
+
27
+ ### 1. Create a GCS Bucket
28
+
29
+ ```bash
30
+ # From your local machine
31
+ gcloud storage buckets create gs://your-bucket-name \
32
+ --location=us-central1 \
33
+ --uniform-bucket-level-access
34
+ ```
35
+
36
+ Or create via console: https://console.cloud.google.com/storage/create-bucket
37
+
38
+ ### 2. Set Up Authentication
39
+
40
+ #### Option A: Service Account (Recommended for Automation)
41
+
42
+ On your local machine:
43
+
44
+ ```bash
45
+ # Create service account
46
+ gcloud iam service-accounts create nanochat-uploader \
47
+ --display-name="Nanochat Model Uploader"
48
+
49
+ # Grant storage permissions
50
+ gcloud projects add-iam-policy-binding YOUR_PROJECT_ID \
51
+ --member="serviceAccount:nanochat-uploader@YOUR_PROJECT_ID.iam.gserviceaccount.com" \
52
+ --role="roles/storage.objectCreator"
53
+
54
+ # Create and download key
55
+ gcloud iam service-accounts keys create ~/nanochat-key.json \
56
+ --iam-account=nanochat-uploader@YOUR_PROJECT_ID.iam.gserviceaccount.com
57
+ ```
58
+
59
+ Copy key to Lambda instance:
60
+
61
+ ```bash
62
+ scp ~/nanochat-key.json ubuntu@<INSTANCE_IP>:~/
63
+ ```
64
+
65
+ On Lambda instance:
66
+
67
+ ```bash
68
+ gcloud auth activate-service-account --key-file=~/nanochat-key.json
69
+ ```
70
+
71
+ #### Option B: User Account (Simpler for Manual Use)
72
+
73
+ On Lambda instance:
74
+
75
+ ```bash
76
+ gcloud auth login
77
+ # Follow the prompts in your browser
78
+ ```
79
+
80
+ ## Usage
81
+
82
+ ### Basic Upload
83
+
84
+ ```bash
85
+ bash scripts/upload_to_gcs.sh --bucket gs://my-models
86
+ ```
87
+
88
+ ### Custom Run Name
89
+
90
+ ```bash
91
+ bash scripts/upload_to_gcs.sh \
92
+ --bucket gs://my-models \
93
+ --run-name depth20-experiment1
94
+ ```
95
+
96
+ ### Exclude Large Dataset Files
97
+
98
+ ```bash
99
+ bash scripts/upload_to_gcs.sh \
100
+ --bucket gs://my-models \
101
+ --exclude-data
102
+ ```
103
+
104
+ ### Dry Run (Preview Only)
105
+
106
+ ```bash
107
+ bash scripts/upload_to_gcs.sh \
108
+ --bucket gs://my-models \
109
+ --dry-run
110
+ ```
111
+
112
+ ### Auto-Terminate After Upload
113
+
114
+ ```bash
115
+ bash scripts/upload_to_gcs.sh \
116
+ --bucket gs://my-models \
117
+ --auto-terminate
118
+ ```
119
+
120
+ ## What Gets Uploaded
121
+
122
+ From `~/.cache/nanochat/`:
123
+
124
+ | Directory | Contents | Typical Size |
125
+ |-----------|----------|--------------|
126
+ | `checkpoints/` | Model weights (.pt, .pkl files) | 500MB - 2GB |
127
+ | `report/` | Training reports and markdown summaries | 1-10MB |
128
+ | `tokenizer/` | BPE tokenizer files | 10-50MB |
129
+ | `eval_bundle/` | Evaluation datasets | 50-200MB |
130
+ | `aqua/` | AQuA-RAT dataset (optional) | 100-500MB |
131
+ | `mechanistic_interpretability/` | DeepMind interp tools | 10-100MB |
132
+
133
+ **Total**: Typically 1-5 GB per training run
134
+
135
+ ## Upload Structure
136
+
137
+ Files are organized in GCS as:
138
+
139
+ ```
140
+ gs://your-bucket/
141
+ └── runs/
142
+ ├── aquarat-20251023-143022/
143
+ │ ├── checkpoints/
144
+ │ │ ├── base_final.pt
145
+ │ │ ├── mid_final.pt
146
+ │ │ ├── sft_final.pt
147
+ │ │ └── rl_final.pt
148
+ │ ├── report/
149
+ │ │ └── report.md
150
+ │ ├── tokenizer/
151
+ │ └── ...
152
+ └── depth20-experiment1/
153
+ └── ...
154
+ ```
155
+
156
+ ## Download Weights Later
157
+
158
+ ### Download Entire Run
159
+
160
+ ```bash
161
+ gsutil -m rsync -r \
162
+ gs://your-bucket/runs/aquarat-20251023-143022/ \
163
+ ./local_checkpoints/
164
+ ```
165
+
166
+ ### Download Just Checkpoints
167
+
168
+ ```bash
169
+ gsutil -m cp -r \
170
+ gs://your-bucket/runs/aquarat-20251023-143022/checkpoints/ \
171
+ ./checkpoints/
172
+ ```
173
+
174
+ ### Download Single File
175
+
176
+ ```bash
177
+ gsutil cp \
178
+ gs://your-bucket/runs/aquarat-20251023-143022/checkpoints/rl_final.pt \
179
+ ./rl_final.pt
180
+ ```
181
+
182
+ ## Cost Considerations
183
+
184
+ ### Storage Costs
185
+
186
+ - Standard storage: ~$0.02/GB/month
187
+ - Nearline storage (30+ days): ~$0.01/GB/month
188
+ - Coldline storage (90+ days): ~$0.004/GB/month
189
+
190
+ **Example**: 2GB model stored for 1 month = $0.04
191
+
192
+ ### Network Egress
193
+
194
+ - Upload (ingress): **Free**
195
+ - Download to same region: **Free**
196
+ - Download to internet: ~$0.12/GB
197
+
198
+ **Tip**: Keep your GCS bucket in the same region as your compute for free transfers.
199
+
200
+ ### Lifecycle Management
201
+
202
+ Auto-delete or move to cheaper storage after 90 days:
203
+
204
+ ```bash
205
+ cat > lifecycle.json << EOF
206
+ {
207
+ "lifecycle": {
208
+ "rule": [
209
+ {
210
+ "action": {"type": "Delete"},
211
+ "condition": {"age": 90}
212
+ }
213
+ ]
214
+ }
215
+ }
216
+ EOF
217
+
218
+ gsutil lifecycle set lifecycle.json gs://your-bucket
219
+ ```
220
+
221
+ ## Troubleshooting
222
+
223
+ ### "gcloud: command not found"
224
+
225
+ The script auto-installs gcloud on Linux. If it fails:
226
+
227
+ ```bash
228
+ curl https://sdk.cloud.google.com | bash
229
+ exec -l $SHELL
230
+ ```
231
+
232
+ ### "Permission denied" Error
233
+
234
+ Check your service account has `roles/storage.objectCreator`:
235
+
236
+ ```bash
237
+ gcloud projects get-iam-policy YOUR_PROJECT_ID \
238
+ --flatten="bindings[].members" \
239
+ --filter="bindings.members:serviceAccount:nanochat-uploader*"
240
+ ```
241
+
242
+ ### Upload Interrupted
243
+
244
+ The script uses `gsutil rsync`, so re-running will resume:
245
+
246
+ ```bash
247
+ bash scripts/upload_to_gcs.sh --bucket gs://your-bucket
248
+ # Will skip already-uploaded files
249
+ ```
250
+
251
+ ### Verify Upload
252
+
253
+ ```bash
254
+ # List all files in the run
255
+ gsutil ls -r gs://your-bucket/runs/your-run-name/
256
+
257
+ # Check specific checkpoints
258
+ gsutil ls gs://your-bucket/runs/your-run-name/checkpoints/
259
+ ```
260
+
261
+ ## Integration with Lambda Launcher
262
+
263
+ You can add GCS credentials to the automated launcher:
264
+
265
+ ```python
266
+ # In scripts/launch_lambda_training.py
267
+ # Add to the cloud-init user-data:
268
+
269
+ write_files:
270
+ - path: /home/ubuntu/.config/gcloud/application_default_credentials.json
271
+ content: |
272
+ {your service account key JSON}
273
+ ```
274
+
275
+ Or pass as environment variable:
276
+
277
+ ```bash
278
+ export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
279
+
280
+ python scripts/launch_lambda_training.py \
281
+ --inject-env GOOGLE_APPLICATION_CREDENTIALS \
282
+ ...
283
+ ```
284
+
285
+ ## Best Practices
286
+
287
+ 1. **Name runs descriptively**: Use `--run-name depth20-lr1e4-batch32`
288
+ 2. **Exclude data when iterating**: Use `--exclude-data` to save bandwidth
289
+ 3. **Dry run first**: Always use `--dry-run` to preview
290
+ 4. **Service accounts for automation**: Easier than user auth
291
+ 5. **Regional buckets**: Match Lambda instance region when possible
292
+ 6. **Lifecycle policies**: Auto-archive old models
293
+ 7. **Download to Lambda**: If re-training, download previous checkpoints to Lambda first
294
+
295
+ ## Security Notes
296
+
297
+ - Service account keys are sensitive - treat like passwords
298
+ - Use least-privilege IAM roles (don't grant `roles/owner`)
299
+ - Rotate service account keys regularly
300
+ - Consider Workload Identity if using GKE
301
+ - Don't commit keys to git (add to `.gitignore`)
302
+
303
+ ## Support
304
+
305
+ - GCS Documentation: https://cloud.google.com/storage/docs
306
+ - gsutil Reference: https://cloud.google.com/storage/docs/gsutil
307
+ - IAM Permissions: https://cloud.google.com/storage/docs/access-control/iam-permissions
308
+
309
+ ---
310
+
311
+ **Quick Reference**:
312
+ ```bash
313
+ # Upload
314
+ bash scripts/upload_to_gcs.sh --bucket gs://my-bucket
315
+
316
+ # Download
317
+ gsutil -m cp -r gs://my-bucket/runs/NAME/checkpoints/ ./
318
+
319
+ # List runs
320
+ gsutil ls gs://my-bucket/runs/
321
+
322
+ # Delete old run
323
+ gsutil -m rm -r gs://my-bucket/runs/old-run/
LAMBDA_MANUAL_SETUP.md ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Manual Setup Guide for Lambda Labs
2
+
3
+ This guide walks you through manually launching and configuring a Lambda Labs GPU instance for training the nanochatAquaRat model with RL on AQuA-RAT.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Lambda Labs Account**: Sign up at https://cloud.lambdalabs.com
8
+ 2. **SSH Key**: Add your SSH public key to Lambda Labs
9
+ 3. **W&B Account**: Sign up at https://wandb.ai and get your API key
10
+ 4. **Sufficient Credits**: Ensure you have enough credits (~$24 for 8 hours on 8x H100)
11
+
12
+ ## Step 1: Add SSH Key to Lambda Labs
13
+
14
+ 1. Go to https://cloud.lambdalabs.com/ssh-keys
15
+ 2. Click "Add SSH Key"
16
+ 3. Paste your public SSH key (from `~/.ssh/id_rsa.pub` or `~/.ssh/id_ed25519.pub`)
17
+ 4. Give it a name (e.g., "my-laptop")
18
+ 5. Click "Add SSH Key"
19
+
20
+ ## Step 2: Launch Instance via Web Dashboard
21
+
22
+ 1. Navigate to https://cloud.lambdalabs.com/instances
23
+ 2. Click **"Launch instance"**
24
+ 3. Configure your instance:
25
+ - **Instance type**: Select `gpu_8x_h100_sxm5` (8x NVIDIA H100 80GB SXM5)
26
+ - For testing: Use `gpu_1x_a10` or smaller
27
+ - **Region**: Choose a region with availability (e.g., `us-west-1`)
28
+ - **SSH Keys**: Select your SSH key
29
+ - **Filesystem**: (Optional) If you have persistent storage
30
+ 4. Click **"Launch instance"**
31
+ 5. Wait 1-2 minutes for the instance to boot
32
+
33
+ ## Step 3: Note Instance Details
34
+
35
+ Once the instance is running, note:
36
+ - **Instance ID**: (e.g., `0123456789abcdef`)
37
+ - **IP Address**: (e.g., `123.45.67.89`)
38
+ - **SSH Command**: Shown in the web interface
39
+
40
+ ## Step 4: Connect to Instance
41
+
42
+ Open your terminal and connect:
43
+
44
+ ```bash
45
+ ssh ubuntu@<INSTANCE_IP>
46
+ ```
47
+
48
+ Example:
49
+ ```bash
50
+ ssh ubuntu@123.45.67.89
51
+ ```
52
+
53
+ ## Step 5: Set Up Environment
54
+
55
+ Once connected, run these commands:
56
+
57
+ ### 5.1 Create Environment File
58
+
59
+ ```bash
60
+ # Create .env file with your credentials
61
+ cat > ~/.env << 'EOF'
62
+ WANDB_API_KEY=your-wandb-api-key-here
63
+ WANDB_PROJECT=nanochat-aquarat
64
+ WANDB_ENTITY=your-wandb-username-or-team
65
+ EOF
66
+ ```
67
+
68
+ Replace `your-wandb-api-key-here` with your actual W&B API key (get it from https://wandb.ai/authorize)
69
+
70
+ ### 5.2 Clone Repository
71
+
72
+ ```bash
73
+ cd ~
74
+ git clone https://github.com/HarleyCoops/nanochatAquaRat.git
75
+ cd nanochatAquaRat
76
+ ```
77
+
78
+ ### 5.3 Copy Environment Variables
79
+
80
+ ```bash
81
+ # Copy the .env file to the project directory
82
+ cp ~/.env .env
83
+ ```
84
+
85
+ ## Step 6: Start Training
86
+
87
+ You have two options for running the training:
88
+
89
+ ### Option A: Run in Screen Session (Recommended)
90
+
91
+ This allows you to detach and let training continue even if you disconnect:
92
+
93
+ ```bash
94
+ # Start a screen session
95
+ screen -S training
96
+
97
+ # Run the training script
98
+ bash run_aquarat_small.sh
99
+ ```
100
+
101
+ **Screen Commands:**
102
+ - **Detach from screen**: Press `Ctrl+A` then `D`
103
+ - **Reattach to screen**: `screen -r training`
104
+ - **List all screen sessions**: `screen -ls`
105
+ - **Kill a screen session**: `screen -X -S training quit`
106
+
107
+ ### Option B: Run Directly (Blocks Terminal)
108
+
109
+ ```bash
110
+ # Run training directly (terminal will be blocked)
111
+ bash run_aquarat_small.sh 2>&1 | tee training.log
112
+ ```
113
+
114
+ This saves output to `training.log` for later review.
115
+
116
+ ## Step 7: Monitor Training
117
+
118
+ ### Monitor via Terminal
119
+
120
+ If using screen:
121
+ ```bash
122
+ # Reattach to see live output
123
+ screen -r training
124
+
125
+ # Or tail the report
126
+ tail -f ~/.cache/nanochat/report/report.md
127
+ ```
128
+
129
+ ### Monitor via Weights & Biases
130
+
131
+ 1. Go to https://wandb.ai
132
+ 2. Navigate to your project: `nanochat-aquarat`
133
+ 3. View real-time metrics, losses, and generated samples
134
+
135
+ Key metrics to watch:
136
+ - `rl/acc` - Accuracy on AQuA-RAT
137
+ - `rl/mean_reward` - Average reward per sample
138
+ - `rl/kl_letter_mean` - KL divergence from initial policy
139
+ - `rl/letter_margin_mean` - Confidence in letter choices
140
+ - `attn/entropy_mean` - Attention mechanism entropy
141
+
142
+ ## Step 8: Training Timeline
143
+
144
+ For the **small (depth=8) model**:
145
+ - **Base pretraining**: ~1-2 hours
146
+ - **Mid-training**: ~30 minutes
147
+ - **SFT**: ~30 minutes
148
+ - **RL**: ~30 minutes
149
+ - **Total**: ~3-4 hours
150
+
151
+ For the **d-20 model** (561M params):
152
+ - **Base pretraining**: ~3-4 hours
153
+ - **Mid-training**: ~1 hour
154
+ - **SFT**: ~1 hour
155
+ - **RL**: ~1 hour
156
+ - **Total**: ~6-8 hours
157
+
158
+ ## Step 9: Check Results
159
+
160
+ After training completes:
161
+
162
+ ```bash
163
+ # View the final report
164
+ cat ~/.cache/nanochat/report/report.md
165
+
166
+ # Check RL checkpoint
167
+ ls -lh ~/.cache/nanochat/checkpoints/
168
+
169
+ # View evaluation results
170
+ cat ~/.cache/nanochat/evals/
171
+ ```
172
+
173
+ ## Step 10: Download Artifacts (Optional)
174
+
175
+ If you want to save the trained model locally:
176
+
177
+ ```bash
178
+ # From your local machine (not on the Lambda instance):
179
+ # Download checkpoints
180
+ scp -r ubuntu@<INSTANCE_IP>:~/.cache/nanochat/checkpoints ./local_checkpoints/
181
+
182
+ # Download reports
183
+ scp -r ubuntu@<INSTANCE_IP>:~/.cache/nanochat/report ./local_reports/
184
+
185
+ # Download training logs
186
+ scp ubuntu@<INSTANCE_IP>:~/nanochatAquaRat/training.log ./training.log
187
+ ```
188
+
189
+ ## Step 11: Terminate Instance
190
+
191
+ **IMPORTANT**: Remember to terminate your instance when done to avoid charges!
192
+
193
+ ### Via Web Dashboard:
194
+ 1. Go to https://cloud.lambdalabs.com/instances
195
+ 2. Find your instance
196
+ 3. Click the **"..."** menu
197
+ 4. Select **"Terminate"**
198
+ 5. Confirm termination
199
+
200
+ ### Via SSH (before disconnecting):
201
+ ```bash
202
+ # Shutdown the instance (will auto-terminate if configured)
203
+ sudo shutdown -h now
204
+ ```
205
+
206
+ ## Troubleshooting
207
+
208
+ ### Issue: "Out of memory" Error
209
+
210
+ **Solution**: Reduce batch size in the training script
211
+ ```bash
212
+ # Edit run_aquarat_small.sh and add these flags to the torchrun commands:
213
+ --device_batch_size=2 # Reduce from default
214
+ ```
215
+
216
+ ### Issue: W&B Not Logging
217
+
218
+ **Solution**: Check your API key
219
+ ```bash
220
+ # Test W&B login
221
+ wandb login
222
+
223
+ # Verify environment variable
224
+ echo $WANDB_API_KEY
225
+
226
+ # Re-run with explicit login
227
+ export WANDB_API_KEY=your-key-here
228
+ bash run_aquarat_small.sh
229
+ ```
230
+
231
+ ### Issue: Screen Session Lost
232
+
233
+ **Solution**: Reattach to screen
234
+ ```bash
235
+ # List all screen sessions
236
+ screen -ls
237
+
238
+ # Reattach to the training session
239
+ screen -r training
240
+
241
+ # If screen says "Detached", force attach
242
+ screen -d -r training
243
+ ```
244
+
245
+ ### Issue: Dataset Download Slow
246
+
247
+ **Solution**: The script downloads data in parallel. Wait for completion or reduce number of shards.
248
+
249
+ ### Issue: SSH Connection Drops
250
+
251
+ **Solution**: Use `screen` or `tmux` to keep processes running
252
+ ```bash
253
+ # If you didn't use screen initially and got disconnected:
254
+ # Reconnect and check if the process is still running
255
+ ps aux | grep python
256
+
257
+ # If running, you can monitor the log files:
258
+ tail -f ~/.cache/nanochat/report/report.md
259
+ tail -f ~/nanochatAquaRat/training.log
260
+ ```
261
+
262
+ ## Cost Estimation
263
+
264
+ **8x H100 SXM5** pricing (as of reference):
265
+ - ~$3.00/hour per GPU
266
+ - 8 GPUs = $24/hour
267
+ - Small model (4 hours) = ~$96
268
+ - d-20 model (8 hours) = ~$192
269
+
270
+ **Budget-friendly testing options:**
271
+ - 1x A10 (24GB): ~$0.60/hour - Good for testing pipeline
272
+ - 1x A6000 (48GB): ~$0.80/hour - Can run small model
273
+ - 2x A100 (40GB): ~$2.20/hour - Can run d-20 with reduced batch size
274
+
275
+ ## Quick Reference Commands
276
+
277
+ ```bash
278
+ # SSH to instance
279
+ ssh ubuntu@<INSTANCE_IP>
280
+
281
+ # Start training in screen
282
+ screen -S training
283
+ bash run_aquarat_small.sh
284
+
285
+ # Detach from screen
286
+ Ctrl+A then D
287
+
288
+ # Reattach to screen
289
+ screen -r training
290
+
291
+ # Monitor W&B
292
+ Open: https://wandb.ai
293
+
294
+ # View live report
295
+ tail -f ~/.cache/nanochat/report/report.md
296
+
297
+ # Check GPU usage
298
+ nvidia-smi
299
+
300
+ # Terminate instance (via dashboard)
301
+ https://cloud.lambdalabs.com/instances
302
+ ```
303
+
304
+ ## Support
305
+
306
+ - **Lambda Labs Support**: https://lambdalabs.com/support
307
+ - **W&B Support**: https://docs.wandb.ai
308
+ - **nanochat Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
309
+
310
+ ## Next Steps
311
+
312
+ After your model is trained:
313
+ 1. Download checkpoints for inference
314
+ 2. Use the web interface: `python -m scripts.chat_web`
315
+ 3. Test via CLI: `python -m scripts.chat_cli`
316
+ 4. Share your results on W&B
317
+ 5. Fine-tune on additional datasets if desired
QUICKSTART.md ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lambda Labs Training Quickstart
2
+
3
+ This project provides two ways to run model training on Lambda Labs:
4
+
5
+ 1. **Automated API Script** (Recommended) - Fully automated deployment
6
+ 2. **Manual Setup** - Step-by-step web dashboard approach
7
+
8
+ ## Prerequisites
9
+
10
+ Both methods require:
11
+ - Lambda Labs account with API key
12
+ - SSH key added to Lambda Labs
13
+ - Weights & Biases account with API key
14
+ - Sufficient credits (~$24/hour for 8x H100)
15
+
16
+ ## Method 1: Automated API Script (Recommended)
17
+
18
+ ### Setup
19
+
20
+ 1. **Set environment variables:**
21
+
22
+ ```bash
23
+ export LAMBDA_API_KEY='your-lambda-api-key'
24
+ export WANDB_API_KEY='your-wandb-api-key'
25
+ ```
26
+
27
+ Get your Lambda API key from: https://cloud.lambdalabs.com/api-keys
28
+ Get your W&B API key from: https://wandb.ai/authorize
29
+
30
+ 2. **Install dependencies:**
31
+
32
+ ```bash
33
+ pip install lambda-cloud-client
34
+ ```
35
+
36
+ ### Usage
37
+
38
+ **Check available instances:**
39
+
40
+ ```bash
41
+ python launch_lambda.py --list-types
42
+ ```
43
+
44
+ **Launch and start training:**
45
+
46
+ ```bash
47
+ # Launch 8x H100 instance (recommended for d-20 model)
48
+ python launch_lambda.py --instance-type gpu_8x_h100_sxm5 --region us-west-1
49
+
50
+ # Launch smaller instance for testing (depth-8 model)
51
+ python launch_lambda.py --instance-type gpu_1x_a100 --region us-west-1
52
+ ```
53
+
54
+ **Just launch without deploying:**
55
+
56
+ ```bash
57
+ python launch_lambda.py --instance-type gpu_8x_h100_sxm5 --no-deploy
58
+ ```
59
+
60
+ The script will:
61
+ 1. ✓ Launch the instance
62
+ 2. ✓ Wait for it to be ready
63
+ 3. ✓ Deploy the code
64
+ 4. ✓ Start training in a screen session
65
+ 5. ✓ Provide connection details
66
+
67
+ ### Monitor Training
68
+
69
+ After launching, SSH to the instance:
70
+
71
+ ```bash
72
+ ssh ubuntu@<INSTANCE_IP>
73
+ ```
74
+
75
+ Then attach to the screen session:
76
+
77
+ ```bash
78
+ screen -r training
79
+ ```
80
+
81
+ Or view logs:
82
+
83
+ ```bash
84
+ tail -f ~/nanochatAquaRat/training.log
85
+ ```
86
+
87
+ ## Method 2: Manual Setup
88
+
89
+ For detailed step-by-step instructions, see [LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md)
90
+
91
+ **Quick summary:**
92
+ 1. Go to https://cloud.lambdalabs.com/instances
93
+ 2. Launch instance manually
94
+ 3. SSH to instance
95
+ 4. Clone repo and set up .env
96
+ 5. Run `bash run_aquarat_small.sh`
97
+
98
+ ## Training Configuration
99
+
100
+ The `run_aquarat_small.sh` script trains a **depth-8 (smaller) model** which takes approximately **3-4 hours** on 8x H100.
101
+
102
+ ### What Gets Trained:
103
+
104
+ 1. **Base Model** (depth-8, ~60M params)
105
+ - Pretrained on limited corpus (24 shards)
106
+ - Faster iteration for testing
107
+
108
+ 2. **Mid-Training**
109
+ - Conversation format adaptation
110
+ - Tool use capabilities
111
+
112
+ 3. **Supervised Fine-Tuning (SFT)**
113
+ - Fine-tuned on AQuA-RAT dataset
114
+ - Multiple-choice math reasoning
115
+
116
+ 4. **Reinforcement Learning (RL)**
117
+ - GRPO-style RL on AQuA-RAT
118
+ - KL divergence tracking
119
+ - Letter-choice logit margin analysis
120
+ - Attention mechanism logging
121
+
122
+ ### W&B Metrics Logged:
123
+
124
+ - `rl/acc` - Answer accuracy
125
+ - `rl/mean_reward` - Average reward
126
+ - `rl/kl_letter_mean` - Policy drift (letter-level)
127
+ - `rl/kl_sequence_mean` - Policy drift (sequence-level)
128
+ - `rl/letter_margin_mean` - Confidence in answers
129
+ - `attn/entropy_mean` - Attention patterns
130
+
131
+ ## Model Sizes Available
132
+
133
+ You can modify `run_aquarat_small.sh` to change the model depth:
134
+
135
+ | Depth | Params | Training Time | Recommended Instance |
136
+ |-------|--------|---------------|---------------------|
137
+ | 8 | ~60M | 3-4 hours | 1x A100 / 2x A100 |
138
+ | 12 | ~180M | 4-5 hours | 4x A100 |
139
+ | 20 | ~561M | 6-8 hours | 8x H100 |
140
+ | 26 | ~1.1B | 10-12 hours | 8x H100 |
141
+
142
+ To change depth, edit the `--depth` parameter in `run_aquarat_small.sh`:
143
+
144
+ ```bash
145
+ torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
146
+ ```
147
+
148
+ ## Cost Estimates
149
+
150
+ Based on Lambda Labs pricing:
151
+
152
+ | Instance Type | GPUs | Cost/Hour | Small (4h) | d-20 (8h) |
153
+ |------------------|--------------|-----------|------------|-----------|
154
+ | gpu_8x_h100_sxm5 | 8x H100 80GB | ~$24.00 | ~$96 | ~$192 |
155
+ | gpu_4x_a100 | 4x A100 40GB | ~$8.80 | ~$35 | ~$70 |
156
+ | gpu_2x_a100 | 2x A100 40GB | ~$4.40 | ~$18 | ~$35 |
157
+ | gpu_1x_a100 | 1x A100 40GB | ~$2.20 | ~$9 | ~$18 |
158
+
159
+ ## Monitoring Options
160
+
161
+ ### 1. SSH + Screen
162
+
163
+ ```bash
164
+ ssh ubuntu@<INSTANCE_IP>
165
+ screen -r training
166
+ # Ctrl+A then D to detach
167
+ ```
168
+
169
+ ### 2. Weights & Biases
170
+
171
+ Dashboard: https://wandb.ai
172
+
173
+ Real-time metrics, attention heatmaps, sample completions
174
+
175
+ ### 3. Log Files
176
+
177
+ ```bash
178
+ # Training log
179
+ tail -f ~/nanochatAquaRat/training.log
180
+
181
+ # Progress report
182
+ tail -f ~/.cache/nanochat/report/report.md
183
+ ```
184
+
185
+ ## After Training
186
+
187
+ ### Download Checkpoints
188
+
189
+ From your local machine:
190
+
191
+ ```bash
192
+ scp -r ubuntu@<INSTANCE_IP>:~/.cache/nanochat/checkpoints ./checkpoints/
193
+ ```
194
+
195
+ ### Run Inference
196
+
197
+ On the Lambda instance:
198
+
199
+ ```bash
200
+ # Web interface
201
+ python -m scripts.chat_web
202
+
203
+ # CLI interface
204
+ python -m scripts.chat_cli -p "What is 25 * 37?"
205
+
206
+ # Evaluate on test set
207
+ python -m scripts.chat_eval -- -i rl -a AQUA
208
+ ```
209
+
210
+ ### Don't Forget to Terminate!
211
+
212
+ **Via Dashboard:**
213
+ https://cloud.lambdalabs.com/instances → Terminate
214
+
215
+ **Via CLI:**
216
+ ```bash
217
+ sudo shutdown -h now
218
+ ```
219
+
220
+ ## Troubleshooting
221
+
222
+ ### Issue: API Key Not Working
223
+
224
+ ```bash
225
+ # Verify keys are set
226
+ echo $LAMBDA_API_KEY
227
+ echo $WANDB_API_KEY
228
+
229
+ # Re-export if needed
230
+ export LAMBDA_API_KEY='your-key'
231
+ export WANDB_API_KEY='your-key'
232
+ ```
233
+
234
+ ### Issue: No Available Instances
235
+
236
+ Lambda Labs instances can be in high demand. Try:
237
+ - Different regions (`--region us-east-1`)
238
+ - Smaller instance types (`gpu_1x_a100`)
239
+ - Check availability: `python launch_lambda.py --list-types`
240
+
241
+ ### Issue: Out of Memory
242
+
243
+ Edit `run_aquarat_small.sh` and reduce batch size:
244
+
245
+ ```bash
246
+ # Add to torchrun commands:
247
+ --device_batch_size=2
248
+ ```
249
+
250
+ ### Issue: Training Stuck
251
+
252
+ Check GPU utilization:
253
+
254
+ ```bash
255
+ nvidia-smi
256
+ ```
257
+
258
+ If GPUs are idle, check for errors:
259
+
260
+ ```bash
261
+ tail -100 ~/nanochatAquaRat/training.log
262
+ ```
263
+
264
+ ## Files in This Repository
265
+
266
+ - `launch_lambda.py` - Automated Lambda Labs launcher
267
+ - `run_aquarat_small.sh` - Training script (depth-8 model)
268
+ - `LAMBDA_MANUAL_SETUP.md` - Detailed manual setup guide
269
+ - `QUICKSTART.md` - This file
270
+ - `.env.template` - Environment variable template
271
+
272
+ ## Support
273
+
274
+ - **Lambda Labs**: https://lambdalabs.com/support
275
+ - **Weights & Biases**: https://docs.wandb.ai
276
+ - **Project Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
277
+
278
+ ## Next Steps
279
+
280
+ 1. ✓ Set up Lambda Labs account and API key
281
+ 2. ✓ Set up Weights & Biases account
282
+ 3. ✓ Choose your method (API script or manual)
283
+ 4. ✓ Launch instance and start training
284
+ 5. ✓ Monitor via W&B dashboard
285
+ 6. ✓ Download checkpoints when complete
286
+ 7. ✓ Terminate instance to stop charges
287
+
288
+ Happy training!
README.md CHANGED
@@ -13,440 +13,440 @@ tags:
13
  ---
14
 
15
  <div align="center">
16
-
17
- ![AQuA-RAT Training](./aquarat2.png)
18
-
19
- # nanochatAquaRat
20
-
21
- **Training Language Models with Reinforcement Learning on Mathematical Reasoning**
22
-
23
- [![GitHub](https://img.shields.io/badge/GitHub-Repository-blue?logo=github)](https://github.com/HarleyCoops/nanochatAquaRat)
24
- [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
25
- [![Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/downloads/)
26
-
27
- A modified version of [nanochat](https://github.com/karpathy/nanochat) trained with reinforcement learning on the [DeepMind AQuA-RAT dataset](https://huggingface.co/datasets/deepmind/aqua_rat) for algebraic reasoning and multiple-choice problem solving.
28
-
29
- [Quick Start](#quick-start) • [Dataset](#dataset-structure) • [Modifications](#modifications-from-base-nanochat) • [Training](#training-pipeline) • [Results](#results)
30
-
31
- </div>
32
-
33
- ---
34
-
35
- ## Table of Contents
36
-
37
- - [Overview](#overview)
38
- - [The Base: nanochat Framework](#the-base-nanochat-framework)
39
- - [Dataset Structure](#dataset-structure)
40
- - [Modifications from Base nanochat](#modifications-from-base-nanochat)
41
- - [Training Pipeline](#training-pipeline)
42
- - [Quick Start](#quick-start)
43
- - [File Structure](#file-structure)
44
- - [Monitoring & Visualization](#monitoring--visualization)
45
- - [Results](#results)
46
-
47
- ---
48
-
49
- ## Overview
50
-
51
- This project adapts the **nanochat** training framework (originally designed for GSM8K numerical reasoning) to work with **AQuA-RAT** (Algebra Question Answering with Rationales), a dataset of ~97,000 algebraic word problems with multiple-choice answers (A-E) and natural language solution rationales.
52
-
53
- ### Why This Matters
54
-
55
- - **Domain Transfer**: Demonstrates how to adapt a mathematical reasoning pipeline from free-form numeric answers to multiple-choice format
56
- - **RL on Math**: Implements GRPO-style reinforcement learning with reward shaping for categorical outputs
57
- - **Mechanistic Interpretability**: Integrates attention analysis during training to understand model reasoning patterns
58
- - **Production-Ready**: Includes automated Lambda Labs and Hyperbolic Labs deployment helpers for cloud GPU training
59
-
60
- ### Key Results
61
-
62
- | Model | Parameters | Training Time | AQuA-RAT Dev Accuracy |
63
- |-------|------------|---------------|----------------------|
64
- | depth-8 | ~60M | 3-4 hours | 30-50% |
65
- | depth-20 | ~561M | 6-8 hours | 40-60% |
66
-
67
- ---
68
-
69
- ## The Base: nanochat Framework
70
-
71
- **nanochat** is a minimalist yet complete pipeline for training transformer language models from scratch, created by Andrej Karpathy. It implements:
72
-
73
- - **Custom tokenizer**: BPE tokenizer written in Rust for performance
74
- - **Training stages**: Pretraining → Mid-training → SFT → RL
75
- - **Evaluation suite**: CORE benchmarks and task-specific metrics
76
- - **Optimizations**: Memory-efficient training, gradient accumulation, distributed training
77
-
78
- **Original focus**: Training on GSM8K (Grade School Math 8K) with free-form numeric answers.
79
-
80
-
81
- ---
82
-
83
- ## Dataset Structure
84
-
85
- ### AQuA-RAT Format
86
-
87
- The [DeepMind AQuA-RAT dataset](https://github.com/deepmind/AQuA) contains algebraic reasoning problems in JSON format:
88
-
89
- ```json
90
- {
91
- "question": "A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?",
92
- "options": [
93
- "A) 53 km",
94
- "B) 55 km",
95
- "C) 52 km",
96
- "D) 60 km",
97
- "E) 50 km"
98
- ],
99
- "rationale": "The distance that the person traveled = 20 * 2.5 = 50 km. Answer: E",
100
- "correct": "E"
101
- }
102
- ```
103
-
104
- **Dataset splits**:
105
- - Training: 97,467 problems
106
- - Development: 254 problems
107
- - Test: 254 problems
108
-
109
- **Key characteristics**:
110
- - Multiple-choice (A-E) format
111
- - Algebraic word problems
112
- - Natural language rationales
113
- - Topics: arithmetic, algebra, geometry, probability
114
-
115
- ### Comparison: GSM8K vs AQuA-RAT
116
-
117
- | Aspect | GSM8K (Original) | AQuA-RAT (This Project) |
118
- |--------|------------------|-------------------------|
119
- | **Format** | Free-form numeric | Multiple choice (A-E) |
120
- | **Answer** | Single number | Letter choice |
121
- | **Size** | 8,500 problems | 97,700 problems |
122
- | **Difficulty** | Elementary school | High school algebra |
123
- | **Rationale** | Step-by-step | Natural language |
124
- | **Evaluation** | Exact match on number | Categorical accuracy |
125
-
126
- ---
127
-
128
- ## Modifications from Base nanochat
129
-
130
- To adapt nanochat from GSM8K to AQuA-RAT, we modified the following components:
131
-
132
- ### 1. Dataset Loader (`scripts/prepare_aqua.py`)
133
-
134
- **Created new file** to download and format AQuA-RAT:
135
-
136
- ```python
137
- # New file: scripts/prepare_aqua.py
138
- ### 1. Dataset Preparation (`scripts/prepare_aqua.py`)
139
-
140
- - Uses `datasets.load_dataset("deepmind/aqua_rat")` and optionally caps split sizes.
141
- - Emits JSONL files (`train.jsonl`, `validation.jsonl`, `test.jsonl`) compatible with
142
- the conversation schema used throughout nanochat.
143
- - Defaults to `~/.cache/nanochat/aqua`, but accepts `--output_dir` overrides so
144
- launchers can bundle their own artifact.
145
-
146
- ```python
147
- def format_example(row):
148
- options = row["options"]
149
- assistant_content = [
150
- {"type": "text", "text": row["rationale"].strip()},
151
- {"type": "text", "text": f"Answer: {row['correct'].strip().upper()}"},
152
- ]
153
- return {
154
- "messages": [
155
- {"role": "user", "content": _render_user_prompt(row["question"], options)},
156
- {"role": "assistant", "content": assistant_content},
157
- ],
158
- "letters": letters,
159
- "answer_letter": correct,
160
- }
161
- ```
162
-
163
- ### 2. Task Module (`tasks/aqua.py`)
164
-
165
- - Accepts optional `data_dir` (or `AQUA_DATA_DIR` / `NANOCHAT_AQUA_DIR`) so the task
166
- can read the cached JSONL; otherwise falls back to Hugging Face.
167
- - Provides `_render_user_prompt` to format the question/options using the common
168
- multiple-choice helper and `_extract_letter` to score completions.
169
- - Returns conversations whose assistant messages include both the rationale and a
170
- final `Answer: <LETTER>` line for SFT, while `evaluate()` only cares about the letter.
171
-
172
- ```python
173
- def _extract_letter(text, default=None):
174
- answer_match = re.search(r"answer\s*[:\-]\s*([A-E])", text, flags=re.IGNORECASE)
175
- if answer_match:
176
- return answer_match.group(1).upper()
177
- match = LETTER_RE.search(text)
178
- return match.group(1).upper() if match else default
179
- ```
180
-
181
- **Key differences from GSM8K**:
182
- - Numeric extraction → Letter extraction
183
- - Free-form answer → Fixed choices A-E
184
- - Exact number match → Categorical match
185
-
186
- ### 3. RL Training (`scripts/chat_rl.py`)
187
-
188
- **Modified** to support both GSM8K and AQuA-RAT:
189
-
190
- Key updates:
191
-
192
- - `train_task` / `val_task` now instantiate `AQUA(...)` instead of `GSM8K(...)`.
193
- - Rewards reuse the task's `evaluate()` helper so any completion containing
194
- “Answer: X” (or the first bare letter) is scored correctly.
195
- - The validation helper became `run_aqua_eval`, still reporting pass@k accuracy
196
- across sampled completions.
197
- - CLI overrides remain the same because the script continues to rely on the
198
- nanochat configurator (`--run`, `--temperature`, `--max_new_tokens`, …).
199
-
200
- ### 4. Evaluation (`scripts/chat_eval.py`)
201
-
202
- - Registered `'AQUA'` in the task registry so `-a AQUA` just works.
203
- - Added a 20% random-guess baseline when aggregating the ChatCORE metric.
204
- - The categorical evaluation path reuses `run_categorical_eval`, clamping logits
205
- to the available letters before scoring.
206
-
207
- ### 5. Training Script (`run_aquarat_small.sh`)
208
-
209
- **What changed vs upstream nanochat**:
210
-
211
- ```bash
212
- # (Optional) Cache the dataset locally as JSONL
213
- python -m scripts.prepare_aqua --output_dir "$NANOCHAT_BASE_DIR/aqua"
214
-
215
- # Mid-training now samples from the AQuA mixture
216
- torchrun -m scripts.mid_train -- --run=demo --num_iterations=200
217
-
218
- # SFT stage emphasises AQuA problems
219
- torchrun -m scripts.sft_train -- --run=demo --aqua_train_examples=20000
220
-
221
- # RL fine-tuning rewards the correct letter on AQuA-RAT
222
- torchrun -m scripts.chat_rl -- --run=demo --temperature=0.7 --max_new_tokens=64
223
- ```
224
-
225
- - **`tasks/aqua.py`** loads AQuA-RAT either from Hugging Face or the cached JSONL
226
- splits, formats questions as conversations, and scores completions by letter.
227
- - **`scripts/mid_train.py`** extends the original Reasoning+Chat mixture with a
228
- 50k slice of AQuA so the model sees multiple-choice algebra earlier.
229
- - **`scripts/chat_sft.py`** replaces the GSM8K component with AQuA, keeping ARC,
230
- SmolTalk, and identity prompts for general chat coverage.
231
- - **`scripts/chat_rl.py`** retools the GRPO loop to sample, reward, and evaluate
232
- AQuA answers (categorical accuracy instead of GSM8K free-form math).
233
- - **`scripts/chat_eval.py`** registers the new AQuA task so `chat_eval` can report
234
- categorical accuracy alongside ARC/MMLU/GSM8K/HumanEval.
235
-
236
- ---
237
-
238
- ## Training Pipeline
239
-
240
- ### Stage 1: Base Pretraining (50-60% of time)
241
-
242
- **What happens**: Model learns language from scratch on FineWeb corpus
243
-
244
- ```bash
245
- torchrun --nproc_per_node=8 -m scripts.base_train -- --depth=8
246
- ```
247
-
248
- **Duration**: 1.5-2 hours on 8x H100
249
- **Output**: Base checkpoint with general language understanding
250
- **Metrics**: Validation loss, CORE benchmark scores
251
-
252
- ### Stage 2: Mid-Training (12-15% of time)
253
-
254
- **What happens**: Teach conversation format and special tokens
255
-
256
- ```bash
257
- torchrun --nproc_per_node=8 -m scripts.mid_train
258
- ```
259
-
260
- **Duration**: 30 minutes
261
- **Output**: Conversational checkpoint
262
- **Metrics**: Format adherence, tool use capability
263
-
264
- ### Stage 3: Supervised Fine-Tuning (12-15% of time)
265
-
266
- **What happens**: Fine-tune on AQuA-RAT with ground-truth solutions
267
-
268
- ```bash
269
- torchrun --nproc_per_node=8 -m scripts.sft_train -- \
270
- --aqua_train_examples=20000 \
271
- --aqua_val_examples=254
272
- ```
273
-
274
- **Duration**: 30 minutes
275
- **Output**: AQuA-tuned checkpoint
276
- **Metrics**: Dev set accuracy (categorical)
277
-
278
- ### Stage 4: Reinforcement Learning (12-15% of time)
279
-
280
- **What happens**: Policy gradient learning with GRPO algorithm
281
-
282
- ```bash
283
- torchrun --nproc_per_node=1 -m scripts.chat_rl -- \
284
- --temperature=0.7 \
285
- --max_new_tokens=64
286
- ```
287
-
288
- **Duration**: 30 minutes
289
- **Algorithm**: Group Relative Policy Optimization (GRPO)
290
- **Reward**: +1.0 for correct letter, +0.1 for valid letter format
291
- **Output**: RL-optimized checkpoint
292
-
293
- **Logged metrics**:
294
- - `rl/acc` - Accuracy on training samples
295
- - `rl/mean_reward` - Average reward per generation
296
- - `rl/kl_letter_mean` - KL divergence at decision point
297
- - `rl/kl_sequence_mean` - Full sequence KL
298
- - `rl/letter_margin_mean` - Confidence (logit gap)
299
- - `attn/entropy_mean` - Attention mechanism patterns
300
-
301
- ---
302
-
303
- ## Quick Start
304
-
305
- ### Repo Setup & Rust Toolchain
306
-
307
- - Clone with submodules so the `rustbpe` tokenizer sources are present:
308
- ```bash
309
- git clone --recurse-submodules https://github.com/HarleyCoops/nanochatAquaRat.git
310
- ```
311
- For existing clones run `git submodule update --init --recursive` before building.
312
- - Install Rust (needed for the tokenizer build). On Linux/macOS follow [https://rustup.rs](https://rustup.rs). On Windows, after installing rustup, ensure the toolchain is MSVC x86\_64 and the cargo bin directory is on `PATH`:
313
- ```powershell
314
- $env:Path += ";$env:USERPROFILE\.cargo\bin"
315
- setx PATH "$env:Path"
316
- setx CARGO_HOME "$env:USERPROFILE\.cargo"
317
- setx RUSTUP_HOME "$env:USERPROFILE\.rustup"
318
- rustup set default-host x86_64-pc-windows-msvc
319
- rustup default stable-x86_64-pc-windows-msvc
320
- cargo --version
321
- rustup --version
322
- ```
323
- - Build the tokenizer once per machine:
324
- ```bash
325
- uv run maturin develop
326
- ```
327
-
328
- ### Option 1: Lambda Labs Cloud (Automated)
329
-
330
- Use the automation helper for one-command deployment:
331
-
332
- ```bash
333
- # Set credentials
334
- export LAMBDA_API_KEY='your-lambda-api-key'
335
- export WANDB_API_KEY='your-wandb-api-key'
336
-
337
- # Launch with auto-start
338
- python scripts/launch_lambda_training.py \
339
- --ssh-key-name your_lambda_ssh_key \
340
- --instance-type gpu_8x_h100_sxm5 \
341
- --region us-west-1 \
342
- --auto-start \
343
- --inject-env WANDB_API_KEY
344
- ```
345
-
346
- The script provisions the instance, clones this repository, sets up environment variables, and starts training in a tmux session.
347
-
348
- **Monitor training**:
349
- ```bash
350
- # SSH to instance
351
- ssh ubuntu@<INSTANCE_IP>
352
-
353
- # Attach to tmux session
354
- tmux attach -t nanochat-train
355
-
356
- # Or view logs
357
- tail -f ~/nanochatAquaRat/training.log
358
- ```
359
-
360
- ### Option 2: Hyperbolic Labs Cloud (Automated)
361
-
362
- Spin up on-demand GPUs via Hyperbolic's marketplace API:
363
-
364
- ```bash
365
- # Set credentials
366
- export HYPERBOLIC_API_KEY='your-hyperbolic-api-key'
367
- export WANDB_API_KEY='your-wandb-api-key'
368
-
369
- # Launch with auto-start
370
- python scripts/launch_hyperbolic_training.py \
371
- --gpu-count 1 \
372
- --region us-east \
373
- --auto-start \
374
- --inject-env WANDB_API_KEY
375
- ```
376
-
377
- The launcher discovers an available node (respecting `--region`, `--supplier`, or `--max-price` filters), provisions it, copies your `.env`, and optionally starts training in tmux. Use `--list` to inspect available marketplace inventory without launching.
378
-
379
- ### Option 3: Lambda Labs Cloud (Manual)
380
-
381
- For step-by-step control, see [LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md).
382
-
383
- **Quick summary**:
384
- 1. Launch instance at https://cloud.lambdalabs.com/instances
385
- 2. SSH to instance: `ssh ubuntu@<IP>`
386
- 3. Clone repo: `git clone <repo-url> && cd nanochatAquaRat`
387
- 4. Set up credentials: `echo "WANDB_API_KEY=..." > .env`
388
- 5. Run training: `bash run_aquarat_small.sh`
389
-
390
- ### Option 4: Hyperbolic VM (Manual)
391
-
392
- For marketplace nodes without automation access, follow this lightweight bootstrap:
393
-
394
- 1. Provision a GPU VM from the Hyperbolic console and copy the SSH command (including `-p <port>` and username).
395
- 2. SSH in and install prerequisites:
396
- ```bash
397
- sudo apt-get update
398
- sudo apt-get install -y git curl unzip build-essential python3 python3-venv tmux
399
- git clone https://github.com/HarleyCoops/nanochatAquaRat.git
400
- cd nanochatAquaRat
401
- ```
402
- 3. Create `.env` with the required keys (WANDB, GCS bucket, AQUA path) and upload your GCP service-account JSON to the VM, e.g. `scp -P <port> C:\path\to\credentials.json user@<ip>:/home/user/gcp-sa.json`.
403
- 4. Install tooling and build the tokenizer:
404
- ```bash
405
- curl -LsSf https://astral.sh/uv/install.sh | sh
406
- curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
407
- source "$HOME/.cargo/env"
408
- export PATH="$HOME/.local/bin:$PATH"
409
- uv venv && uv sync --extra gpu
410
- source .venv/bin/activate
411
- uv run maturin develop
412
- uv run python -m scripts.tok_train
413
- ```
414
- 5. Install the Google Cloud SDK, authenticate, and stage the cached AQuA splits (or regenerate them):
415
- ```bash
416
- curl -sSL https://sdk.cloud.google.com | bash
417
- source "$HOME/.bashrc"
418
- gcloud auth login --no-launch-browser
419
- gcloud config set project <your-project-id>
420
- gcloud storage cp gs://nanochat-aquarat-datasets/datasets/aqua/aqua_cache.zip .
421
- unzip -o aqua_cache.zip -d ~/aqua_cache
422
- export AQUA_DATA_DIR=$HOME/aqua_cache
423
- ```
424
- 6. Fetch the identity conversation bundle (required for SFT) and the evaluation bundle once so CORE metrics don’t fail:
425
- ```bash
426
- cd ~/.cache/nanochat
427
- curl -L -o identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
428
- curl -L -o eval_bundle.zip https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
429
- unzip -q eval_bundle.zip && rm eval_bundle.zip
430
- cd ~/nanochatAquaRat
431
- ```
432
- 7. Launch the desired script, e.g. `CUDA_VISIBLE_DEVICES=0 bash run_aquarat_lite.sh` or the full `run_aquarat_small.sh`.
433
- 8. Monitor training via tmux/W&B and terminate the VM from Hyperbolic when the run finishes to stop billing.
434
-
435
- ### Option 4: Alternative Launcher Script
436
-
437
- A simplified launcher is also available:
438
-
439
- ```bash
440
- export LAMBDA_API_KEY='your-key'
441
- export WANDB_API_KEY='your-key'
442
-
443
- python launch_lambda.py \
444
- --instance-type gpu_8x_h100_sxm5 \
445
- --region us-west-1
446
- ```
447
-
448
- See [QUICKSTART.md](QUICKSTART.md) for details.
449
-
450
  ### Option 5: Local/Custom Setup
451
 
452
  ```bash
@@ -464,175 +464,191 @@ bash run_aquarat_small.sh
464
  - 40GB+ GPU memory per GPU
465
  - ~100GB disk space
466
 
467
- ---
468
-
469
- ## File Structure
470
 
471
- ```
472
- nanochatAquaRat/
473
- ├── nanochat/… # Vendored upstream nanochat package
474
- ├── scripts/
475
- │ ├── base_train.py # Base pretraining stage
476
- │ ├── mid_train.py # Mid-training (now includes AQuA)
477
- │ ├── chat_sft.py # Chat SFT pipeline
478
- │ ├── sft_train.py # Shim so `-m scripts.sft_train` still works
479
- │ ├── chat_rl.py # Reinforcement learning on AQuA-RAT
480
- │ ├── chat_eval.py # Evaluation harness (adds AQuA task)
481
- │ ├── prepare_aqua.py # AQuA-RAT JSONL exporter
482
- │ ├── launch_lambda_training.py # Lambda Labs automation
483
- │ ├── launch_hyperbolic_training.py # Hyperbolic Labs automation
484
- │ └── upload_to_gcs.sh # Artifact helper
485
- ├── tasks/
486
- │ ├── aqua.py # AQuA-RAT task implementation
487
- │ ├── arc.py / gsm8k.py / mmlu.py # Other reasoning tasks
488
- │ └── …
489
- ├── run_aquarat_small.sh # End-to-end orchestration
490
- ├── pyproject.toml / uv.lock # Environment definitions
491
- └── README.md
492
- ```
493
- ### Summary of Code Changes
494
 
495
- | File | Type | Description |
496
- |------|------|-------------|
497
- | `tasks/aqua.py` | NEW | Conversation + evaluation wrapper for AQuA-RAT |
498
- | `scripts/prepare_aqua.py` | NEW | Materializes train/validation/test JSONL splits for offline use |
499
- | `scripts/mid_train.py` | MODIFIED | Adds AQuA to the mid-training mixture |
500
- | `scripts/chat_sft.py` | MODIFIED | SFT mixture now includes AQuA controls |
501
- | `scripts/sft_train.py` | NEW | Thin compatibility shim around `chat_sft` |
502
- | `scripts/chat_rl.py` | MODIFIED | RL loop retargeted from GSM8K to AQuA-RAT |
503
- | `scripts/chat_eval.py` | MODIFIED | Registers AQuA for categorical evaluation |
504
- | `run_aquarat_small.sh` | MODIFIED | Pipeline glue aligned with AQuA staging |
505
- | `scripts/launch_hyperbolic_training.py` | NEW | Hyperbolic Labs automation helper |
506
- | `launch_lambda.py` / `scripts/launch_lambda_training.py` | EXISTING | Lambda Labs support retained |
507
 
508
  ---
509
 
510
- ## Monitoring & Visualization
511
-
512
- All metrics stream to [Weights & Biases](https://wandb.ai) in real-time:
513
-
514
- **Training Metrics**:
515
- - Loss curves (pretraining, SFT, RL)
516
- - Learning rate schedules
517
- - Gradient norms
518
-
519
- **RL Metrics**:
520
- - Policy performance (accuracy, rewards)
521
- - KL divergence from initial policy
522
- - Letter-choice distributions (A-E)
523
- - Confidence margins
524
-
525
- **Interpretability**:
526
- - Attention heatmaps per layer
527
- - Entropy evolution across training
528
- - Token-level attention weights
529
 
530
- Example W&B dashboard:
531
- ```
532
- rl/acc ━━━━━━━━━━ 0.45
533
- rl/kl_letter_mean ━━━━━━━━━━ 0.12
534
- rl/letter_margin_mean ━━━━━━━━━━ 2.34
535
- attn/entropy_mean ━━━━━━━━━━ 3.21
536
  ```
537
-
538
- ---
539
-
540
- ## Results
541
-
542
- ### Model Configurations
543
-
544
- | Depth | Parameters | Training Time | Best Instance Type | Estimated Cost |
545
- |-------|------------|---------------|-------------------|----------------|
546
- | 8 | ~60M | 3-4 hours | 1-2x A100 | ~$18-35 |
547
- | 12 | ~180M | 4-5 hours | 4x A100 | ~$35-45 |
548
- | 20 | ~561M | 6-8 hours | 8x H100 | ~$144-192 |
549
- | 26 | ~1.1B | 10-12 hours | 8x H100 | ~$240-288 |
550
-
551
- To change model depth, edit the `--depth` parameter in `run_aquarat_small.sh`.
552
-
553
- ### Expected Performance
554
-
555
- **After SFT** (before RL):
556
- - Dev accuracy: 20-30% (depth-8), 30-40% (depth-20)
557
- - Basic problem-solving capability
558
- - Some format errors (invalid letters)
559
-
560
- **After RL**:
561
- - Dev accuracy: 30-50% (depth-8), 40-60% (depth-20)
562
- - Improved reasoning coherence
563
- - Better multiple-choice selection confidence
564
- - Reduced format errors
565
- - Stable attention patterns
566
-
567
- ### Cost Management
568
-
569
- Lambda Labs pricing (8x H100 SXM5 @ ~$24/hour):
570
-
571
- | Model | Training Time | Total Cost |
572
- |-------|---------------|------------|
573
- | depth-8 (60M) | 3-4 hours | ~$96 |
574
- | depth-20 (561M) | 6-8 hours | ~$192 |
575
-
576
- Budget options:
577
- - Test pipeline: 1x A10 @ $0.60/hr
578
- - Small model: 2x A100 @ $4.40/hr
579
- - Production: 8x H100 @ $24/hr
580
-
581
- ---
582
-
583
- ## Important Notes
584
-
585
- ### For Lambda Labs Users
586
- - **Always terminate instances** after training to avoid charges
587
- - Monitor spending in the Lambda Labs dashboard
588
- - Check instance availability before launching (high demand periods)
589
-
590
- ### Known Limitations
591
- - RL on AQuA-RAT is experimental; results may vary
592
- - Attention logging adds ~5-10% overhead
593
- - KL computation can be expensive with large batch sizes
594
- - Smaller models (<100M params) may struggle with complex reasoning
595
-
596
- ---
597
-
598
- ## Documentation
599
-
600
- - **[scripts/launch_lambda_training.py](scripts/launch_lambda_training.py)** - Full-featured automation
601
- - **[scripts/launch_hyperbolic_training.py](scripts/launch_hyperbolic_training.py)** - Hyperbolic marketplace automation
602
- - **[launch_lambda.py](launch_lambda.py)** - Simplified launcher
603
- - **[QUICKSTART.md](QUICKSTART.md)** - Fast track guide
604
- - **[LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md)** - Manual setup walkthrough
605
- - **[GCS_UPLOAD_GUIDE.md](GCS_UPLOAD_GUIDE.md)** - Upload weights to Google Cloud Storage
606
- - **[.env.template](.env.template)** - Environment configuration
607
-
608
- ---
609
-
610
- ## Contributing
611
-
612
- This project is based on the nanochat framework. For issues specific to:
613
- - **AQuA-RAT training**: Open an issue in this repository
614
- - **Base nanochat framework**: Refer to the upstream nanochat project
615
- - **Lambda Labs deployment**: See documentation above
616
-
617
- ---
618
-
619
- ## License
620
-
621
- This project inherits the license from the base nanochat project.
622
-
623
- ---
624
-
625
- ## Acknowledgments
626
-
627
- - **Andrej Karpathy** - nanochat framework
628
- - **DeepMind** - AQuA-RAT dataset and mechanistic interpretability tools
629
- - **Lambda Labs** - Cloud GPU infrastructure
630
- - **Weights & Biases** - Experiment tracking and visualization
631
-
632
- ---
633
-
634
- ## Support
635
-
636
- - **Lambda Labs Support**: https://lambdalabs.com/support
637
- - **Weights & Biases Docs**: https://docs.wandb.ai
638
- - **Project Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ---
14
 
15
  <div align="center">
16
+
17
+ ![AQuA-RAT Training](./aquarat2.png)
18
+
19
+ # nanochatAquaRat
20
+
21
+ **Training Language Models with Reinforcement Learning on Mathematical Reasoning**
22
+
23
+ [![GitHub](https://img.shields.io/badge/GitHub-Repository-blue?logo=github)](https://github.com/HarleyCoops/nanochatAquaRat)
24
+ [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
25
+ [![Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/downloads/)
26
+
27
+ A modified version of [nanochat](https://github.com/karpathy/nanochat) trained with reinforcement learning on the [DeepMind AQuA-RAT dataset](https://huggingface.co/datasets/deepmind/aqua_rat) for algebraic reasoning and multiple-choice problem solving.
28
+
29
+ [Quick Start](#quick-start) • [Dataset](#dataset-structure) • [Modifications](#modifications-from-base-nanochat) • [Training](#training-pipeline) • [Results](#results)
30
+
31
+ </div>
32
+
33
+ ---
34
+
35
+ ## Table of Contents
36
+
37
+ - [Overview](#overview)
38
+ - [The Base: nanochat Framework](#the-base-nanochat-framework)
39
+ - [Dataset Structure](#dataset-structure)
40
+ - [Modifications from Base nanochat](#modifications-from-base-nanochat)
41
+ - [Training Pipeline](#training-pipeline)
42
+ - [Quick Start](#quick-start)
43
+ - [File Structure](#file-structure)
44
+ - [Monitoring & Visualization](#monitoring--visualization)
45
+ - [Results](#results)
46
+
47
+ ---
48
+
49
+ ## Overview
50
+
51
+ This project adapts the **nanochat** training framework (originally designed for GSM8K numerical reasoning) to work with **AQuA-RAT** (Algebra Question Answering with Rationales), a dataset of ~97,000 algebraic word problems with multiple-choice answers (A-E) and natural language solution rationales.
52
+
53
+ ### Why This Matters
54
+
55
+ - **Domain Transfer**: Demonstrates how to adapt a mathematical reasoning pipeline from free-form numeric answers to multiple-choice format
56
+ - **RL on Math**: Implements GRPO-style reinforcement learning with reward shaping for categorical outputs
57
+ - **Mechanistic Interpretability**: Integrates attention analysis during training to understand model reasoning patterns
58
+ - **Production-Ready**: Includes automated Lambda Labs and Hyperbolic Labs deployment helpers for cloud GPU training
59
+
60
+ ### Key Results
61
+
62
+ | Model | Parameters | Training Time | AQuA-RAT Dev Accuracy |
63
+ |-------|------------|---------------|----------------------|
64
+ | depth-8 | ~60M | 3-4 hours | 30-50% |
65
+ | depth-20 | ~561M | 6-8 hours | 40-60% |
66
+
67
+ ---
68
+
69
+ ## The Base: nanochat Framework
70
+
71
+ **nanochat** is a minimalist yet complete pipeline for training transformer language models from scratch, created by Andrej Karpathy. It implements:
72
+
73
+ - **Custom tokenizer**: BPE tokenizer written in Rust for performance
74
+ - **Training stages**: Pretraining → Mid-training → SFT → RL
75
+ - **Evaluation suite**: CORE benchmarks and task-specific metrics
76
+ - **Optimizations**: Memory-efficient training, gradient accumulation, distributed training
77
+
78
+ **Original focus**: Training on GSM8K (Grade School Math 8K) with free-form numeric answers.
79
+
80
+
81
+ ---
82
+
83
+ ## Dataset Structure
84
+
85
+ ### AQuA-RAT Format
86
+
87
+ The [DeepMind AQuA-RAT dataset](https://github.com/deepmind/AQuA) contains algebraic reasoning problems in JSON format:
88
+
89
+ ```json
90
+ {
91
+ "question": "A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?",
92
+ "options": [
93
+ "A) 53 km",
94
+ "B) 55 km",
95
+ "C) 52 km",
96
+ "D) 60 km",
97
+ "E) 50 km"
98
+ ],
99
+ "rationale": "The distance that the person traveled = 20 * 2.5 = 50 km. Answer: E",
100
+ "correct": "E"
101
+ }
102
+ ```
103
+
104
+ **Dataset splits**:
105
+ - Training: 97,467 problems
106
+ - Development: 254 problems
107
+ - Test: 254 problems
108
+
109
+ **Key characteristics**:
110
+ - Multiple-choice (A-E) format
111
+ - Algebraic word problems
112
+ - Natural language rationales
113
+ - Topics: arithmetic, algebra, geometry, probability
114
+
115
+ ### Comparison: GSM8K vs AQuA-RAT
116
+
117
+ | Aspect | GSM8K (Original) | AQuA-RAT (This Project) |
118
+ |--------|------------------|-------------------------|
119
+ | **Format** | Free-form numeric | Multiple choice (A-E) |
120
+ | **Answer** | Single number | Letter choice |
121
+ | **Size** | 8,500 problems | 97,700 problems |
122
+ | **Difficulty** | Elementary school | High school algebra |
123
+ | **Rationale** | Step-by-step | Natural language |
124
+ | **Evaluation** | Exact match on number | Categorical accuracy |
125
+
126
+ ---
127
+
128
+ ## Modifications from Base nanochat
129
+
130
+ To adapt nanochat from GSM8K to AQuA-RAT, we modified the following components:
131
+
132
+ ### 1. Dataset Loader (`scripts/prepare_aqua.py`)
133
+
134
+ **Created new file** to download and format AQuA-RAT:
135
+
136
+ ```python
137
+ # New file: scripts/prepare_aqua.py
138
+ ### 1. Dataset Preparation (`scripts/prepare_aqua.py`)
139
+
140
+ - Uses `datasets.load_dataset("deepmind/aqua_rat")` and optionally caps split sizes.
141
+ - Emits JSONL files (`train.jsonl`, `validation.jsonl`, `test.jsonl`) compatible with
142
+ the conversation schema used throughout nanochat.
143
+ - Defaults to `~/.cache/nanochat/aqua`, but accepts `--output_dir` overrides so
144
+ launchers can bundle their own artifact.
145
+
146
+ ```python
147
+ def format_example(row):
148
+ options = row["options"]
149
+ assistant_content = [
150
+ {"type": "text", "text": row["rationale"].strip()},
151
+ {"type": "text", "text": f"Answer: {row['correct'].strip().upper()}"},
152
+ ]
153
+ return {
154
+ "messages": [
155
+ {"role": "user", "content": _render_user_prompt(row["question"], options)},
156
+ {"role": "assistant", "content": assistant_content},
157
+ ],
158
+ "letters": letters,
159
+ "answer_letter": correct,
160
+ }
161
+ ```
162
+
163
+ ### 2. Task Module (`tasks/aqua.py`)
164
+
165
+ - Accepts optional `data_dir` (or `AQUA_DATA_DIR` / `NANOCHAT_AQUA_DIR`) so the task
166
+ can read the cached JSONL; otherwise falls back to Hugging Face.
167
+ - Provides `_render_user_prompt` to format the question/options using the common
168
+ multiple-choice helper and `_extract_letter` to score completions.
169
+ - Returns conversations whose assistant messages include both the rationale and a
170
+ final `Answer: <LETTER>` line for SFT, while `evaluate()` only cares about the letter.
171
+
172
+ ```python
173
+ def _extract_letter(text, default=None):
174
+ answer_match = re.search(r"answer\s*[:\-]\s*([A-E])", text, flags=re.IGNORECASE)
175
+ if answer_match:
176
+ return answer_match.group(1).upper()
177
+ match = LETTER_RE.search(text)
178
+ return match.group(1).upper() if match else default
179
+ ```
180
+
181
+ **Key differences from GSM8K**:
182
+ - Numeric extraction → Letter extraction
183
+ - Free-form answer → Fixed choices A-E
184
+ - Exact number match → Categorical match
185
+
186
+ ### 3. RL Training (`scripts/chat_rl.py`)
187
+
188
+ **Modified** to support both GSM8K and AQuA-RAT:
189
+
190
+ Key updates:
191
+
192
+ - `train_task` / `val_task` now instantiate `AQUA(...)` instead of `GSM8K(...)`.
193
+ - Rewards reuse the task's `evaluate()` helper so any completion containing
194
+ “Answer: X” (or the first bare letter) is scored correctly.
195
+ - The validation helper became `run_aqua_eval`, still reporting pass@k accuracy
196
+ across sampled completions.
197
+ - CLI overrides remain the same because the script continues to rely on the
198
+ nanochat configurator (`--run`, `--temperature`, `--max_new_tokens`, …).
199
+
200
+ ### 4. Evaluation (`scripts/chat_eval.py`)
201
+
202
+ - Registered `'AQUA'` in the task registry so `-a AQUA` just works.
203
+ - Added a 20% random-guess baseline when aggregating the ChatCORE metric.
204
+ - The categorical evaluation path reuses `run_categorical_eval`, clamping logits
205
+ to the available letters before scoring.
206
+
207
+ ### 5. Training Script (`run_aquarat_small.sh`)
208
+
209
+ **What changed vs upstream nanochat**:
210
+
211
+ ```bash
212
+ # (Optional) Cache the dataset locally as JSONL
213
+ python -m scripts.prepare_aqua --output_dir "$NANOCHAT_BASE_DIR/aqua"
214
+
215
+ # Mid-training now samples from the AQuA mixture
216
+ torchrun -m scripts.mid_train -- --run=demo --num_iterations=200
217
+
218
+ # SFT stage emphasises AQuA problems
219
+ torchrun -m scripts.sft_train -- --run=demo --aqua_train_examples=20000
220
+
221
+ # RL fine-tuning rewards the correct letter on AQuA-RAT
222
+ torchrun -m scripts.chat_rl -- --run=demo --temperature=0.7 --max_new_tokens=64
223
+ ```
224
+
225
+ - **`tasks/aqua.py`** loads AQuA-RAT either from Hugging Face or the cached JSONL
226
+ splits, formats questions as conversations, and scores completions by letter.
227
+ - **`scripts/mid_train.py`** extends the original Reasoning+Chat mixture with a
228
+ 50k slice of AQuA so the model sees multiple-choice algebra earlier.
229
+ - **`scripts/chat_sft.py`** replaces the GSM8K component with AQuA, keeping ARC,
230
+ SmolTalk, and identity prompts for general chat coverage.
231
+ - **`scripts/chat_rl.py`** retools the GRPO loop to sample, reward, and evaluate
232
+ AQuA answers (categorical accuracy instead of GSM8K free-form math).
233
+ - **`scripts/chat_eval.py`** registers the new AQuA task so `chat_eval` can report
234
+ categorical accuracy alongside ARC/MMLU/GSM8K/HumanEval.
235
+
236
+ ---
237
+
238
+ ## Training Pipeline
239
+
240
+ ### Stage 1: Base Pretraining (50-60% of time)
241
+
242
+ **What happens**: Model learns language from scratch on FineWeb corpus
243
+
244
+ ```bash
245
+ torchrun --nproc_per_node=8 -m scripts.base_train -- --depth=8
246
+ ```
247
+
248
+ **Duration**: 1.5-2 hours on 8x H100
249
+ **Output**: Base checkpoint with general language understanding
250
+ **Metrics**: Validation loss, CORE benchmark scores
251
+
252
+ ### Stage 2: Mid-Training (12-15% of time)
253
+
254
+ **What happens**: Teach conversation format and special tokens
255
+
256
+ ```bash
257
+ torchrun --nproc_per_node=8 -m scripts.mid_train
258
+ ```
259
+
260
+ **Duration**: 30 minutes
261
+ **Output**: Conversational checkpoint
262
+ **Metrics**: Format adherence, tool use capability
263
+
264
+ ### Stage 3: Supervised Fine-Tuning (12-15% of time)
265
+
266
+ **What happens**: Fine-tune on AQuA-RAT with ground-truth solutions
267
+
268
+ ```bash
269
+ torchrun --nproc_per_node=8 -m scripts.sft_train -- \
270
+ --aqua_train_examples=20000 \
271
+ --aqua_val_examples=254
272
+ ```
273
+
274
+ **Duration**: 30 minutes
275
+ **Output**: AQuA-tuned checkpoint
276
+ **Metrics**: Dev set accuracy (categorical)
277
+
278
+ ### Stage 4: Reinforcement Learning (12-15% of time)
279
+
280
+ **What happens**: Policy gradient learning with GRPO algorithm
281
+
282
+ ```bash
283
+ torchrun --nproc_per_node=1 -m scripts.chat_rl -- \
284
+ --temperature=0.7 \
285
+ --max_new_tokens=64
286
+ ```
287
+
288
+ **Duration**: 30 minutes
289
+ **Algorithm**: Group Relative Policy Optimization (GRPO)
290
+ **Reward**: +1.0 for correct letter, +0.1 for valid letter format
291
+ **Output**: RL-optimized checkpoint
292
+
293
+ **Logged metrics**:
294
+ - `rl/acc` - Accuracy on training samples
295
+ - `rl/mean_reward` - Average reward per generation
296
+ - `rl/kl_letter_mean` - KL divergence at decision point
297
+ - `rl/kl_sequence_mean` - Full sequence KL
298
+ - `rl/letter_margin_mean` - Confidence (logit gap)
299
+ - `attn/entropy_mean` - Attention mechanism patterns
300
+
301
+ ---
302
+
303
+ ## Quick Start
304
+
305
+ ### Repo Setup & Rust Toolchain
306
+
307
+ - Clone with submodules so the `rustbpe` tokenizer sources are present:
308
+ ```bash
309
+ git clone --recurse-submodules https://github.com/HarleyCoops/nanochatAquaRat.git
310
+ ```
311
+ For existing clones run `git submodule update --init --recursive` before building.
312
+ - Install Rust (needed for the tokenizer build). On Linux/macOS follow [https://rustup.rs](https://rustup.rs). On Windows, after installing rustup, ensure the toolchain is MSVC x86\_64 and the cargo bin directory is on `PATH`:
313
+ ```powershell
314
+ $env:Path += ";$env:USERPROFILE\.cargo\bin"
315
+ setx PATH "$env:Path"
316
+ setx CARGO_HOME "$env:USERPROFILE\.cargo"
317
+ setx RUSTUP_HOME "$env:USERPROFILE\.rustup"
318
+ rustup set default-host x86_64-pc-windows-msvc
319
+ rustup default stable-x86_64-pc-windows-msvc
320
+ cargo --version
321
+ rustup --version
322
+ ```
323
+ - Build the tokenizer once per machine:
324
+ ```bash
325
+ uv run maturin develop
326
+ ```
327
+
328
+ ### Option 1: Lambda Labs Cloud (Automated)
329
+
330
+ Use the automation helper for one-command deployment:
331
+
332
+ ```bash
333
+ # Set credentials
334
+ export LAMBDA_API_KEY='your-lambda-api-key'
335
+ export WANDB_API_KEY='your-wandb-api-key'
336
+
337
+ # Launch with auto-start
338
+ python scripts/launch_lambda_training.py \
339
+ --ssh-key-name your_lambda_ssh_key \
340
+ --instance-type gpu_8x_h100_sxm5 \
341
+ --region us-west-1 \
342
+ --auto-start \
343
+ --inject-env WANDB_API_KEY
344
+ ```
345
+
346
+ The script provisions the instance, clones this repository, sets up environment variables, and starts training in a tmux session.
347
+
348
+ **Monitor training**:
349
+ ```bash
350
+ # SSH to instance
351
+ ssh ubuntu@<INSTANCE_IP>
352
+
353
+ # Attach to tmux session
354
+ tmux attach -t nanochat-train
355
+
356
+ # Or view logs
357
+ tail -f ~/nanochatAquaRat/training.log
358
+ ```
359
+
360
+ ### Option 2: Hyperbolic Labs Cloud (Automated)
361
+
362
+ Spin up on-demand GPUs via Hyperbolic's marketplace API:
363
+
364
+ ```bash
365
+ # Set credentials
366
+ export HYPERBOLIC_API_KEY='your-hyperbolic-api-key'
367
+ export WANDB_API_KEY='your-wandb-api-key'
368
+
369
+ # Launch with auto-start
370
+ python scripts/launch_hyperbolic_training.py \
371
+ --gpu-count 1 \
372
+ --region us-east \
373
+ --auto-start \
374
+ --inject-env WANDB_API_KEY
375
+ ```
376
+
377
+ The launcher discovers an available node (respecting `--region`, `--supplier`, or `--max-price` filters), provisions it, copies your `.env`, and optionally starts training in tmux. Use `--list` to inspect available marketplace inventory without launching.
378
+
379
+ ### Option 3: Lambda Labs Cloud (Manual)
380
+
381
+ For step-by-step control, see [LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md).
382
+
383
+ **Quick summary**:
384
+ 1. Launch instance at https://cloud.lambdalabs.com/instances
385
+ 2. SSH to instance: `ssh ubuntu@<IP>`
386
+ 3. Clone repo: `git clone <repo-url> && cd nanochatAquaRat`
387
+ 4. Set up credentials: `echo "WANDB_API_KEY=..." > .env`
388
+ 5. Run training: `bash run_aquarat_small.sh`
389
+
390
+ ### Option 4: Hyperbolic VM (Manual)
391
+
392
+ For marketplace nodes without automation access, follow this lightweight bootstrap:
393
+
394
+ 1. Provision a GPU VM from the Hyperbolic console and copy the SSH command (including `-p <port>` and username).
395
+ 2. SSH in and install prerequisites:
396
+ ```bash
397
+ sudo apt-get update
398
+ sudo apt-get install -y git curl unzip build-essential python3 python3-venv tmux
399
+ git clone https://github.com/HarleyCoops/nanochatAquaRat.git
400
+ cd nanochatAquaRat
401
+ ```
402
+ 3. Create `.env` with the required keys (WANDB, GCS bucket, AQUA path) and upload your GCP service-account JSON to the VM, e.g. `scp -P <port> C:\path\to\credentials.json user@<ip>:/home/user/gcp-sa.json`.
403
+ 4. Install tooling and build the tokenizer:
404
+ ```bash
405
+ curl -LsSf https://astral.sh/uv/install.sh | sh
406
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
407
+ source "$HOME/.cargo/env"
408
+ export PATH="$HOME/.local/bin:$PATH"
409
+ uv venv && uv sync --extra gpu
410
+ source .venv/bin/activate
411
+ uv run maturin develop
412
+ uv run python -m scripts.tok_train
413
+ ```
414
+ 5. Install the Google Cloud SDK, authenticate, and stage the cached AQuA splits (or regenerate them):
415
+ ```bash
416
+ curl -sSL https://sdk.cloud.google.com | bash
417
+ source "$HOME/.bashrc"
418
+ gcloud auth login --no-launch-browser
419
+ gcloud config set project <your-project-id>
420
+ gcloud storage cp gs://nanochat-aquarat-datasets/datasets/aqua/aqua_cache.zip .
421
+ unzip -o aqua_cache.zip -d ~/aqua_cache
422
+ export AQUA_DATA_DIR=$HOME/aqua_cache
423
+ ```
424
+ 6. Fetch the identity conversation bundle (required for SFT) and the evaluation bundle once so CORE metrics don’t fail:
425
+ ```bash
426
+ cd ~/.cache/nanochat
427
+ curl -L -o identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
428
+ curl -L -o eval_bundle.zip https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
429
+ unzip -q eval_bundle.zip && rm eval_bundle.zip
430
+ cd ~/nanochatAquaRat
431
+ ```
432
+ 7. Launch the desired script, e.g. `CUDA_VISIBLE_DEVICES=0 bash run_aquarat_lite.sh` or the full `run_aquarat_small.sh`.
433
+ 8. Monitor training via tmux/W&B and terminate the VM from Hyperbolic when the run finishes to stop billing.
434
+
435
+ ### Option 4: Alternative Launcher Script
436
+
437
+ A simplified launcher is also available:
438
+
439
+ ```bash
440
+ export LAMBDA_API_KEY='your-key'
441
+ export WANDB_API_KEY='your-key'
442
+
443
+ python launch_lambda.py \
444
+ --instance-type gpu_8x_h100_sxm5 \
445
+ --region us-west-1
446
+ ```
447
+
448
+ See [QUICKSTART.md](QUICKSTART.md) for details.
449
+
450
  ### Option 5: Local/Custom Setup
451
 
452
  ```bash
 
464
  - 40GB+ GPU memory per GPU
465
  - ~100GB disk space
466
 
467
+ ## Hugging Face Sync
 
 
468
 
469
+ Keep the GitHub docs mirrored with the Hugging Face model card:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
+ 1. Edit `README.md` (and any linked docs) as usual.
472
+ 2. Stage the release payload locally:
473
+ ```bash
474
+ uv run python -m scripts.sync_hf_repo --no-push
475
+ ```
476
+ This copies every README dependency into `hf_release/`. The script warns if a referenced file such as `LICENSE` is missing.
477
+ 3. Push the staged contents to Hugging Face once you are satisfied:
478
+ ```bash
479
+ uv run python -m scripts.sync_hf_repo --repo-id HarleyCooper/nanochatAquaRat
480
+ ```
481
+ The command requires prior `huggingface-cli login` (or an `HF_TOKEN` env var). Use `--dry-run` to review operations without copying or uploading.
 
482
 
483
  ---
484
 
485
+ ## File Structure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
 
 
 
 
 
 
487
  ```
488
+ nanochatAquaRat/
489
+ ├── nanochat/… # Vendored upstream nanochat package
490
+ ├── scripts/
491
+ │ ├── base_train.py # Base pretraining stage
492
+ │ ├── mid_train.py # Mid-training (now includes AQuA)
493
+ │ ├── chat_sft.py # Chat SFT pipeline
494
+ │ ├── sft_train.py # Shim so `-m scripts.sft_train` still works
495
+ │ ├── chat_rl.py # Reinforcement learning on AQuA-RAT
496
+ │ ├── chat_eval.py # Evaluation harness (adds AQuA task)
497
+ │ ├── prepare_aqua.py # AQuA-RAT JSONL exporter
498
+ │ ├── launch_lambda_training.py # Lambda Labs automation
499
+ │ ├── launch_hyperbolic_training.py # Hyperbolic Labs automation
500
+ │ └── upload_to_gcs.sh # Artifact helper
501
+ ├── tasks/
502
+ │ ├── aqua.py # AQuA-RAT task implementation
503
+ │ ├── arc.py / gsm8k.py / mmlu.py # Other reasoning tasks
504
+ │ └──
505
+ ├── run_aquarat_small.sh # End-to-end orchestration
506
+ ├── pyproject.toml / uv.lock # Environment definitions
507
+ └── README.md
508
+ ```
509
+ ### Summary of Code Changes
510
+
511
+ | File | Type | Description |
512
+ |------|------|-------------|
513
+ | `tasks/aqua.py` | NEW | Conversation + evaluation wrapper for AQuA-RAT |
514
+ | `scripts/prepare_aqua.py` | NEW | Materializes train/validation/test JSONL splits for offline use |
515
+ | `scripts/mid_train.py` | MODIFIED | Adds AQuA to the mid-training mixture |
516
+ | `scripts/chat_sft.py` | MODIFIED | SFT mixture now includes AQuA controls |
517
+ | `scripts/sft_train.py` | NEW | Thin compatibility shim around `chat_sft` |
518
+ | `scripts/chat_rl.py` | MODIFIED | RL loop retargeted from GSM8K to AQuA-RAT |
519
+ | `scripts/chat_eval.py` | MODIFIED | Registers AQuA for categorical evaluation |
520
+ | `run_aquarat_small.sh` | MODIFIED | Pipeline glue aligned with AQuA staging |
521
+ | `scripts/launch_hyperbolic_training.py` | NEW | Hyperbolic Labs automation helper |
522
+ | `launch_lambda.py` / `scripts/launch_lambda_training.py` | EXISTING | Lambda Labs support retained |
523
+
524
+ ---
525
+
526
+ ## Monitoring & Visualization
527
+
528
+ All metrics stream to [Weights & Biases](https://wandb.ai) in real-time:
529
+
530
+ **Training Metrics**:
531
+ - Loss curves (pretraining, SFT, RL)
532
+ - Learning rate schedules
533
+ - Gradient norms
534
+
535
+ **RL Metrics**:
536
+ - Policy performance (accuracy, rewards)
537
+ - KL divergence from initial policy
538
+ - Letter-choice distributions (A-E)
539
+ - Confidence margins
540
+
541
+ **Interpretability**:
542
+ - Attention heatmaps per layer
543
+ - Entropy evolution across training
544
+ - Token-level attention weights
545
+
546
+ Example W&B dashboard:
547
+ ```
548
+ rl/acc ━━━━━━━━━━ 0.45
549
+ rl/kl_letter_mean ━━━━━━━━━━ 0.12
550
+ rl/letter_margin_mean ━━━━━━━━━━ 2.34
551
+ attn/entropy_mean ━━━━━━━━━━ 3.21
552
+ ```
553
+
554
+ ---
555
+
556
+ ## Results
557
+
558
+ ### Model Configurations
559
+
560
+ | Depth | Parameters | Training Time | Best Instance Type | Estimated Cost |
561
+ |-------|------------|---------------|-------------------|----------------|
562
+ | 8 | ~60M | 3-4 hours | 1-2x A100 | ~$18-35 |
563
+ | 12 | ~180M | 4-5 hours | 4x A100 | ~$35-45 |
564
+ | 20 | ~561M | 6-8 hours | 8x H100 | ~$144-192 |
565
+ | 26 | ~1.1B | 10-12 hours | 8x H100 | ~$240-288 |
566
+
567
+ To change model depth, edit the `--depth` parameter in `run_aquarat_small.sh`.
568
+
569
+ ### Expected Performance
570
+
571
+ **After SFT** (before RL):
572
+ - Dev accuracy: 20-30% (depth-8), 30-40% (depth-20)
573
+ - Basic problem-solving capability
574
+ - Some format errors (invalid letters)
575
+
576
+ **After RL**:
577
+ - Dev accuracy: 30-50% (depth-8), 40-60% (depth-20)
578
+ - Improved reasoning coherence
579
+ - Better multiple-choice selection confidence
580
+ - Reduced format errors
581
+ - Stable attention patterns
582
+
583
+ ### Cost Management
584
+
585
+ Lambda Labs pricing (8x H100 SXM5 @ ~$24/hour):
586
+
587
+ | Model | Training Time | Total Cost |
588
+ |-------|---------------|------------|
589
+ | depth-8 (60M) | 3-4 hours | ~$96 |
590
+ | depth-20 (561M) | 6-8 hours | ~$192 |
591
+
592
+ Budget options:
593
+ - Test pipeline: 1x A10 @ $0.60/hr
594
+ - Small model: 2x A100 @ $4.40/hr
595
+ - Production: 8x H100 @ $24/hr
596
+
597
+ ---
598
+
599
+ ## Important Notes
600
+
601
+ ### For Lambda Labs Users
602
+ - **Always terminate instances** after training to avoid charges
603
+ - Monitor spending in the Lambda Labs dashboard
604
+ - Check instance availability before launching (high demand periods)
605
+
606
+ ### Known Limitations
607
+ - RL on AQuA-RAT is experimental; results may vary
608
+ - Attention logging adds ~5-10% overhead
609
+ - KL computation can be expensive with large batch sizes
610
+ - Smaller models (<100M params) may struggle with complex reasoning
611
+
612
+ ---
613
+
614
+ ## Documentation
615
+
616
+ - **[scripts/launch_lambda_training.py](scripts/launch_lambda_training.py)** - Full-featured automation
617
+ - **[scripts/launch_hyperbolic_training.py](scripts/launch_hyperbolic_training.py)** - Hyperbolic marketplace automation
618
+ - **[launch_lambda.py](launch_lambda.py)** - Simplified launcher
619
+ - **[QUICKSTART.md](QUICKSTART.md)** - Fast track guide
620
+ - **[LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md)** - Manual setup walkthrough
621
+ - **[GCS_UPLOAD_GUIDE.md](GCS_UPLOAD_GUIDE.md)** - Upload weights to Google Cloud Storage
622
+ - **[.env.template](.env.template)** - Environment configuration
623
+
624
+ ---
625
+
626
+ ## Contributing
627
+
628
+ This project is based on the nanochat framework. For issues specific to:
629
+ - **AQuA-RAT training**: Open an issue in this repository
630
+ - **Base nanochat framework**: Refer to the upstream nanochat project
631
+ - **Lambda Labs deployment**: See documentation above
632
+
633
+ ---
634
+
635
+ ## License
636
+
637
+ This project inherits the license from the base nanochat project.
638
+
639
+ ---
640
+
641
+ ## Acknowledgments
642
+
643
+ - **Andrej Karpathy** - nanochat framework
644
+ - **DeepMind** - AQuA-RAT dataset and mechanistic interpretability tools
645
+ - **Lambda Labs** - Cloud GPU infrastructure
646
+ - **Weights & Biases** - Experiment tracking and visualization
647
+
648
+ ---
649
+
650
+ ## Support
651
+
652
+ - **Lambda Labs Support**: https://lambdalabs.com/support
653
+ - **Weights & Biases Docs**: https://docs.wandb.ai
654
+ - **Project Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
aquarat2.png ADDED

Git LFS Details

  • SHA256: 21ba080c26864a6e08c59fe26a7add4d08989850765b6f420914fa590cd3d49f
  • Pointer size: 131 Bytes
  • Size of remote file: 559 kB
launch_lambda.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Lambda Labs GPU Instance Launcher for AQuA-RAT Training
4
+
5
+ This script automates launching an 8x H100 GPU instance on Lambda Labs
6
+ and deploying the nanochatAquaRat training pipeline.
7
+
8
+ Prerequisites:
9
+ 1. Lambda Labs API key (set as LAMBDA_API_KEY environment variable)
10
+ 2. Your SSH public key added to Lambda Labs account
11
+ 3. W&B API key for logging (set as WANDB_API_KEY environment variable)
12
+
13
+ Usage:
14
+ python launch_lambda.py --instance-type gpu_8x_h100_sxm5 --region us-west-1
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import time
20
+ import argparse
21
+ import subprocess
22
+ from pathlib import Path
23
+
24
+ try:
25
+ import lambda_cloud_client
26
+ from lambda_cloud_client.rest import ApiException
27
+ except ImportError:
28
+ print("Installing lambda-cloud-client...")
29
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "lambda-cloud-client"])
30
+ import lambda_cloud_client
31
+ from lambda_cloud_client.rest import ApiException
32
+
33
+
34
+ def check_env_vars():
35
+ """Check required environment variables are set"""
36
+ required_vars = {
37
+ 'LAMBDA_API_KEY': 'Lambda Labs API key',
38
+ 'WANDB_API_KEY': 'Weights & Biases API key'
39
+ }
40
+
41
+ missing = []
42
+ for var, description in required_vars.items():
43
+ if not os.getenv(var):
44
+ missing.append(f" - {var} ({description})")
45
+
46
+ if missing:
47
+ print("ERROR: Missing required environment variables:")
48
+ print("\n".join(missing))
49
+ print("\nSet them with:")
50
+ print(" export LAMBDA_API_KEY='your-lambda-api-key'")
51
+ print(" export WANDB_API_KEY='your-wandb-api-key'")
52
+ sys.exit(1)
53
+
54
+
55
+ def get_api_client():
56
+ """Initialize Lambda Cloud API client"""
57
+ api_key = os.getenv('LAMBDA_API_KEY')
58
+ configuration = lambda_cloud_client.Configuration(
59
+ host="https://cloud.lambdalabs.com/api/v1",
60
+ access_token=api_key
61
+ )
62
+ return lambda_cloud_client.ApiClient(configuration)
63
+
64
+
65
+ def list_available_instance_types(api_client):
66
+ """List available instance types and regions"""
67
+ api_instance = lambda_cloud_client.DefaultApi(api_client)
68
+
69
+ try:
70
+ response = api_instance.instance_types()
71
+ print("\nAvailable Instance Types:")
72
+ print("-" * 80)
73
+
74
+ for type_name, details in response.data.items():
75
+ if details.instance_type.regions_with_capacity_available:
76
+ print(f"\n{type_name}:")
77
+ print(f" GPUs: {details.instance_type.specs.gpus}")
78
+ print(f" GPU Memory: {details.instance_type.specs.memory_gbs} GB")
79
+ print(f" Price: ${details.instance_type.specs.price_cents_per_hour / 100}/hour")
80
+ print(f" Available regions: {', '.join(details.instance_type.regions_with_capacity_available)}")
81
+
82
+ return response.data
83
+ except ApiException as e:
84
+ print(f"Error fetching instance types: {e}")
85
+ sys.exit(1)
86
+
87
+
88
+ def launch_instance(api_client, instance_type, region, name="nanochat-aquarat-training"):
89
+ """Launch a Lambda Labs GPU instance"""
90
+ api_instance = lambda_cloud_client.DefaultApi(api_client)
91
+
92
+ # Get SSH keys
93
+ try:
94
+ ssh_keys_response = api_instance.list_ssh_keys()
95
+ if not ssh_keys_response.data:
96
+ print("ERROR: No SSH keys found in your Lambda Labs account.")
97
+ print("Please add an SSH key at: https://cloud.lambdalabs.com/ssh-keys")
98
+ sys.exit(1)
99
+
100
+ ssh_key_names = [key.name for key in ssh_keys_response.data]
101
+ print(f"Using SSH keys: {', '.join(ssh_key_names)}")
102
+ except ApiException as e:
103
+ print(f"Error fetching SSH keys: {e}")
104
+ sys.exit(1)
105
+
106
+ # Launch instance
107
+ launch_request = lambda_cloud_client.LaunchInstanceRequest(
108
+ region_name=region,
109
+ instance_type_name=instance_type,
110
+ ssh_key_names=ssh_key_names,
111
+ name=name,
112
+ quantity=1
113
+ )
114
+
115
+ print(f"\nLaunching {instance_type} instance in {region}...")
116
+
117
+ try:
118
+ response = api_instance.launch_instance(launch_request)
119
+
120
+ if response.data and response.data.instance_ids:
121
+ instance_id = response.data.instance_ids[0]
122
+ print(f"✓ Instance launched successfully!")
123
+ print(f" Instance ID: {instance_id}")
124
+ return instance_id
125
+ else:
126
+ print("ERROR: Instance launch failed")
127
+ sys.exit(1)
128
+
129
+ except ApiException as e:
130
+ print(f"Error launching instance: {e}")
131
+ sys.exit(1)
132
+
133
+
134
+ def wait_for_instance(api_client, instance_id, timeout=300):
135
+ """Wait for instance to be ready"""
136
+ api_instance = lambda_cloud_client.DefaultApi(api_client)
137
+
138
+ print("\nWaiting for instance to be ready...")
139
+ start_time = time.time()
140
+
141
+ while time.time() - start_time < timeout:
142
+ try:
143
+ response = api_instance.get_instance(instance_id)
144
+ instance = response.data
145
+
146
+ if instance.status == "active":
147
+ print(f"✓ Instance is ready!")
148
+ print(f" IP Address: {instance.ip}")
149
+ print(f" SSH Command: ssh ubuntu@{instance.ip}")
150
+ return instance
151
+
152
+ print(f" Status: {instance.status}... waiting")
153
+ time.sleep(10)
154
+
155
+ except ApiException as e:
156
+ print(f"Error checking instance status: {e}")
157
+ time.sleep(10)
158
+
159
+ print("ERROR: Timeout waiting for instance to be ready")
160
+ sys.exit(1)
161
+
162
+
163
+ def generate_startup_script():
164
+ """Generate the startup script to run on the instance"""
165
+ wandb_key = os.getenv('WANDB_API_KEY')
166
+
167
+ script = f"""#!/bin/bash
168
+ set -euo pipefail
169
+
170
+ # Create .env file with credentials
171
+ cat > /home/ubuntu/nanochatAquaRat/.env << 'EOF'
172
+ WANDB_API_KEY={wandb_key}
173
+ WANDB_PROJECT=nanochat-aquarat
174
+ WANDB_ENTITY=${{WANDB_ENTITY:-}}
175
+ EOF
176
+
177
+ # Clone repository if not exists
178
+ cd /home/ubuntu
179
+ if [ ! -d "nanochatAquaRat" ]; then
180
+ git clone https://github.com/HarleyCoops/nanochatAquaRat.git
181
+ fi
182
+
183
+ cd nanochatAquaRat
184
+
185
+ # Make script executable
186
+ chmod +x run_aquarat_small.sh
187
+
188
+ # Run training in screen session
189
+ screen -dmS training bash -c './run_aquarat_small.sh 2>&1 | tee training.log'
190
+
191
+ echo "Training started in screen session 'training'"
192
+ echo "To attach: screen -r training"
193
+ echo "To detach: Ctrl+A then D"
194
+ echo "To view log: tail -f training.log"
195
+ """
196
+
197
+ return script
198
+
199
+
200
+ def deploy_and_run(instance_ip):
201
+ """Deploy code and start training on the instance"""
202
+ print("\nDeploying code and starting training...")
203
+
204
+ startup_script = generate_startup_script()
205
+
206
+ # Save startup script locally
207
+ script_path = Path("/tmp/lambda_startup.sh")
208
+ script_path.write_text(startup_script)
209
+
210
+ # Copy startup script to instance
211
+ print(" Copying startup script...")
212
+ subprocess.run([
213
+ "scp", "-o", "StrictHostKeyChecking=no",
214
+ str(script_path),
215
+ f"ubuntu@{instance_ip}:/tmp/startup.sh"
216
+ ], check=True)
217
+
218
+ # Execute startup script
219
+ print(" Starting training...")
220
+ subprocess.run([
221
+ "ssh", "-o", "StrictHostKeyChecking=no",
222
+ f"ubuntu@{instance_ip}",
223
+ "bash /tmp/startup.sh"
224
+ ], check=True)
225
+
226
+ print("\n" + "=" * 80)
227
+ print("✓ Training deployment complete!")
228
+ print("=" * 80)
229
+ print("\nTo monitor your training:")
230
+ print(f" 1. SSH: ssh ubuntu@{instance_ip}")
231
+ print(f" 2. Attach to screen: screen -r training")
232
+ print(f" 3. View log: tail -f ~/nanochatAquaRat/training.log")
233
+ print(f" 4. W&B Dashboard: https://wandb.ai")
234
+ print("\nTo detach from screen: Ctrl+A then D")
235
+ print("\nRemember to terminate the instance when done to avoid charges!")
236
+
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser(description="Launch Lambda Labs instance for AQuA-RAT training")
240
+ parser.add_argument("--instance-type", default="gpu_8x_h100_sxm5",
241
+ help="Instance type (default: gpu_8x_h100_sxm5)")
242
+ parser.add_argument("--region", default="us-west-1",
243
+ help="Region to launch in (default: us-west-1)")
244
+ parser.add_argument("--name", default="nanochat-aquarat-training",
245
+ help="Instance name (default: nanochat-aquarat-training)")
246
+ parser.add_argument("--list-types", action="store_true",
247
+ help="List available instance types and exit")
248
+ parser.add_argument("--no-deploy", action="store_true",
249
+ help="Launch instance but don't deploy code")
250
+
251
+ args = parser.parse_args()
252
+
253
+ print("=" * 80)
254
+ print("Lambda Labs GPU Instance Launcher for AQuA-RAT Training")
255
+ print("=" * 80)
256
+
257
+ # Check environment variables
258
+ check_env_vars()
259
+
260
+ # Initialize API client
261
+ api_client = get_api_client()
262
+
263
+ # List available types if requested
264
+ if args.list_types:
265
+ list_available_instance_types(api_client)
266
+ return
267
+
268
+ # Launch instance
269
+ instance_id = launch_instance(api_client, args.instance_type, args.region, args.name)
270
+
271
+ # Wait for instance to be ready
272
+ instance = wait_for_instance(api_client, instance_id)
273
+
274
+ # Deploy and run training
275
+ if not args.no_deploy:
276
+ time.sleep(5) # Give SSH a moment to be fully ready
277
+ try:
278
+ deploy_and_run(instance.ip)
279
+ except subprocess.CalledProcessError as e:
280
+ print(f"\nWarning: Deployment encountered an error: {e}")
281
+ print(f"You can manually SSH to the instance and run the training:")
282
+ print(f" ssh ubuntu@{instance.ip}")
283
+ print(f" cd nanochatAquaRat && bash run_aquarat_small.sh")
284
+
285
+ print("\n" + "=" * 80)
286
+ print("Instance Information")
287
+ print("=" * 80)
288
+ print(f"Instance ID: {instance_id}")
289
+ print(f"IP Address: {instance.ip}")
290
+ print(f"Status: {instance.status}")
291
+ print("\nTo terminate this instance:")
292
+ print(f" python launch_lambda.py --terminate {instance_id}")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
scripts/launch_hyperbolic_training.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Automation helper to launch a Hyperbolic Labs marketplace instance and kick off the
4
+ nanochat AQuA-RAT training run.
5
+
6
+ The workflow mirrors `launch_lambda_training.py`, but uses Hyperbolic's REST API.
7
+
8
+ Example:
9
+
10
+ python scripts/launch_hyperbolic_training.py \\
11
+ --gpu-count 1 \\
12
+ --region us-east \\
13
+ --max-price 4.5 \\
14
+ --auto-start \\
15
+ --inject-env WANDB_API_KEY
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ import os
23
+ import shlex
24
+ import subprocess
25
+ import sys
26
+ import tempfile
27
+ import textwrap
28
+ import time
29
+ from pathlib import Path
30
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
31
+
32
+ import requests
33
+
34
+ API_BASE = "https://api.hyperbolic.xyz"
35
+ MARKETPLACE_BASE = f"{API_BASE}/v1/marketplace"
36
+ READY_STATUSES = {
37
+ "ready",
38
+ "running",
39
+ "instance_running",
40
+ "node_ready",
41
+ "active",
42
+ "online",
43
+ }
44
+
45
+
46
+ def log(msg: str) -> None:
47
+ print(f"[info] {msg}")
48
+
49
+
50
+ def warn(msg: str) -> None:
51
+ print(f"[warn] {msg}", file=sys.stderr)
52
+
53
+
54
+ def error(msg: str) -> None:
55
+ print(f"[error] {msg}", file=sys.stderr)
56
+
57
+
58
+ def shell_quote(value: str) -> str:
59
+ return shlex.quote(value)
60
+
61
+
62
+ def collect_env_pairs(cli_pairs: Sequence[str], inject_names: Sequence[str]) -> List[Tuple[str, str]]:
63
+ """Merge KEY=VALUE pairs with env vars pulled from the local environment."""
64
+ merged: Dict[str, str] = {}
65
+
66
+ for item in cli_pairs:
67
+ if "=" not in item:
68
+ raise ValueError(f"--env expects KEY=VALUE entries, got '{item}'")
69
+ key, value = item.split("=", 1)
70
+ key = key.strip()
71
+ if not key:
72
+ raise ValueError(f"Environment key is empty in '{item}'")
73
+ merged[key] = value
74
+
75
+ for name in inject_names:
76
+ if not name:
77
+ raise ValueError("Encountered empty --inject-env name")
78
+ if name not in os.environ:
79
+ raise ValueError(f"--inject-env requested '{name}' but it is not set locally")
80
+ merged[name] = os.environ[name]
81
+
82
+ return list(merged.items())
83
+
84
+
85
+ def build_bootstrap_script(
86
+ repo_dir: str,
87
+ run_script: str,
88
+ branch: str,
89
+ repo_url: Optional[str],
90
+ env_file_remote: str,
91
+ auto_start: bool,
92
+ tmux_session: str,
93
+ ) -> str:
94
+ """Compose the bash script executed on the instance to prepare the training run."""
95
+ lines: List[str] = [
96
+ "#!/usr/bin/env bash",
97
+ "set -euxo pipefail",
98
+ f'REPO_DIR="$HOME/{repo_dir}"',
99
+ f'RUN_SCRIPT="{run_script}"',
100
+ f'ENV_FILE="{env_file_remote}"',
101
+ f'AUTO_START="{1 if auto_start else 0}"',
102
+ ]
103
+
104
+ if repo_url:
105
+ lines.append(f"REPO_URL={shell_quote(repo_url)}")
106
+ lines.extend(
107
+ [
108
+ 'if [ ! -d "$REPO_DIR/.git" ]; then',
109
+ ' rm -rf "$REPO_DIR"',
110
+ ' git clone "$REPO_URL" "$REPO_DIR"',
111
+ "fi",
112
+ 'cd "$REPO_DIR"',
113
+ "git fetch --all --prune",
114
+ f"git switch {shell_quote(branch)}",
115
+ "git pull --ff-only || true",
116
+ ]
117
+ )
118
+ else:
119
+ lines.extend(
120
+ [
121
+ 'mkdir -p "$REPO_DIR"',
122
+ 'cd "$REPO_DIR"',
123
+ ]
124
+ )
125
+
126
+ lines.extend(
127
+ [
128
+ 'if [ -f "$ENV_FILE" ]; then',
129
+ ' cp "$ENV_FILE" .env',
130
+ "fi",
131
+ 'if [ -f "$RUN_SCRIPT" ]; then',
132
+ ' chmod +x "$RUN_SCRIPT"',
133
+ "else",
134
+ ' echo "Run script $RUN_SCRIPT not found; auto-start will be skipped." >&2',
135
+ ' AUTO_START="0"',
136
+ "fi",
137
+ ]
138
+ )
139
+
140
+ if auto_start:
141
+ tmux_line = (
142
+ f'tmux new -d -s {shell_quote(tmux_session)} '
143
+ '"cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\""'
144
+ )
145
+ nohup_line = (
146
+ 'nohup bash -lc "cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\"" '
147
+ '> "$HOME/nanochat-train.log" 2>&1 &'
148
+ )
149
+ lines.extend(
150
+ [
151
+ 'if [ "$AUTO_START" = "1" ]; then',
152
+ " if command -v tmux >/dev/null 2>&1; then",
153
+ f" {tmux_line}",
154
+ " else",
155
+ f" {nohup_line}",
156
+ " fi",
157
+ "fi",
158
+ ]
159
+ )
160
+
161
+ return "\n".join(lines) + "\n"
162
+
163
+
164
+ class HyperbolicClient:
165
+ def __init__(self, api_key: Optional[str]):
166
+ if not api_key:
167
+ raise ValueError("Hyperbolic API key is required. Pass --api-key or set HYPERBOLIC_API_KEY.")
168
+ self.api_key = api_key
169
+
170
+ def _headers(self, with_auth: bool = True) -> Dict[str, str]:
171
+ headers = {"Content-Type": "application/json"}
172
+ if with_auth and self.api_key:
173
+ headers["Authorization"] = f"Bearer {self.api_key}"
174
+ return headers
175
+
176
+ def list_marketplace(self) -> List[Dict[str, Any]]:
177
+ response = requests.post(
178
+ MARKETPLACE_BASE,
179
+ headers=self._headers(with_auth=False),
180
+ json={"filters": {}},
181
+ timeout=30,
182
+ )
183
+ response.raise_for_status()
184
+ payload = response.json()
185
+ instances = payload.get("instances")
186
+ if instances is None:
187
+ if isinstance(payload, list):
188
+ instances = payload
189
+ elif isinstance(payload, dict):
190
+ instances = payload.get("nodes") or payload.get("data") or []
191
+ else:
192
+ instances = []
193
+ return instances
194
+
195
+ def create_instance(
196
+ self,
197
+ cluster_name: str,
198
+ node_name: str,
199
+ gpu_count: int,
200
+ image: Optional[Dict[str, Any]] = None,
201
+ ) -> Dict[str, Any]:
202
+ payload: Dict[str, Any] = {
203
+ "cluster_name": cluster_name,
204
+ "node_name": node_name,
205
+ "gpu_count": gpu_count,
206
+ }
207
+ if image:
208
+ payload["image"] = image
209
+ response = requests.post(
210
+ f"{MARKETPLACE_BASE}/instances/create",
211
+ headers=self._headers(),
212
+ json=payload,
213
+ timeout=30,
214
+ )
215
+ response.raise_for_status()
216
+ return response.json()
217
+
218
+ def list_instances(self) -> List[Dict[str, Any]]:
219
+ response = requests.get(
220
+ f"{MARKETPLACE_BASE}/instances",
221
+ headers=self._headers(),
222
+ timeout=30,
223
+ )
224
+ response.raise_for_status()
225
+ payload = response.json()
226
+ if isinstance(payload, dict):
227
+ return payload.get("instances") or payload.get("data") or []
228
+ if isinstance(payload, list):
229
+ return payload
230
+ return []
231
+
232
+ def terminate_instance(self, instance_id: str) -> Dict[str, Any]:
233
+ response = requests.post(
234
+ f"{MARKETPLACE_BASE}/instances/terminate",
235
+ headers=self._headers(),
236
+ json={"id": instance_id},
237
+ timeout=30,
238
+ )
239
+ response.raise_for_status()
240
+ return response.json()
241
+
242
+ def get_balance(self) -> Optional[float]:
243
+ try:
244
+ response = requests.get(
245
+ f"{API_BASE}/billing/get_current_balance",
246
+ headers=self._headers(),
247
+ timeout=30,
248
+ )
249
+ response.raise_for_status()
250
+ data = response.json()
251
+ if isinstance(data, dict):
252
+ return float(data.get("balance") or data.get("amount") or data.get("credits"))
253
+ except (requests.HTTPError, ValueError, TypeError):
254
+ warn("Unable to fetch current account balance.")
255
+ return None
256
+
257
+
258
+ def summarize_node(node: Dict[str, Any]) -> str:
259
+ def _gpu_models() -> str:
260
+ gpus = []
261
+ hardware = node.get("hardware") or {}
262
+ for gpu in hardware.get("gpus") or []:
263
+ model = gpu.get("model")
264
+ ram = gpu.get("ram") or gpu.get("memory") or gpu.get("vram")
265
+ if model and ram:
266
+ gpus.append(f"{model} ({ram} GB)")
267
+ elif model:
268
+ gpus.append(model)
269
+ return ", ".join(gpus) if gpus else "Unknown GPUs"
270
+
271
+ price_info = (node.get("pricing") or {}).get("price") or {}
272
+ price = price_info.get("amount")
273
+ price_str = f"${price:.2f}/hr" if isinstance(price, (int, float)) else "n/a"
274
+ region = ((node.get("location") or {}).get("region")) or "unknown region"
275
+ cluster = node.get("cluster_name") or "unknown cluster"
276
+ available = (node.get("gpus_total") or 0) - (node.get("gpus_reserved") or 0)
277
+ supplier = node.get("supplier_id") or "unknown supplier"
278
+ return (
279
+ f"{node.get('id', '<unknown>')} | {cluster} | {region} | "
280
+ f"{available}/{node.get('gpus_total', '?')} GPUs free | "
281
+ f"{_gpu_models()} | {price_str} | supplier: {supplier}"
282
+ )
283
+
284
+
285
+ def filter_nodes(
286
+ nodes: Iterable[Dict[str, Any]],
287
+ gpu_count: int,
288
+ region: Optional[str],
289
+ supplier: Optional[str],
290
+ max_price: Optional[float],
291
+ ) -> List[Dict[str, Any]]:
292
+ filtered: List[Dict[str, Any]] = []
293
+ region = region.lower() if region else None
294
+ supplier = supplier.lower() if supplier else None
295
+
296
+ for node in nodes:
297
+ total = node.get("gpus_total") or 0
298
+ reserved = node.get("gpus_reserved") or 0
299
+ available = total - reserved
300
+ if available < gpu_count:
301
+ continue
302
+
303
+ if region:
304
+ region_value = ((node.get("location") or {}).get("region") or "").lower()
305
+ if region not in region_value:
306
+ continue
307
+
308
+ if supplier:
309
+ supplier_value = (node.get("supplier_id") or "").lower()
310
+ if supplier not in supplier_value:
311
+ continue
312
+
313
+ price_info = (node.get("pricing") or {}).get("price") or {}
314
+ price = price_info.get("amount")
315
+ if max_price is not None and isinstance(price, (int, float)) and price > max_price:
316
+ continue
317
+
318
+ filtered.append(node)
319
+
320
+ filtered.sort(
321
+ key=lambda n: ((n.get("pricing") or {}).get("price") or {}).get("amount", float("inf"))
322
+ )
323
+ return filtered
324
+
325
+
326
+ def extract_instance_id(payload: Dict[str, Any], before_ids: Sequence[str], client: HyperbolicClient) -> str:
327
+ candidates: List[str] = []
328
+ for key in ("id", "instance_id", "instanceId"):
329
+ value = payload.get(key)
330
+ if isinstance(value, str) and value:
331
+ candidates.append(value)
332
+ instance_obj = payload.get("instance") or payload.get("data")
333
+ if isinstance(instance_obj, dict):
334
+ for key in ("id", "instance_id", "instanceId"):
335
+ value = instance_obj.get(key)
336
+ if isinstance(value, str) and value:
337
+ candidates.append(value)
338
+ if candidates:
339
+ return candidates[0]
340
+
341
+ # Fall back to diffing current instances.
342
+ time.sleep(3)
343
+ current = client.list_instances()
344
+ current_ids = {str(inst.get("id")) for inst in current if inst.get("id")}
345
+ diff = current_ids.difference(before_ids)
346
+ if diff:
347
+ return diff.pop()
348
+ raise RuntimeError("Unable to determine instance ID from API response.")
349
+
350
+
351
+ def extract_ip(instance: Dict[str, Any]) -> Optional[str]:
352
+ network = instance.get("network") or {}
353
+ candidates = [
354
+ instance.get("public_ip"),
355
+ instance.get("ip_address"),
356
+ instance.get("ip"),
357
+ instance.get("ipv4"),
358
+ network.get("public_ip"),
359
+ network.get("ip"),
360
+ network.get("ipv4"),
361
+ ]
362
+
363
+ for item in instance.get("ip_addresses") or []:
364
+ candidates.extend(
365
+ [
366
+ item.get("public_ip"),
367
+ item.get("ip"),
368
+ item.get("ipv4"),
369
+ item.get("address"),
370
+ ]
371
+ )
372
+
373
+ for candidate in candidates:
374
+ if isinstance(candidate, str) and candidate:
375
+ return candidate
376
+ return None
377
+
378
+
379
+ def extract_ssh_port(instance: Dict[str, Any]) -> int:
380
+ network = instance.get("network") or {}
381
+ candidates = [
382
+ instance.get("ssh_port"),
383
+ network.get("ssh_port"),
384
+ (instance.get("ssh") or {}).get("port"),
385
+ ]
386
+
387
+ for candidate in candidates:
388
+ if isinstance(candidate, int):
389
+ return candidate
390
+ if isinstance(candidate, str) and candidate.isdigit():
391
+ return int(candidate)
392
+ return 22
393
+
394
+
395
+ def extract_status(instance: Dict[str, Any]) -> str:
396
+ for key in ("status", "instance_status", "state"):
397
+ value = instance.get(key)
398
+ if isinstance(value, str):
399
+ return value
400
+ return ""
401
+
402
+
403
+ def wait_for_instance(
404
+ client: HyperbolicClient,
405
+ instance_id: str,
406
+ poll_seconds: int,
407
+ max_wait_minutes: int,
408
+ ) -> Dict[str, Any]:
409
+ log(f"Waiting for instance {instance_id} to become ready...")
410
+ deadline = time.time() + max_wait_minutes * 60
411
+ while time.time() < deadline:
412
+ instances = client.list_instances()
413
+ for instance in instances:
414
+ identifiers = {
415
+ str(instance.get("id")),
416
+ str(instance.get("instance_id")),
417
+ str(instance.get("instanceId")),
418
+ }
419
+ if instance_id not in identifiers:
420
+ continue
421
+
422
+ status = extract_status(instance).lower()
423
+ ip = extract_ip(instance)
424
+ if status in READY_STATUSES and ip:
425
+ log(f"Instance is ready: status={status}, ip={ip}")
426
+ return instance
427
+
428
+ log(f" status={status or '<unknown>'}; waiting for ready state...")
429
+ time.sleep(poll_seconds)
430
+
431
+ raise TimeoutError(f"Timed out waiting for instance {instance_id} to become ready.")
432
+
433
+
434
+ def build_env_content(pairs: Sequence[Tuple[str, str]]) -> str:
435
+ return "\n".join(f"{key}={value}" for key, value in pairs) + ("\n" if pairs else "")
436
+
437
+
438
+ def scp(
439
+ local_path: Path,
440
+ remote_path: str,
441
+ ssh_user: str,
442
+ host: str,
443
+ port: int,
444
+ ssh_key: Optional[str],
445
+ ) -> None:
446
+ cmd = ["scp", "-o", "StrictHostKeyChecking=no"]
447
+ if ssh_key:
448
+ cmd.extend(["-i", ssh_key])
449
+ if port != 22:
450
+ cmd.extend(["-P", str(port)])
451
+ cmd.extend([str(local_path), f"{ssh_user}@{host}:{remote_path}"])
452
+ subprocess.run(cmd, check=True)
453
+
454
+
455
+ def ssh_command(
456
+ ssh_user: str,
457
+ host: str,
458
+ port: int,
459
+ ssh_key: Optional[str],
460
+ *command: str,
461
+ ) -> None:
462
+ cmd = ["ssh", "-o", "StrictHostKeyChecking=no"]
463
+ if ssh_key:
464
+ cmd.extend(["-i", ssh_key])
465
+ if port != 22:
466
+ cmd.extend(["-p", str(port)])
467
+ cmd.append(f"{ssh_user}@{host}")
468
+ if command:
469
+ cmd.append(" ".join(command))
470
+ subprocess.run(cmd, check=True)
471
+
472
+
473
+ def deploy_to_instance(
474
+ instance: Dict[str, Any],
475
+ bootstrap_script: str,
476
+ env_pairs: Sequence[Tuple[str, str]],
477
+ env_file_remote: str,
478
+ ssh_user: str,
479
+ ssh_key: Optional[str],
480
+ ) -> None:
481
+ ip = extract_ip(instance)
482
+ if not ip:
483
+ raise RuntimeError("Instance does not report a public IP address yet.")
484
+ port = extract_ssh_port(instance)
485
+ log(f"Deploying bootstrap assets to {ssh_user}@{ip}:{port} ...")
486
+
487
+ with tempfile.TemporaryDirectory() as tmpdir:
488
+ tmp_path = Path(tmpdir)
489
+ bootstrap_path = tmp_path / "bootstrap.sh"
490
+ bootstrap_path.write_text(bootstrap_script)
491
+ scp(bootstrap_path, "/tmp/nanochat_bootstrap.sh", ssh_user, ip, port, ssh_key)
492
+
493
+ if env_pairs:
494
+ env_path = tmp_path / "nanochat.env"
495
+ env_path.write_text(build_env_content(env_pairs))
496
+ ssh_command(ssh_user, ip, port, ssh_key, f"mkdir -p {shell_quote(str(Path(env_file_remote).parent))}")
497
+ scp(env_path, env_file_remote, ssh_user, ip, port, ssh_key)
498
+
499
+ log("Executing remote bootstrap script...")
500
+ ssh_command(ssh_user, ip, port, ssh_key, "bash /tmp/nanochat_bootstrap.sh")
501
+
502
+
503
+ def parse_args() -> argparse.Namespace:
504
+ parser = argparse.ArgumentParser(description="Launch Hyperbolic Labs instance for AQuA-RAT training")
505
+ parser.add_argument("--api-key", default=os.environ.get("HYPERBOLIC_API_KEY"), help="Hyperbolic API key")
506
+ parser.add_argument("--gpu-count", type=int, default=1, help="Number of GPUs to request (default: 1)")
507
+ parser.add_argument("--region", help="Preferred region substring (case-insensitive)")
508
+ parser.add_argument("--supplier", help="Preferred supplier substring (case-insensitive)")
509
+ parser.add_argument("--max-price", type=float, help="Maximum hourly price in USD")
510
+ parser.add_argument("--node-name", help="Specify node name explicitly")
511
+ parser.add_argument("--cluster-name", help="Cluster name when using --node-name")
512
+ parser.add_argument("--list", action="store_true", help="List available marketplace nodes and exit")
513
+
514
+ parser.add_argument("--repo-url", help="Repository URL to clone (defaults to git remote origin)")
515
+ parser.add_argument("--branch", default="main", help="Branch to checkout on the instance (default: main)")
516
+ parser.add_argument("--run-script", default="run_aquarat_small.sh", help="Script to execute on the instance")
517
+ parser.add_argument("--repo-dir", default="nanochatAquaRat", help="Directory name for the repo on the instance")
518
+
519
+ parser.add_argument("--auto-start", action="store_true", help="Automatically run the training script")
520
+ parser.add_argument("--tmux-session", default="training", help="tmux session name when auto-start is enabled")
521
+ parser.add_argument("--ssh-user", default="ubuntu", help="SSH username for the instance (default: ubuntu)")
522
+ parser.add_argument("--ssh-key", help="Path to SSH private key for scp/ssh")
523
+ parser.add_argument("--no-deploy", action="store_true", help="Skip deployment after the instance is ready")
524
+
525
+ parser.add_argument("--env", action="append", default=[], help="Environment variable in KEY=VALUE form")
526
+ parser.add_argument("--inject-env", action="append", default=[], help="Environment variable name to copy from local env")
527
+
528
+ parser.add_argument("--poll-seconds", type=int, default=20, help="Polling interval while waiting (default: 20)")
529
+ parser.add_argument("--max-wait-minutes", type=int, default=25, help="Maximum minutes to wait for ready state")
530
+
531
+ return parser.parse_args()
532
+
533
+
534
+ def guess_repo_url() -> Optional[str]:
535
+ try:
536
+ completed = subprocess.run(
537
+ ["git", "config", "--get", "remote.origin.url"],
538
+ capture_output=True,
539
+ text=True,
540
+ check=True,
541
+ )
542
+ except (subprocess.CalledProcessError, FileNotFoundError):
543
+ return None
544
+ url = completed.stdout.strip()
545
+ return url or None
546
+
547
+
548
+ def main() -> int:
549
+ args = parse_args()
550
+
551
+ try:
552
+ env_pairs = collect_env_pairs(args.env, args.inject_env)
553
+ except ValueError as exc:
554
+ error(str(exc))
555
+ return 1
556
+
557
+ if not args.api_key:
558
+ error("Hyperbolic API key not provided. Use --api-key or set HYPERBOLIC_API_KEY.")
559
+ return 1
560
+
561
+ client = HyperbolicClient(args.api_key)
562
+
563
+ try:
564
+ nodes = client.list_marketplace()
565
+ except requests.HTTPError as exc:
566
+ error(f"Failed to list marketplace nodes: {exc}")
567
+ return 1
568
+
569
+ if args.list:
570
+ log("Available marketplace nodes:")
571
+ for node in nodes:
572
+ print(summarize_node(node))
573
+ return 0
574
+
575
+ selected_node: Optional[Dict[str, Any]] = None
576
+
577
+ if args.node_name:
578
+ for node in nodes:
579
+ if node.get("id") == args.node_name:
580
+ selected_node = node
581
+ break
582
+ if not selected_node:
583
+ error(f"Node '{args.node_name}' not found in marketplace list.")
584
+ return 1
585
+ if not args.cluster_name:
586
+ args.cluster_name = selected_node.get("cluster_name")
587
+ else:
588
+ filtered_nodes = filter_nodes(nodes, args.gpu_count, args.region, args.supplier, args.max_price)
589
+ if not filtered_nodes:
590
+ error("No marketplace nodes match the specified constraints.")
591
+ return 1
592
+ selected_node = filtered_nodes[0]
593
+ log("Selected node:")
594
+ print(" " + summarize_node(selected_node))
595
+ args.cluster_name = selected_node.get("cluster_name")
596
+ args.node_name = selected_node.get("id")
597
+
598
+ if not args.cluster_name:
599
+ error("Cluster name is required; unable to determine cluster for the selected node.")
600
+ return 1
601
+
602
+ repo_url = args.repo_url or guess_repo_url()
603
+ if repo_url:
604
+ log(f"Using repository: {repo_url}")
605
+ else:
606
+ warn("Could not determine repository URL. Auto-start will clone existing repo on instance if present.")
607
+
608
+ balance = client.get_balance()
609
+ if balance is not None:
610
+ log(f"Current Hyperbolic balance: ${balance:.2f}")
611
+
612
+ before_instances = client.list_instances()
613
+ before_ids = {str(inst.get("id")) for inst in before_instances if inst.get("id")}
614
+ log(f"Launching instance on cluster '{args.cluster_name}' node '{args.node_name}' "
615
+ f"with {args.gpu_count} GPU(s)...")
616
+
617
+ try:
618
+ create_response = client.create_instance(
619
+ cluster_name=args.cluster_name,
620
+ node_name=args.node_name,
621
+ gpu_count=args.gpu_count,
622
+ )
623
+ except requests.HTTPError as exc:
624
+ error(f"Failed to launch instance: {exc}")
625
+ try:
626
+ warn(f"Response payload: {exc.response.text}") # type: ignore[attr-defined]
627
+ except Exception:
628
+ pass
629
+ return 1
630
+
631
+ instance_id = extract_instance_id(create_response, before_ids, client)
632
+ log(f"Instance request acknowledged with id={instance_id}")
633
+
634
+ try:
635
+ instance = wait_for_instance(
636
+ client=client,
637
+ instance_id=instance_id,
638
+ poll_seconds=args.poll_seconds,
639
+ max_wait_minutes=args.max_wait_minutes,
640
+ )
641
+ except TimeoutError as exc:
642
+ error(str(exc))
643
+ return 1
644
+
645
+ ip = extract_ip(instance)
646
+ port = extract_ssh_port(instance)
647
+ ssh_user = args.ssh_user
648
+
649
+ log("Instance ready. Connection details:")
650
+ if ip:
651
+ ssh_parts = ["ssh", "-o", "StrictHostKeyChecking=no"]
652
+ if args.ssh_key:
653
+ ssh_parts.extend(["-i", args.ssh_key])
654
+ if port != 22:
655
+ ssh_parts.extend(["-p", str(port)])
656
+ ssh_parts.append(f"{ssh_user}@{ip}")
657
+ print(" SSH:", " ".join(ssh_parts))
658
+ else:
659
+ warn("Instance IP not available; SSH command cannot be constructed.")
660
+
661
+ if args.no_deploy:
662
+ log("Skipping deployment (--no-deploy supplied).")
663
+ return 0
664
+
665
+ env_file_remote = f"/home/{ssh_user}/nanochat_aquarat.env"
666
+ bootstrap_script = build_bootstrap_script(
667
+ repo_dir=args.repo_dir,
668
+ run_script=args.run_script,
669
+ branch=args.branch,
670
+ repo_url=repo_url,
671
+ env_file_remote=env_file_remote,
672
+ auto_start=args.auto_start,
673
+ tmux_session=args.tmux_session,
674
+ )
675
+
676
+ try:
677
+ deploy_to_instance(
678
+ instance=instance,
679
+ bootstrap_script=bootstrap_script,
680
+ env_pairs=env_pairs,
681
+ env_file_remote=env_file_remote,
682
+ ssh_user=ssh_user,
683
+ ssh_key=args.ssh_key,
684
+ )
685
+ except subprocess.CalledProcessError as exc:
686
+ error(f"Deployment failed: {exc}")
687
+ warn("You can manually SSH to the instance and run the training script.")
688
+ return 1
689
+
690
+ log("Deployment complete.")
691
+ if args.auto_start:
692
+ log("Training should now be running on the instance.")
693
+ else:
694
+ log("Auto-start disabled; after SSHing in, run the configured script manually.")
695
+
696
+ return 0
697
+
698
+
699
+ if __name__ == "__main__":
700
+ sys.exit(main())
701
+
scripts/launch_lambda_training.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Automation helper to launch a Lambda Labs instance and kick off the nanochat AQuA-RAT run.
4
+
5
+ Usage example:
6
+
7
+ python scripts/launch_lambda_training.py \
8
+ --ssh-key-name my-key \
9
+ --region us-west-1 \
10
+ --instance-type gpu_1x_a10 \
11
+ --repo-url https://github.com/your-org/nanochatAquaRat.git \
12
+ --auto-start \
13
+ --inject-env WANDB_API_KEY
14
+
15
+ By default the script will create cloud-init user-data that installs basic tooling,
16
+ clones the repository, copies an `.env` file when provided, and (optionally) runs
17
+ `run_aquarat_small.sh` inside a detached tmux session. The Lambda Cloud API key is
18
+ read from the `LAMBDA_API_KEY` environment variable or `--api-key`.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import json
25
+ import os
26
+ import shlex
27
+ import subprocess
28
+ import sys
29
+ import textwrap
30
+ import time
31
+ from collections import OrderedDict
32
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
33
+
34
+ import requests
35
+
36
+ API_BASE = "https://cloud.lambda.ai/api/v1"
37
+ DEFAULT_PACKAGES = ["git", "curl", "tmux", "build-essential"]
38
+
39
+
40
+ def log(msg: str) -> None:
41
+ print(f"[info] {msg}")
42
+
43
+
44
+ def warn(msg: str) -> None:
45
+ print(f"[warn] {msg}", file=sys.stderr)
46
+
47
+
48
+ def error(msg: str) -> None:
49
+ print(f"[error] {msg}", file=sys.stderr)
50
+
51
+
52
+ def shell_quote(value: str) -> str:
53
+ """Return a shell-escaped string for safe embedding inside scripts."""
54
+ return shlex.quote(value)
55
+
56
+
57
+ def guess_repo_url() -> Optional[str]:
58
+ """Attempt to infer the git remote URL for the current repository."""
59
+ try:
60
+ completed = subprocess.run(
61
+ ["git", "config", "--get", "remote.origin.url"],
62
+ capture_output=True,
63
+ text=True,
64
+ check=True,
65
+ )
66
+ except (subprocess.CalledProcessError, FileNotFoundError):
67
+ return None
68
+
69
+ url = completed.stdout.strip()
70
+ return url or None
71
+
72
+
73
+ def collect_env_pairs(cli_pairs: Sequence[str], inject_names: Sequence[str]) -> List[Tuple[str, str]]:
74
+ """
75
+ Merge KEY=VALUE pairs declared via the CLI with variables injected from the local env.
76
+ Later occurrences with the same key take precedence.
77
+ """
78
+ merged: "OrderedDict[str, str]" = OrderedDict()
79
+
80
+ for item in cli_pairs:
81
+ if "=" not in item:
82
+ raise ValueError(f"--env expects KEY=VALUE entries, got '{item}'")
83
+ key, value = item.split("=", 1)
84
+ key = key.strip()
85
+ if not key:
86
+ raise ValueError(f"Environment key is empty in '{item}'")
87
+ merged[key] = value
88
+
89
+ for name in inject_names:
90
+ if not name:
91
+ raise ValueError("Encountered empty --inject-env name")
92
+ if name not in os.environ:
93
+ raise ValueError(f"--inject-env requested '{name}' but it is not set locally")
94
+ merged[name] = os.environ[name]
95
+
96
+ return list(merged.items())
97
+
98
+
99
+ def build_bootstrap_script(
100
+ repo_dir: str,
101
+ run_script: str,
102
+ branch: str,
103
+ repo_url: Optional[str],
104
+ env_file_remote: str,
105
+ auto_start: bool,
106
+ tmux_session: str,
107
+ ) -> str:
108
+ """Compose the bash script executed on the instance to prepare the training run."""
109
+ lines: List[str] = [
110
+ "#!/usr/bin/env bash",
111
+ "set -euxo pipefail",
112
+ f'REPO_DIR="$HOME/{repo_dir}"',
113
+ f'RUN_SCRIPT="{run_script}"',
114
+ f'ENV_FILE="{env_file_remote}"',
115
+ f'AUTO_START="{1 if auto_start else 0}"',
116
+ ]
117
+
118
+ if repo_url:
119
+ lines.append(f"REPO_URL={shell_quote(repo_url)}")
120
+ lines.extend(
121
+ [
122
+ 'if [ ! -d "$REPO_DIR/.git" ]; then',
123
+ ' rm -rf "$REPO_DIR"',
124
+ ' git clone "$REPO_URL" "$REPO_DIR"',
125
+ "fi",
126
+ 'cd "$REPO_DIR"',
127
+ "git fetch --all --prune",
128
+ f"git switch {shell_quote(branch)}",
129
+ "git pull --ff-only || true",
130
+ ]
131
+ )
132
+ else:
133
+ lines.extend(
134
+ [
135
+ 'mkdir -p "$REPO_DIR"',
136
+ 'cd "$REPO_DIR"',
137
+ ]
138
+ )
139
+
140
+ lines.extend(
141
+ [
142
+ 'if [ -f "$ENV_FILE" ]; then',
143
+ ' cp "$ENV_FILE" .env',
144
+ "fi",
145
+ 'if [ -f "$RUN_SCRIPT" ]; then',
146
+ ' chmod +x "$RUN_SCRIPT"',
147
+ "else",
148
+ ' echo "Run script $RUN_SCRIPT not found; auto-start will be skipped." >&2',
149
+ ' AUTO_START="0"',
150
+ "fi",
151
+ ]
152
+ )
153
+
154
+ if auto_start:
155
+ tmux_line = (
156
+ f'tmux new -d -s {shell_quote(tmux_session)} '
157
+ '"cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\""'
158
+ )
159
+ nohup_line = (
160
+ 'nohup bash -lc "cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\"" '
161
+ '> "$HOME/nanochat-train.log" 2>&1 &'
162
+ )
163
+ lines.extend(
164
+ [
165
+ 'if [ "$AUTO_START" = "1" ]; then',
166
+ " if command -v tmux >/dev/null 2>&1; then",
167
+ f" {tmux_line}",
168
+ " else",
169
+ f" {nohup_line}",
170
+ " fi",
171
+ "fi",
172
+ ]
173
+ )
174
+
175
+ return "\n".join(lines) + "\n"
176
+
177
+
178
+ def build_user_data(
179
+ packages: Sequence[str],
180
+ bootstrap_script: str,
181
+ env_pairs: Sequence[Tuple[str, str]],
182
+ env_file_remote: str,
183
+ ) -> str:
184
+ """Render cloud-init user-data with package installs, env file, and bootstrap script."""
185
+ lines: List[str] = ["#cloud-config", "package_update: true", "package_upgrade: false"]
186
+
187
+ if packages:
188
+ lines.append("packages:")
189
+ for package in packages:
190
+ lines.append(f" - {package}")
191
+
192
+ lines.append("write_files:")
193
+ if env_pairs:
194
+ env_content = "\n".join(f"{key}={value}" for key, value in env_pairs) + "\n"
195
+ lines.extend(
196
+ [
197
+ f" - path: {env_file_remote}",
198
+ " owner: ubuntu:ubuntu",
199
+ " permissions: '0640'",
200
+ " content: |",
201
+ textwrap.indent(env_content, " "),
202
+ ]
203
+ )
204
+
205
+ lines.extend(
206
+ [
207
+ " - path: /home/ubuntu/bootstrap_nanochat.sh",
208
+ " owner: ubuntu:ubuntu",
209
+ " permissions: '0755'",
210
+ " content: |",
211
+ textwrap.indent(bootstrap_script, " "),
212
+ ]
213
+ )
214
+
215
+ lines.extend(
216
+ [
217
+ "runcmd:",
218
+ " - \"su - ubuntu -c '/home/ubuntu/bootstrap_nanochat.sh'\"",
219
+ ]
220
+ )
221
+
222
+ return "\n".join(lines) + "\n"
223
+
224
+
225
+ class LambdaClient:
226
+ """Minimal wrapper around the Lambda Cloud REST API."""
227
+
228
+ def __init__(self, api_key: str) -> None:
229
+ if not api_key:
230
+ raise ValueError("Lambda Cloud API key not provided")
231
+
232
+ self.session = requests.Session()
233
+ self.session.headers.update(
234
+ {
235
+ "Authorization": f"Bearer {api_key}",
236
+ "Accept": "application/json",
237
+ "Content-Type": "application/json",
238
+ }
239
+ )
240
+
241
+ def launch_instances(self, payload: Dict[str, object]) -> List[str]:
242
+ response = self.session.post(
243
+ f"{API_BASE}/instance-operations/launch",
244
+ data=json.dumps(payload),
245
+ timeout=60,
246
+ )
247
+ if response.status_code >= 400:
248
+ raise requests.HTTPError(response.text, response=response)
249
+ data = response.json()
250
+ instance_ids = data["data"]["instance_ids"]
251
+ return instance_ids
252
+
253
+ def get_instance(self, instance_id: str) -> Optional[Dict[str, object]]:
254
+ response = self.session.get(
255
+ f"{API_BASE}/instances/{instance_id}",
256
+ timeout=30,
257
+ )
258
+ if response.status_code == 404:
259
+ return None
260
+ if response.status_code >= 400:
261
+ raise requests.HTTPError(response.text, response=response)
262
+ return response.json()["data"]
263
+
264
+ def wait_for_instance(
265
+ self,
266
+ instance_id: str,
267
+ poll_seconds: int,
268
+ max_wait_minutes: int,
269
+ ) -> Dict[str, object]:
270
+ deadline = time.time() + max_wait_minutes * 60
271
+ last_status = "unknown"
272
+ while time.time() < deadline:
273
+ instance = self.get_instance(instance_id)
274
+ if not instance:
275
+ time.sleep(poll_seconds)
276
+ continue
277
+
278
+ status = str(instance.get("status", "unknown"))
279
+ if status != last_status:
280
+ log(f"Instance {instance_id} status: {status}")
281
+ last_status = status
282
+
283
+ if status == "active":
284
+ return instance
285
+ if status in {"terminated", "terminating", "preempted"}:
286
+ raise RuntimeError(f"Instance {instance_id} entered terminal status '{status}'")
287
+ if status == "unhealthy":
288
+ warn(f"Instance {instance_id} reported unhealthy; continuing to poll")
289
+
290
+ time.sleep(poll_seconds)
291
+
292
+ raise TimeoutError(f"Timed out waiting for instance {instance_id} to become active")
293
+
294
+
295
+ def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
296
+ parser = argparse.ArgumentParser(
297
+ description="Launch a Lambda Labs instance and prepare the nanochat training run."
298
+ )
299
+ parser.add_argument(
300
+ "--api-key",
301
+ default=os.getenv("LAMBDA_API_KEY"),
302
+ help="Lambda Cloud API key (default: read from LAMBDA_API_KEY).",
303
+ )
304
+ parser.add_argument("--region", default="us-west-1", help="Lambda Cloud region name.")
305
+ parser.add_argument(
306
+ "--instance-type",
307
+ default="gpu_1x_a10",
308
+ help="Instance type to launch (see Lambda Cloud docs for the catalog).",
309
+ )
310
+ parser.add_argument(
311
+ "--ssh-key-name",
312
+ required=True,
313
+ help="Name of the SSH key already registered with Lambda Cloud.",
314
+ )
315
+ parser.add_argument(
316
+ "--quantity",
317
+ type=int,
318
+ default=1,
319
+ help="Number of instances to launch (auto-start only supports 1).",
320
+ )
321
+ parser.add_argument("--name", help="Friendly name to assign to the instance(s).")
322
+ parser.add_argument(
323
+ "--repo-url",
324
+ help="Git URL for the nanochat repository (default: auto-detect from current repo).",
325
+ )
326
+ parser.add_argument("--branch", default="main", help="Git branch to checkout on the instance.")
327
+ parser.add_argument(
328
+ "--repo-dir",
329
+ default="nanochatAquaRat",
330
+ help="Directory name to clone the repository into on the instance.",
331
+ )
332
+ parser.add_argument(
333
+ "--run-script",
334
+ default="run_aquarat_small.sh",
335
+ help="Relative path to the training launch script inside the repo.",
336
+ )
337
+ parser.add_argument(
338
+ "--tmux-session",
339
+ default="nanochat-train",
340
+ help="tmux session name used when --auto-start is active.",
341
+ )
342
+ parser.add_argument(
343
+ "--auto-start",
344
+ action="store_true",
345
+ help="Kick off the training script automatically after provisioning completes.",
346
+ )
347
+ parser.add_argument(
348
+ "--env",
349
+ action="append",
350
+ default=[],
351
+ help="Additional KEY=VALUE pairs to write into the remote .env file (repeatable).",
352
+ )
353
+ parser.add_argument(
354
+ "--inject-env",
355
+ action="append",
356
+ default=[],
357
+ help="Names of local environment variables whose values should populate the remote .env.",
358
+ )
359
+ parser.add_argument(
360
+ "--env-file-name",
361
+ default=".env.lambda",
362
+ help="Filename (relative to /home/ubuntu) for the generated environment file.",
363
+ )
364
+ parser.add_argument(
365
+ "--image-id",
366
+ help="Optional image ID to use instead of the default Lambda Stack image.",
367
+ )
368
+ parser.add_argument(
369
+ "--image-family",
370
+ help="Optional image family name to use instead of the default image.",
371
+ )
372
+ parser.add_argument(
373
+ "--max-wait-minutes",
374
+ type=int,
375
+ default=25,
376
+ help="Maximum minutes to wait for the instance to become active.",
377
+ )
378
+ parser.add_argument(
379
+ "--poll-seconds",
380
+ type=int,
381
+ default=20,
382
+ help="Polling interval while waiting for the instance to become active.",
383
+ )
384
+ parser.add_argument(
385
+ "--skip-wait",
386
+ action="store_true",
387
+ help="Exit after requesting the launch without waiting for active status.",
388
+ )
389
+ parser.add_argument(
390
+ "--no-user-data",
391
+ action="store_true",
392
+ help="Skip sending user-data; instance boots with the stock image configuration.",
393
+ )
394
+ parser.add_argument(
395
+ "--print-user-data",
396
+ action="store_true",
397
+ help="Print the generated user-data to stdout before launching.",
398
+ )
399
+
400
+ return parser.parse_args(argv)
401
+
402
+
403
+ def main(argv: Optional[Sequence[str]] = None) -> int:
404
+ args = parse_args(argv)
405
+
406
+ if args.image_id and args.image_family:
407
+ error("Provide only one of --image-id or --image-family.")
408
+ return 1
409
+
410
+ if args.auto_start and args.quantity != 1:
411
+ error("--auto-start currently supports a single instance (set --quantity=1).")
412
+ return 1
413
+
414
+ if not args.api_key:
415
+ error("Lambda Cloud API key not provided. Use --api-key or set LAMBDA_API_KEY.")
416
+ return 1
417
+
418
+ repo_url = args.repo_url
419
+ if not repo_url:
420
+ repo_url = guess_repo_url()
421
+ if repo_url:
422
+ log(f"Discovered repository URL: {repo_url}")
423
+ else:
424
+ warn(
425
+ "Could not auto-detect repository URL; "
426
+ "pass --repo-url if you want the instance to clone the repo automatically."
427
+ )
428
+
429
+ try:
430
+ env_pairs = collect_env_pairs(args.env, args.inject_env)
431
+ except ValueError as exc:
432
+ error(str(exc))
433
+ return 1
434
+
435
+ env_file_remote = f"/home/ubuntu/{args.env_file_name}"
436
+
437
+ bootstrap_script = build_bootstrap_script(
438
+ repo_dir=args.repo_dir,
439
+ run_script=args.run_script,
440
+ branch=args.branch,
441
+ repo_url=repo_url,
442
+ env_file_remote=env_file_remote,
443
+ auto_start=args.auto_start,
444
+ tmux_session=args.tmux_session,
445
+ )
446
+
447
+ user_data: Optional[str] = None
448
+ if not args.no_user_data:
449
+ user_data = build_user_data(
450
+ packages=DEFAULT_PACKAGES,
451
+ bootstrap_script=bootstrap_script,
452
+ env_pairs=env_pairs,
453
+ env_file_remote=env_file_remote,
454
+ )
455
+ if args.print_user_data:
456
+ print(user_data)
457
+ else:
458
+ if args.print_user_data:
459
+ print("# user-data disabled (--no-user-data)")
460
+
461
+ payload: Dict[str, object] = {
462
+ "region_name": args.region,
463
+ "instance_type_name": args.instance_type,
464
+ "ssh_key_names": [args.ssh_key_name],
465
+ }
466
+ if args.quantity:
467
+ payload["quantity"] = args.quantity
468
+ if args.name:
469
+ payload["name"] = args.name
470
+ if user_data:
471
+ payload["user_data"] = user_data
472
+ if args.image_id:
473
+ payload["image"] = {"id": args.image_id}
474
+ elif args.image_family:
475
+ payload["image"] = {"family": args.image_family}
476
+
477
+ client = LambdaClient(args.api_key)
478
+
479
+ log(
480
+ "Requesting instance launch "
481
+ f"(region={args.region}, type={args.instance_type}, quantity={args.quantity})"
482
+ )
483
+
484
+ try:
485
+ instance_ids = client.launch_instances(payload)
486
+ except requests.HTTPError as exc:
487
+ error(f"Instance launch failed: {exc}")
488
+ if exc.response is not None:
489
+ warn(f"Response content: {exc.response.text}")
490
+ return 1
491
+
492
+ log(f"Requested instance IDs: {', '.join(instance_ids)}")
493
+
494
+ if args.skip_wait:
495
+ log("Skipping wait (--skip-wait supplied).")
496
+ return 0
497
+
498
+ instances: List[Dict[str, object]] = []
499
+ for instance_id in instance_ids:
500
+ try:
501
+ instance = client.wait_for_instance(
502
+ instance_id=instance_id,
503
+ poll_seconds=args.poll_seconds,
504
+ max_wait_minutes=args.max_wait_minutes,
505
+ )
506
+ except (RuntimeError, TimeoutError, requests.HTTPError) as exc:
507
+ error(f"Failed while waiting for instance {instance_id}: {exc}")
508
+ return 1
509
+ instances.append(instance)
510
+
511
+ for instance in instances:
512
+ ip = instance.get("ip") or "<pending>"
513
+ name = instance.get("name") or instance.get("id")
514
+ log(f"Instance {name} is active with public IP {ip}")
515
+ if ip and ip != "<pending>":
516
+ log(
517
+ f"SSH command: ssh -i /path/to/key.pem ubuntu@{ip}"
518
+ )
519
+
520
+ if args.auto_start:
521
+ log(
522
+ "Auto-start enabled. Training is running inside tmux; attach with "
523
+ f"`ssh ...` then `tmux attach -t {args.tmux_session}`."
524
+ )
525
+ else:
526
+ log(
527
+ "Auto-start disabled. After SSH'ing in, run "
528
+ f"`cd ~/{args.repo_dir} && bash {args.run_script}`."
529
+ )
530
+
531
+ return 0
532
+
533
+
534
+ if __name__ == "__main__":
535
+ sys.exit(main())