Sync docs from GitHub
Browse files- .env.template +57 -0
- .gitattributes +1 -0
- GCS_UPLOAD_GUIDE.md +323 -0
- LAMBDA_MANUAL_SETUP.md +317 -0
- QUICKSTART.md +288 -0
- README.md +615 -599
- aquarat2.png +3 -0
- launch_lambda.py +296 -0
- scripts/launch_hyperbolic_training.py +701 -0
- scripts/launch_lambda_training.py +535 -0
.env.template
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nanochatAquaRat Environment Template
|
| 2 |
+
# Copy this file to .env and fill in actual values before running any scripts.
|
| 3 |
+
|
| 4 |
+
# -----------------------------------------------------------------------------
|
| 5 |
+
# Cloud GPU Providers
|
| 6 |
+
# -----------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# Lambda Labs automation (scripts/launch_lambda_training.py, launch_lambda.py)
|
| 9 |
+
# https://cloud.lambdalabs.com/api-keys
|
| 10 |
+
LAMBDA_API_KEY=your-lambda-api-key-here
|
| 11 |
+
|
| 12 |
+
# Hyperbolic Labs automation (scripts/launch_hyperbolic_training.py)
|
| 13 |
+
# https://app.hyperbolic.ai/settings/api-keys
|
| 14 |
+
HYPERBOLIC_API_KEY=your-hyperbolic-api-key-here
|
| 15 |
+
|
| 16 |
+
# Optional Hyperbolic defaults (override with CLI args if needed)
|
| 17 |
+
# HYPERBOLIC_REGION=us-east
|
| 18 |
+
# HYPERBOLIC_MAX_PRICE=6.00
|
| 19 |
+
|
| 20 |
+
# -----------------------------------------------------------------------------
|
| 21 |
+
# Experiment Tracking (Weights & Biases)
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
# Get your key at https://wandb.ai/authorize
|
| 24 |
+
WANDB_API_KEY=your-wandb-api-key-here
|
| 25 |
+
WANDB_PROJECT=nanochat-aquarat
|
| 26 |
+
WANDB_ENTITY=your-wandb-username-or-team-name
|
| 27 |
+
|
| 28 |
+
# Optional run metadata
|
| 29 |
+
WANDB_RUN=aquarat-$(date +%Y%m%d-%H%M%S)
|
| 30 |
+
WANDB_MODE=online # online | offline | disabled
|
| 31 |
+
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Google Cloud Storage Uploads (scripts/upload_to_gcs.sh)
|
| 34 |
+
# See GCS_UPLOAD_GUIDE.md for detailed instructions
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
# GCP project that owns your storage bucket
|
| 38 |
+
GCP_PROJECT_ID=your-gcp-project-id
|
| 39 |
+
|
| 40 |
+
# Default bucket for model artifacts (used by upload scripts and automation)
|
| 41 |
+
GCS_BUCKET=gs://your-model-bucket
|
| 42 |
+
|
| 43 |
+
# Service account credentials (recommended for automation). Point to the JSON
|
| 44 |
+
# key you download with `gcloud iam service-accounts keys create ...`
|
| 45 |
+
GOOGLE_APPLICATION_CREDENTIALS=/path/to/nanochat-gcs-key.json
|
| 46 |
+
|
| 47 |
+
# -----------------------------------------------------------------------------
|
| 48 |
+
# Optional cache/directories
|
| 49 |
+
# -----------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
# Override the default ~/.cache/nanochat location (used for checkpoints/data)
|
| 52 |
+
# NANOCHAT_BASE_DIR=/mnt/nanochat-cache
|
| 53 |
+
|
| 54 |
+
# If you pre-convert AQuA-RAT with scripts/prepare_aqua.py, set this so tasks
|
| 55 |
+
# and training scripts reuse the cached JSONL splits instead of downloading.
|
| 56 |
+
# AQUA_DATA_DIR=/mnt/datasets/aqua
|
| 57 |
+
|
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
aquarat2.png filter=lfs diff=lfs merge=lfs -text
|
GCS_UPLOAD_GUIDE.md
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Google Cloud Storage Upload Guide
|
| 2 |
+
|
| 3 |
+
After training your nanochat model on Lambda Labs, use this guide to upload all weights and artifacts to Google Cloud Storage.
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# After training completes, SSH to your Lambda instance
|
| 9 |
+
ssh ubuntu@<INSTANCE_IP>
|
| 10 |
+
|
| 11 |
+
# Navigate to project directory
|
| 12 |
+
cd ~/nanochatAquaRat
|
| 13 |
+
|
| 14 |
+
# Run upload script
|
| 15 |
+
bash scripts/upload_to_gcs.sh --bucket gs://your-bucket-name
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
The script will:
|
| 19 |
+
1. Check/install gcloud CLI if needed
|
| 20 |
+
2. Verify authentication and bucket access
|
| 21 |
+
3. Show what will be uploaded and ask for confirmation
|
| 22 |
+
4. Upload all artifacts with progress
|
| 23 |
+
5. Ask if you want to terminate the Lambda instance
|
| 24 |
+
|
| 25 |
+
## Prerequisites
|
| 26 |
+
|
| 27 |
+
### 1. Create a GCS Bucket
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# From your local machine
|
| 31 |
+
gcloud storage buckets create gs://your-bucket-name \
|
| 32 |
+
--location=us-central1 \
|
| 33 |
+
--uniform-bucket-level-access
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Or create via console: https://console.cloud.google.com/storage/create-bucket
|
| 37 |
+
|
| 38 |
+
### 2. Set Up Authentication
|
| 39 |
+
|
| 40 |
+
#### Option A: Service Account (Recommended for Automation)
|
| 41 |
+
|
| 42 |
+
On your local machine:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
# Create service account
|
| 46 |
+
gcloud iam service-accounts create nanochat-uploader \
|
| 47 |
+
--display-name="Nanochat Model Uploader"
|
| 48 |
+
|
| 49 |
+
# Grant storage permissions
|
| 50 |
+
gcloud projects add-iam-policy-binding YOUR_PROJECT_ID \
|
| 51 |
+
--member="serviceAccount:nanochat-uploader@YOUR_PROJECT_ID.iam.gserviceaccount.com" \
|
| 52 |
+
--role="roles/storage.objectCreator"
|
| 53 |
+
|
| 54 |
+
# Create and download key
|
| 55 |
+
gcloud iam service-accounts keys create ~/nanochat-key.json \
|
| 56 |
+
--iam-account=nanochat-uploader@YOUR_PROJECT_ID.iam.gserviceaccount.com
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Copy key to Lambda instance:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
scp ~/nanochat-key.json ubuntu@<INSTANCE_IP>:~/
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
On Lambda instance:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
gcloud auth activate-service-account --key-file=~/nanochat-key.json
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
#### Option B: User Account (Simpler for Manual Use)
|
| 72 |
+
|
| 73 |
+
On Lambda instance:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
gcloud auth login
|
| 77 |
+
# Follow the prompts in your browser
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Usage
|
| 81 |
+
|
| 82 |
+
### Basic Upload
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
bash scripts/upload_to_gcs.sh --bucket gs://my-models
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### Custom Run Name
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
bash scripts/upload_to_gcs.sh \
|
| 92 |
+
--bucket gs://my-models \
|
| 93 |
+
--run-name depth20-experiment1
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Exclude Large Dataset Files
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
bash scripts/upload_to_gcs.sh \
|
| 100 |
+
--bucket gs://my-models \
|
| 101 |
+
--exclude-data
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### Dry Run (Preview Only)
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
bash scripts/upload_to_gcs.sh \
|
| 108 |
+
--bucket gs://my-models \
|
| 109 |
+
--dry-run
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Auto-Terminate After Upload
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
bash scripts/upload_to_gcs.sh \
|
| 116 |
+
--bucket gs://my-models \
|
| 117 |
+
--auto-terminate
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## What Gets Uploaded
|
| 121 |
+
|
| 122 |
+
From `~/.cache/nanochat/`:
|
| 123 |
+
|
| 124 |
+
| Directory | Contents | Typical Size |
|
| 125 |
+
|-----------|----------|--------------|
|
| 126 |
+
| `checkpoints/` | Model weights (.pt, .pkl files) | 500MB - 2GB |
|
| 127 |
+
| `report/` | Training reports and markdown summaries | 1-10MB |
|
| 128 |
+
| `tokenizer/` | BPE tokenizer files | 10-50MB |
|
| 129 |
+
| `eval_bundle/` | Evaluation datasets | 50-200MB |
|
| 130 |
+
| `aqua/` | AQuA-RAT dataset (optional) | 100-500MB |
|
| 131 |
+
| `mechanistic_interpretability/` | DeepMind interp tools | 10-100MB |
|
| 132 |
+
|
| 133 |
+
**Total**: Typically 1-5 GB per training run
|
| 134 |
+
|
| 135 |
+
## Upload Structure
|
| 136 |
+
|
| 137 |
+
Files are organized in GCS as:
|
| 138 |
+
|
| 139 |
+
```
|
| 140 |
+
gs://your-bucket/
|
| 141 |
+
└── runs/
|
| 142 |
+
├── aquarat-20251023-143022/
|
| 143 |
+
│ ├── checkpoints/
|
| 144 |
+
│ │ ├── base_final.pt
|
| 145 |
+
│ │ ├── mid_final.pt
|
| 146 |
+
│ │ ├── sft_final.pt
|
| 147 |
+
│ │ └── rl_final.pt
|
| 148 |
+
│ ├── report/
|
| 149 |
+
│ │ └── report.md
|
| 150 |
+
│ ├── tokenizer/
|
| 151 |
+
│ └── ...
|
| 152 |
+
└── depth20-experiment1/
|
| 153 |
+
└── ...
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
## Download Weights Later
|
| 157 |
+
|
| 158 |
+
### Download Entire Run
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
gsutil -m rsync -r \
|
| 162 |
+
gs://your-bucket/runs/aquarat-20251023-143022/ \
|
| 163 |
+
./local_checkpoints/
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### Download Just Checkpoints
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
gsutil -m cp -r \
|
| 170 |
+
gs://your-bucket/runs/aquarat-20251023-143022/checkpoints/ \
|
| 171 |
+
./checkpoints/
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### Download Single File
|
| 175 |
+
|
| 176 |
+
```bash
|
| 177 |
+
gsutil cp \
|
| 178 |
+
gs://your-bucket/runs/aquarat-20251023-143022/checkpoints/rl_final.pt \
|
| 179 |
+
./rl_final.pt
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## Cost Considerations
|
| 183 |
+
|
| 184 |
+
### Storage Costs
|
| 185 |
+
|
| 186 |
+
- Standard storage: ~$0.02/GB/month
|
| 187 |
+
- Nearline storage (30+ days): ~$0.01/GB/month
|
| 188 |
+
- Coldline storage (90+ days): ~$0.004/GB/month
|
| 189 |
+
|
| 190 |
+
**Example**: 2GB model stored for 1 month = $0.04
|
| 191 |
+
|
| 192 |
+
### Network Egress
|
| 193 |
+
|
| 194 |
+
- Upload (ingress): **Free**
|
| 195 |
+
- Download to same region: **Free**
|
| 196 |
+
- Download to internet: ~$0.12/GB
|
| 197 |
+
|
| 198 |
+
**Tip**: Keep your GCS bucket in the same region as your compute for free transfers.
|
| 199 |
+
|
| 200 |
+
### Lifecycle Management
|
| 201 |
+
|
| 202 |
+
Auto-delete or move to cheaper storage after 90 days:
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
cat > lifecycle.json << EOF
|
| 206 |
+
{
|
| 207 |
+
"lifecycle": {
|
| 208 |
+
"rule": [
|
| 209 |
+
{
|
| 210 |
+
"action": {"type": "Delete"},
|
| 211 |
+
"condition": {"age": 90}
|
| 212 |
+
}
|
| 213 |
+
]
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
EOF
|
| 217 |
+
|
| 218 |
+
gsutil lifecycle set lifecycle.json gs://your-bucket
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
## Troubleshooting
|
| 222 |
+
|
| 223 |
+
### "gcloud: command not found"
|
| 224 |
+
|
| 225 |
+
The script auto-installs gcloud on Linux. If it fails:
|
| 226 |
+
|
| 227 |
+
```bash
|
| 228 |
+
curl https://sdk.cloud.google.com | bash
|
| 229 |
+
exec -l $SHELL
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
### "Permission denied" Error
|
| 233 |
+
|
| 234 |
+
Check your service account has `roles/storage.objectCreator`:
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
gcloud projects get-iam-policy YOUR_PROJECT_ID \
|
| 238 |
+
--flatten="bindings[].members" \
|
| 239 |
+
--filter="bindings.members:serviceAccount:nanochat-uploader*"
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
### Upload Interrupted
|
| 243 |
+
|
| 244 |
+
The script uses `gsutil rsync`, so re-running will resume:
|
| 245 |
+
|
| 246 |
+
```bash
|
| 247 |
+
bash scripts/upload_to_gcs.sh --bucket gs://your-bucket
|
| 248 |
+
# Will skip already-uploaded files
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
### Verify Upload
|
| 252 |
+
|
| 253 |
+
```bash
|
| 254 |
+
# List all files in the run
|
| 255 |
+
gsutil ls -r gs://your-bucket/runs/your-run-name/
|
| 256 |
+
|
| 257 |
+
# Check specific checkpoints
|
| 258 |
+
gsutil ls gs://your-bucket/runs/your-run-name/checkpoints/
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
## Integration with Lambda Launcher
|
| 262 |
+
|
| 263 |
+
You can add GCS credentials to the automated launcher:
|
| 264 |
+
|
| 265 |
+
```python
|
| 266 |
+
# In scripts/launch_lambda_training.py
|
| 267 |
+
# Add to the cloud-init user-data:
|
| 268 |
+
|
| 269 |
+
write_files:
|
| 270 |
+
- path: /home/ubuntu/.config/gcloud/application_default_credentials.json
|
| 271 |
+
content: |
|
| 272 |
+
{your service account key JSON}
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
Or pass as environment variable:
|
| 276 |
+
|
| 277 |
+
```bash
|
| 278 |
+
export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
|
| 279 |
+
|
| 280 |
+
python scripts/launch_lambda_training.py \
|
| 281 |
+
--inject-env GOOGLE_APPLICATION_CREDENTIALS \
|
| 282 |
+
...
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
## Best Practices
|
| 286 |
+
|
| 287 |
+
1. **Name runs descriptively**: Use `--run-name depth20-lr1e4-batch32`
|
| 288 |
+
2. **Exclude data when iterating**: Use `--exclude-data` to save bandwidth
|
| 289 |
+
3. **Dry run first**: Always use `--dry-run` to preview
|
| 290 |
+
4. **Service accounts for automation**: Easier than user auth
|
| 291 |
+
5. **Regional buckets**: Match Lambda instance region when possible
|
| 292 |
+
6. **Lifecycle policies**: Auto-archive old models
|
| 293 |
+
7. **Download to Lambda**: If re-training, download previous checkpoints to Lambda first
|
| 294 |
+
|
| 295 |
+
## Security Notes
|
| 296 |
+
|
| 297 |
+
- Service account keys are sensitive - treat like passwords
|
| 298 |
+
- Use least-privilege IAM roles (don't grant `roles/owner`)
|
| 299 |
+
- Rotate service account keys regularly
|
| 300 |
+
- Consider Workload Identity if using GKE
|
| 301 |
+
- Don't commit keys to git (add to `.gitignore`)
|
| 302 |
+
|
| 303 |
+
## Support
|
| 304 |
+
|
| 305 |
+
- GCS Documentation: https://cloud.google.com/storage/docs
|
| 306 |
+
- gsutil Reference: https://cloud.google.com/storage/docs/gsutil
|
| 307 |
+
- IAM Permissions: https://cloud.google.com/storage/docs/access-control/iam-permissions
|
| 308 |
+
|
| 309 |
+
---
|
| 310 |
+
|
| 311 |
+
**Quick Reference**:
|
| 312 |
+
```bash
|
| 313 |
+
# Upload
|
| 314 |
+
bash scripts/upload_to_gcs.sh --bucket gs://my-bucket
|
| 315 |
+
|
| 316 |
+
# Download
|
| 317 |
+
gsutil -m cp -r gs://my-bucket/runs/NAME/checkpoints/ ./
|
| 318 |
+
|
| 319 |
+
# List runs
|
| 320 |
+
gsutil ls gs://my-bucket/runs/
|
| 321 |
+
|
| 322 |
+
# Delete old run
|
| 323 |
+
gsutil -m rm -r gs://my-bucket/runs/old-run/
|
LAMBDA_MANUAL_SETUP.md
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Manual Setup Guide for Lambda Labs
|
| 2 |
+
|
| 3 |
+
This guide walks you through manually launching and configuring a Lambda Labs GPU instance for training the nanochatAquaRat model with RL on AQuA-RAT.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
1. **Lambda Labs Account**: Sign up at https://cloud.lambdalabs.com
|
| 8 |
+
2. **SSH Key**: Add your SSH public key to Lambda Labs
|
| 9 |
+
3. **W&B Account**: Sign up at https://wandb.ai and get your API key
|
| 10 |
+
4. **Sufficient Credits**: Ensure you have enough credits (~$24 for 8 hours on 8x H100)
|
| 11 |
+
|
| 12 |
+
## Step 1: Add SSH Key to Lambda Labs
|
| 13 |
+
|
| 14 |
+
1. Go to https://cloud.lambdalabs.com/ssh-keys
|
| 15 |
+
2. Click "Add SSH Key"
|
| 16 |
+
3. Paste your public SSH key (from `~/.ssh/id_rsa.pub` or `~/.ssh/id_ed25519.pub`)
|
| 17 |
+
4. Give it a name (e.g., "my-laptop")
|
| 18 |
+
5. Click "Add SSH Key"
|
| 19 |
+
|
| 20 |
+
## Step 2: Launch Instance via Web Dashboard
|
| 21 |
+
|
| 22 |
+
1. Navigate to https://cloud.lambdalabs.com/instances
|
| 23 |
+
2. Click **"Launch instance"**
|
| 24 |
+
3. Configure your instance:
|
| 25 |
+
- **Instance type**: Select `gpu_8x_h100_sxm5` (8x NVIDIA H100 80GB SXM5)
|
| 26 |
+
- For testing: Use `gpu_1x_a10` or smaller
|
| 27 |
+
- **Region**: Choose a region with availability (e.g., `us-west-1`)
|
| 28 |
+
- **SSH Keys**: Select your SSH key
|
| 29 |
+
- **Filesystem**: (Optional) If you have persistent storage
|
| 30 |
+
4. Click **"Launch instance"**
|
| 31 |
+
5. Wait 1-2 minutes for the instance to boot
|
| 32 |
+
|
| 33 |
+
## Step 3: Note Instance Details
|
| 34 |
+
|
| 35 |
+
Once the instance is running, note:
|
| 36 |
+
- **Instance ID**: (e.g., `0123456789abcdef`)
|
| 37 |
+
- **IP Address**: (e.g., `123.45.67.89`)
|
| 38 |
+
- **SSH Command**: Shown in the web interface
|
| 39 |
+
|
| 40 |
+
## Step 4: Connect to Instance
|
| 41 |
+
|
| 42 |
+
Open your terminal and connect:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
ssh ubuntu@<INSTANCE_IP>
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Example:
|
| 49 |
+
```bash
|
| 50 |
+
ssh ubuntu@123.45.67.89
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Step 5: Set Up Environment
|
| 54 |
+
|
| 55 |
+
Once connected, run these commands:
|
| 56 |
+
|
| 57 |
+
### 5.1 Create Environment File
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Create .env file with your credentials
|
| 61 |
+
cat > ~/.env << 'EOF'
|
| 62 |
+
WANDB_API_KEY=your-wandb-api-key-here
|
| 63 |
+
WANDB_PROJECT=nanochat-aquarat
|
| 64 |
+
WANDB_ENTITY=your-wandb-username-or-team
|
| 65 |
+
EOF
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Replace `your-wandb-api-key-here` with your actual W&B API key (get it from https://wandb.ai/authorize)
|
| 69 |
+
|
| 70 |
+
### 5.2 Clone Repository
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
cd ~
|
| 74 |
+
git clone https://github.com/HarleyCoops/nanochatAquaRat.git
|
| 75 |
+
cd nanochatAquaRat
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 5.3 Copy Environment Variables
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
# Copy the .env file to the project directory
|
| 82 |
+
cp ~/.env .env
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Step 6: Start Training
|
| 86 |
+
|
| 87 |
+
You have two options for running the training:
|
| 88 |
+
|
| 89 |
+
### Option A: Run in Screen Session (Recommended)
|
| 90 |
+
|
| 91 |
+
This allows you to detach and let training continue even if you disconnect:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Start a screen session
|
| 95 |
+
screen -S training
|
| 96 |
+
|
| 97 |
+
# Run the training script
|
| 98 |
+
bash run_aquarat_small.sh
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
**Screen Commands:**
|
| 102 |
+
- **Detach from screen**: Press `Ctrl+A` then `D`
|
| 103 |
+
- **Reattach to screen**: `screen -r training`
|
| 104 |
+
- **List all screen sessions**: `screen -ls`
|
| 105 |
+
- **Kill a screen session**: `screen -X -S training quit`
|
| 106 |
+
|
| 107 |
+
### Option B: Run Directly (Blocks Terminal)
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
# Run training directly (terminal will be blocked)
|
| 111 |
+
bash run_aquarat_small.sh 2>&1 | tee training.log
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
This saves output to `training.log` for later review.
|
| 115 |
+
|
| 116 |
+
## Step 7: Monitor Training
|
| 117 |
+
|
| 118 |
+
### Monitor via Terminal
|
| 119 |
+
|
| 120 |
+
If using screen:
|
| 121 |
+
```bash
|
| 122 |
+
# Reattach to see live output
|
| 123 |
+
screen -r training
|
| 124 |
+
|
| 125 |
+
# Or tail the report
|
| 126 |
+
tail -f ~/.cache/nanochat/report/report.md
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Monitor via Weights & Biases
|
| 130 |
+
|
| 131 |
+
1. Go to https://wandb.ai
|
| 132 |
+
2. Navigate to your project: `nanochat-aquarat`
|
| 133 |
+
3. View real-time metrics, losses, and generated samples
|
| 134 |
+
|
| 135 |
+
Key metrics to watch:
|
| 136 |
+
- `rl/acc` - Accuracy on AQuA-RAT
|
| 137 |
+
- `rl/mean_reward` - Average reward per sample
|
| 138 |
+
- `rl/kl_letter_mean` - KL divergence from initial policy
|
| 139 |
+
- `rl/letter_margin_mean` - Confidence in letter choices
|
| 140 |
+
- `attn/entropy_mean` - Attention mechanism entropy
|
| 141 |
+
|
| 142 |
+
## Step 8: Training Timeline
|
| 143 |
+
|
| 144 |
+
For the **small (depth=8) model**:
|
| 145 |
+
- **Base pretraining**: ~1-2 hours
|
| 146 |
+
- **Mid-training**: ~30 minutes
|
| 147 |
+
- **SFT**: ~30 minutes
|
| 148 |
+
- **RL**: ~30 minutes
|
| 149 |
+
- **Total**: ~3-4 hours
|
| 150 |
+
|
| 151 |
+
For the **d-20 model** (561M params):
|
| 152 |
+
- **Base pretraining**: ~3-4 hours
|
| 153 |
+
- **Mid-training**: ~1 hour
|
| 154 |
+
- **SFT**: ~1 hour
|
| 155 |
+
- **RL**: ~1 hour
|
| 156 |
+
- **Total**: ~6-8 hours
|
| 157 |
+
|
| 158 |
+
## Step 9: Check Results
|
| 159 |
+
|
| 160 |
+
After training completes:
|
| 161 |
+
|
| 162 |
+
```bash
|
| 163 |
+
# View the final report
|
| 164 |
+
cat ~/.cache/nanochat/report/report.md
|
| 165 |
+
|
| 166 |
+
# Check RL checkpoint
|
| 167 |
+
ls -lh ~/.cache/nanochat/checkpoints/
|
| 168 |
+
|
| 169 |
+
# View evaluation results
|
| 170 |
+
cat ~/.cache/nanochat/evals/
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
## Step 10: Download Artifacts (Optional)
|
| 174 |
+
|
| 175 |
+
If you want to save the trained model locally:
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
# From your local machine (not on the Lambda instance):
|
| 179 |
+
# Download checkpoints
|
| 180 |
+
scp -r ubuntu@<INSTANCE_IP>:~/.cache/nanochat/checkpoints ./local_checkpoints/
|
| 181 |
+
|
| 182 |
+
# Download reports
|
| 183 |
+
scp -r ubuntu@<INSTANCE_IP>:~/.cache/nanochat/report ./local_reports/
|
| 184 |
+
|
| 185 |
+
# Download training logs
|
| 186 |
+
scp ubuntu@<INSTANCE_IP>:~/nanochatAquaRat/training.log ./training.log
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## Step 11: Terminate Instance
|
| 190 |
+
|
| 191 |
+
**IMPORTANT**: Remember to terminate your instance when done to avoid charges!
|
| 192 |
+
|
| 193 |
+
### Via Web Dashboard:
|
| 194 |
+
1. Go to https://cloud.lambdalabs.com/instances
|
| 195 |
+
2. Find your instance
|
| 196 |
+
3. Click the **"..."** menu
|
| 197 |
+
4. Select **"Terminate"**
|
| 198 |
+
5. Confirm termination
|
| 199 |
+
|
| 200 |
+
### Via SSH (before disconnecting):
|
| 201 |
+
```bash
|
| 202 |
+
# Shutdown the instance (will auto-terminate if configured)
|
| 203 |
+
sudo shutdown -h now
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
## Troubleshooting
|
| 207 |
+
|
| 208 |
+
### Issue: "Out of memory" Error
|
| 209 |
+
|
| 210 |
+
**Solution**: Reduce batch size in the training script
|
| 211 |
+
```bash
|
| 212 |
+
# Edit run_aquarat_small.sh and add these flags to the torchrun commands:
|
| 213 |
+
--device_batch_size=2 # Reduce from default
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
### Issue: W&B Not Logging
|
| 217 |
+
|
| 218 |
+
**Solution**: Check your API key
|
| 219 |
+
```bash
|
| 220 |
+
# Test W&B login
|
| 221 |
+
wandb login
|
| 222 |
+
|
| 223 |
+
# Verify environment variable
|
| 224 |
+
echo $WANDB_API_KEY
|
| 225 |
+
|
| 226 |
+
# Re-run with explicit login
|
| 227 |
+
export WANDB_API_KEY=your-key-here
|
| 228 |
+
bash run_aquarat_small.sh
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
### Issue: Screen Session Lost
|
| 232 |
+
|
| 233 |
+
**Solution**: Reattach to screen
|
| 234 |
+
```bash
|
| 235 |
+
# List all screen sessions
|
| 236 |
+
screen -ls
|
| 237 |
+
|
| 238 |
+
# Reattach to the training session
|
| 239 |
+
screen -r training
|
| 240 |
+
|
| 241 |
+
# If screen says "Detached", force attach
|
| 242 |
+
screen -d -r training
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
### Issue: Dataset Download Slow
|
| 246 |
+
|
| 247 |
+
**Solution**: The script downloads data in parallel. Wait for completion or reduce number of shards.
|
| 248 |
+
|
| 249 |
+
### Issue: SSH Connection Drops
|
| 250 |
+
|
| 251 |
+
**Solution**: Use `screen` or `tmux` to keep processes running
|
| 252 |
+
```bash
|
| 253 |
+
# If you didn't use screen initially and got disconnected:
|
| 254 |
+
# Reconnect and check if the process is still running
|
| 255 |
+
ps aux | grep python
|
| 256 |
+
|
| 257 |
+
# If running, you can monitor the log files:
|
| 258 |
+
tail -f ~/.cache/nanochat/report/report.md
|
| 259 |
+
tail -f ~/nanochatAquaRat/training.log
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
## Cost Estimation
|
| 263 |
+
|
| 264 |
+
**8x H100 SXM5** pricing (as of reference):
|
| 265 |
+
- ~$3.00/hour per GPU
|
| 266 |
+
- 8 GPUs = $24/hour
|
| 267 |
+
- Small model (4 hours) = ~$96
|
| 268 |
+
- d-20 model (8 hours) = ~$192
|
| 269 |
+
|
| 270 |
+
**Budget-friendly testing options:**
|
| 271 |
+
- 1x A10 (24GB): ~$0.60/hour - Good for testing pipeline
|
| 272 |
+
- 1x A6000 (48GB): ~$0.80/hour - Can run small model
|
| 273 |
+
- 2x A100 (40GB): ~$2.20/hour - Can run d-20 with reduced batch size
|
| 274 |
+
|
| 275 |
+
## Quick Reference Commands
|
| 276 |
+
|
| 277 |
+
```bash
|
| 278 |
+
# SSH to instance
|
| 279 |
+
ssh ubuntu@<INSTANCE_IP>
|
| 280 |
+
|
| 281 |
+
# Start training in screen
|
| 282 |
+
screen -S training
|
| 283 |
+
bash run_aquarat_small.sh
|
| 284 |
+
|
| 285 |
+
# Detach from screen
|
| 286 |
+
Ctrl+A then D
|
| 287 |
+
|
| 288 |
+
# Reattach to screen
|
| 289 |
+
screen -r training
|
| 290 |
+
|
| 291 |
+
# Monitor W&B
|
| 292 |
+
Open: https://wandb.ai
|
| 293 |
+
|
| 294 |
+
# View live report
|
| 295 |
+
tail -f ~/.cache/nanochat/report/report.md
|
| 296 |
+
|
| 297 |
+
# Check GPU usage
|
| 298 |
+
nvidia-smi
|
| 299 |
+
|
| 300 |
+
# Terminate instance (via dashboard)
|
| 301 |
+
https://cloud.lambdalabs.com/instances
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
## Support
|
| 305 |
+
|
| 306 |
+
- **Lambda Labs Support**: https://lambdalabs.com/support
|
| 307 |
+
- **W&B Support**: https://docs.wandb.ai
|
| 308 |
+
- **nanochat Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
|
| 309 |
+
|
| 310 |
+
## Next Steps
|
| 311 |
+
|
| 312 |
+
After your model is trained:
|
| 313 |
+
1. Download checkpoints for inference
|
| 314 |
+
2. Use the web interface: `python -m scripts.chat_web`
|
| 315 |
+
3. Test via CLI: `python -m scripts.chat_cli`
|
| 316 |
+
4. Share your results on W&B
|
| 317 |
+
5. Fine-tune on additional datasets if desired
|
QUICKSTART.md
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Lambda Labs Training Quickstart
|
| 2 |
+
|
| 3 |
+
This project provides two ways to run model training on Lambda Labs:
|
| 4 |
+
|
| 5 |
+
1. **Automated API Script** (Recommended) - Fully automated deployment
|
| 6 |
+
2. **Manual Setup** - Step-by-step web dashboard approach
|
| 7 |
+
|
| 8 |
+
## Prerequisites
|
| 9 |
+
|
| 10 |
+
Both methods require:
|
| 11 |
+
- Lambda Labs account with API key
|
| 12 |
+
- SSH key added to Lambda Labs
|
| 13 |
+
- Weights & Biases account with API key
|
| 14 |
+
- Sufficient credits (~$24/hour for 8x H100)
|
| 15 |
+
|
| 16 |
+
## Method 1: Automated API Script (Recommended)
|
| 17 |
+
|
| 18 |
+
### Setup
|
| 19 |
+
|
| 20 |
+
1. **Set environment variables:**
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
export LAMBDA_API_KEY='your-lambda-api-key'
|
| 24 |
+
export WANDB_API_KEY='your-wandb-api-key'
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
Get your Lambda API key from: https://cloud.lambdalabs.com/api-keys
|
| 28 |
+
Get your W&B API key from: https://wandb.ai/authorize
|
| 29 |
+
|
| 30 |
+
2. **Install dependencies:**
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
pip install lambda-cloud-client
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Usage
|
| 37 |
+
|
| 38 |
+
**Check available instances:**
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python launch_lambda.py --list-types
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
**Launch and start training:**
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# Launch 8x H100 instance (recommended for d-20 model)
|
| 48 |
+
python launch_lambda.py --instance-type gpu_8x_h100_sxm5 --region us-west-1
|
| 49 |
+
|
| 50 |
+
# Launch smaller instance for testing (depth-8 model)
|
| 51 |
+
python launch_lambda.py --instance-type gpu_1x_a100 --region us-west-1
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
**Just launch without deploying:**
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
python launch_lambda.py --instance-type gpu_8x_h100_sxm5 --no-deploy
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
The script will:
|
| 61 |
+
1. ✓ Launch the instance
|
| 62 |
+
2. ✓ Wait for it to be ready
|
| 63 |
+
3. ✓ Deploy the code
|
| 64 |
+
4. ✓ Start training in a screen session
|
| 65 |
+
5. ✓ Provide connection details
|
| 66 |
+
|
| 67 |
+
### Monitor Training
|
| 68 |
+
|
| 69 |
+
After launching, SSH to the instance:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
ssh ubuntu@<INSTANCE_IP>
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
Then attach to the screen session:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
screen -r training
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Or view logs:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
tail -f ~/nanochatAquaRat/training.log
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## Method 2: Manual Setup
|
| 88 |
+
|
| 89 |
+
For detailed step-by-step instructions, see [LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md)
|
| 90 |
+
|
| 91 |
+
**Quick summary:**
|
| 92 |
+
1. Go to https://cloud.lambdalabs.com/instances
|
| 93 |
+
2. Launch instance manually
|
| 94 |
+
3. SSH to instance
|
| 95 |
+
4. Clone repo and set up .env
|
| 96 |
+
5. Run `bash run_aquarat_small.sh`
|
| 97 |
+
|
| 98 |
+
## Training Configuration
|
| 99 |
+
|
| 100 |
+
The `run_aquarat_small.sh` script trains a **depth-8 (smaller) model** which takes approximately **3-4 hours** on 8x H100.
|
| 101 |
+
|
| 102 |
+
### What Gets Trained:
|
| 103 |
+
|
| 104 |
+
1. **Base Model** (depth-8, ~60M params)
|
| 105 |
+
- Pretrained on limited corpus (24 shards)
|
| 106 |
+
- Faster iteration for testing
|
| 107 |
+
|
| 108 |
+
2. **Mid-Training**
|
| 109 |
+
- Conversation format adaptation
|
| 110 |
+
- Tool use capabilities
|
| 111 |
+
|
| 112 |
+
3. **Supervised Fine-Tuning (SFT)**
|
| 113 |
+
- Fine-tuned on AQuA-RAT dataset
|
| 114 |
+
- Multiple-choice math reasoning
|
| 115 |
+
|
| 116 |
+
4. **Reinforcement Learning (RL)**
|
| 117 |
+
- GRPO-style RL on AQuA-RAT
|
| 118 |
+
- KL divergence tracking
|
| 119 |
+
- Letter-choice logit margin analysis
|
| 120 |
+
- Attention mechanism logging
|
| 121 |
+
|
| 122 |
+
### W&B Metrics Logged:
|
| 123 |
+
|
| 124 |
+
- `rl/acc` - Answer accuracy
|
| 125 |
+
- `rl/mean_reward` - Average reward
|
| 126 |
+
- `rl/kl_letter_mean` - Policy drift (letter-level)
|
| 127 |
+
- `rl/kl_sequence_mean` - Policy drift (sequence-level)
|
| 128 |
+
- `rl/letter_margin_mean` - Confidence in answers
|
| 129 |
+
- `attn/entropy_mean` - Attention patterns
|
| 130 |
+
|
| 131 |
+
## Model Sizes Available
|
| 132 |
+
|
| 133 |
+
You can modify `run_aquarat_small.sh` to change the model depth:
|
| 134 |
+
|
| 135 |
+
| Depth | Params | Training Time | Recommended Instance |
|
| 136 |
+
|-------|--------|---------------|---------------------|
|
| 137 |
+
| 8 | ~60M | 3-4 hours | 1x A100 / 2x A100 |
|
| 138 |
+
| 12 | ~180M | 4-5 hours | 4x A100 |
|
| 139 |
+
| 20 | ~561M | 6-8 hours | 8x H100 |
|
| 140 |
+
| 26 | ~1.1B | 10-12 hours | 8x H100 |
|
| 141 |
+
|
| 142 |
+
To change depth, edit the `--depth` parameter in `run_aquarat_small.sh`:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Cost Estimates
|
| 149 |
+
|
| 150 |
+
Based on Lambda Labs pricing:
|
| 151 |
+
|
| 152 |
+
| Instance Type | GPUs | Cost/Hour | Small (4h) | d-20 (8h) |
|
| 153 |
+
|------------------|--------------|-----------|------------|-----------|
|
| 154 |
+
| gpu_8x_h100_sxm5 | 8x H100 80GB | ~$24.00 | ~$96 | ~$192 |
|
| 155 |
+
| gpu_4x_a100 | 4x A100 40GB | ~$8.80 | ~$35 | ~$70 |
|
| 156 |
+
| gpu_2x_a100 | 2x A100 40GB | ~$4.40 | ~$18 | ~$35 |
|
| 157 |
+
| gpu_1x_a100 | 1x A100 40GB | ~$2.20 | ~$9 | ~$18 |
|
| 158 |
+
|
| 159 |
+
## Monitoring Options
|
| 160 |
+
|
| 161 |
+
### 1. SSH + Screen
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
ssh ubuntu@<INSTANCE_IP>
|
| 165 |
+
screen -r training
|
| 166 |
+
# Ctrl+A then D to detach
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### 2. Weights & Biases
|
| 170 |
+
|
| 171 |
+
Dashboard: https://wandb.ai
|
| 172 |
+
|
| 173 |
+
Real-time metrics, attention heatmaps, sample completions
|
| 174 |
+
|
| 175 |
+
### 3. Log Files
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
# Training log
|
| 179 |
+
tail -f ~/nanochatAquaRat/training.log
|
| 180 |
+
|
| 181 |
+
# Progress report
|
| 182 |
+
tail -f ~/.cache/nanochat/report/report.md
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## After Training
|
| 186 |
+
|
| 187 |
+
### Download Checkpoints
|
| 188 |
+
|
| 189 |
+
From your local machine:
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
scp -r ubuntu@<INSTANCE_IP>:~/.cache/nanochat/checkpoints ./checkpoints/
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
### Run Inference
|
| 196 |
+
|
| 197 |
+
On the Lambda instance:
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
# Web interface
|
| 201 |
+
python -m scripts.chat_web
|
| 202 |
+
|
| 203 |
+
# CLI interface
|
| 204 |
+
python -m scripts.chat_cli -p "What is 25 * 37?"
|
| 205 |
+
|
| 206 |
+
# Evaluate on test set
|
| 207 |
+
python -m scripts.chat_eval -- -i rl -a AQUA
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
### Don't Forget to Terminate!
|
| 211 |
+
|
| 212 |
+
**Via Dashboard:**
|
| 213 |
+
https://cloud.lambdalabs.com/instances → Terminate
|
| 214 |
+
|
| 215 |
+
**Via CLI:**
|
| 216 |
+
```bash
|
| 217 |
+
sudo shutdown -h now
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
## Troubleshooting
|
| 221 |
+
|
| 222 |
+
### Issue: API Key Not Working
|
| 223 |
+
|
| 224 |
+
```bash
|
| 225 |
+
# Verify keys are set
|
| 226 |
+
echo $LAMBDA_API_KEY
|
| 227 |
+
echo $WANDB_API_KEY
|
| 228 |
+
|
| 229 |
+
# Re-export if needed
|
| 230 |
+
export LAMBDA_API_KEY='your-key'
|
| 231 |
+
export WANDB_API_KEY='your-key'
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
### Issue: No Available Instances
|
| 235 |
+
|
| 236 |
+
Lambda Labs instances can be in high demand. Try:
|
| 237 |
+
- Different regions (`--region us-east-1`)
|
| 238 |
+
- Smaller instance types (`gpu_1x_a100`)
|
| 239 |
+
- Check availability: `python launch_lambda.py --list-types`
|
| 240 |
+
|
| 241 |
+
### Issue: Out of Memory
|
| 242 |
+
|
| 243 |
+
Edit `run_aquarat_small.sh` and reduce batch size:
|
| 244 |
+
|
| 245 |
+
```bash
|
| 246 |
+
# Add to torchrun commands:
|
| 247 |
+
--device_batch_size=2
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
### Issue: Training Stuck
|
| 251 |
+
|
| 252 |
+
Check GPU utilization:
|
| 253 |
+
|
| 254 |
+
```bash
|
| 255 |
+
nvidia-smi
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
If GPUs are idle, check for errors:
|
| 259 |
+
|
| 260 |
+
```bash
|
| 261 |
+
tail -100 ~/nanochatAquaRat/training.log
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
## Files in This Repository
|
| 265 |
+
|
| 266 |
+
- `launch_lambda.py` - Automated Lambda Labs launcher
|
| 267 |
+
- `run_aquarat_small.sh` - Training script (depth-8 model)
|
| 268 |
+
- `LAMBDA_MANUAL_SETUP.md` - Detailed manual setup guide
|
| 269 |
+
- `QUICKSTART.md` - This file
|
| 270 |
+
- `.env.template` - Environment variable template
|
| 271 |
+
|
| 272 |
+
## Support
|
| 273 |
+
|
| 274 |
+
- **Lambda Labs**: https://lambdalabs.com/support
|
| 275 |
+
- **Weights & Biases**: https://docs.wandb.ai
|
| 276 |
+
- **Project Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
|
| 277 |
+
|
| 278 |
+
## Next Steps
|
| 279 |
+
|
| 280 |
+
1. ✓ Set up Lambda Labs account and API key
|
| 281 |
+
2. ✓ Set up Weights & Biases account
|
| 282 |
+
3. ✓ Choose your method (API script or manual)
|
| 283 |
+
4. ✓ Launch instance and start training
|
| 284 |
+
5. ✓ Monitor via W&B dashboard
|
| 285 |
+
6. ✓ Download checkpoints when complete
|
| 286 |
+
7. ✓ Terminate instance to stop charges
|
| 287 |
+
|
| 288 |
+
Happy training!
|
README.md
CHANGED
|
@@ -13,440 +13,440 @@ tags:
|
|
| 13 |
---
|
| 14 |
|
| 15 |
<div align="center">
|
| 16 |
-
|
| 17 |
-

|
| 18 |
-
|
| 19 |
-
# nanochatAquaRat
|
| 20 |
-
|
| 21 |
-
**Training Language Models with Reinforcement Learning on Mathematical Reasoning**
|
| 22 |
-
|
| 23 |
-
[](https://github.com/HarleyCoops/nanochatAquaRat)
|
| 24 |
-
[](LICENSE)
|
| 25 |
-
[](https://www.python.org/downloads/)
|
| 26 |
-
|
| 27 |
-
A modified version of [nanochat](https://github.com/karpathy/nanochat) trained with reinforcement learning on the [DeepMind AQuA-RAT dataset](https://huggingface.co/datasets/deepmind/aqua_rat) for algebraic reasoning and multiple-choice problem solving.
|
| 28 |
-
|
| 29 |
-
[Quick Start](#quick-start) • [Dataset](#dataset-structure) • [Modifications](#modifications-from-base-nanochat) • [Training](#training-pipeline) • [Results](#results)
|
| 30 |
-
|
| 31 |
-
</div>
|
| 32 |
-
|
| 33 |
-
---
|
| 34 |
-
|
| 35 |
-
## Table of Contents
|
| 36 |
-
|
| 37 |
-
- [Overview](#overview)
|
| 38 |
-
- [The Base: nanochat Framework](#the-base-nanochat-framework)
|
| 39 |
-
- [Dataset Structure](#dataset-structure)
|
| 40 |
-
- [Modifications from Base nanochat](#modifications-from-base-nanochat)
|
| 41 |
-
- [Training Pipeline](#training-pipeline)
|
| 42 |
-
- [Quick Start](#quick-start)
|
| 43 |
-
- [File Structure](#file-structure)
|
| 44 |
-
- [Monitoring & Visualization](#monitoring--visualization)
|
| 45 |
-
- [Results](#results)
|
| 46 |
-
|
| 47 |
-
---
|
| 48 |
-
|
| 49 |
-
## Overview
|
| 50 |
-
|
| 51 |
-
This project adapts the **nanochat** training framework (originally designed for GSM8K numerical reasoning) to work with **AQuA-RAT** (Algebra Question Answering with Rationales), a dataset of ~97,000 algebraic word problems with multiple-choice answers (A-E) and natural language solution rationales.
|
| 52 |
-
|
| 53 |
-
### Why This Matters
|
| 54 |
-
|
| 55 |
-
- **Domain Transfer**: Demonstrates how to adapt a mathematical reasoning pipeline from free-form numeric answers to multiple-choice format
|
| 56 |
-
- **RL on Math**: Implements GRPO-style reinforcement learning with reward shaping for categorical outputs
|
| 57 |
-
- **Mechanistic Interpretability**: Integrates attention analysis during training to understand model reasoning patterns
|
| 58 |
-
- **Production-Ready**: Includes automated Lambda Labs and Hyperbolic Labs deployment helpers for cloud GPU training
|
| 59 |
-
|
| 60 |
-
### Key Results
|
| 61 |
-
|
| 62 |
-
| Model | Parameters | Training Time | AQuA-RAT Dev Accuracy |
|
| 63 |
-
|-------|------------|---------------|----------------------|
|
| 64 |
-
| depth-8 | ~60M | 3-4 hours | 30-50% |
|
| 65 |
-
| depth-20 | ~561M | 6-8 hours | 40-60% |
|
| 66 |
-
|
| 67 |
-
---
|
| 68 |
-
|
| 69 |
-
## The Base: nanochat Framework
|
| 70 |
-
|
| 71 |
-
**nanochat** is a minimalist yet complete pipeline for training transformer language models from scratch, created by Andrej Karpathy. It implements:
|
| 72 |
-
|
| 73 |
-
- **Custom tokenizer**: BPE tokenizer written in Rust for performance
|
| 74 |
-
- **Training stages**: Pretraining → Mid-training → SFT → RL
|
| 75 |
-
- **Evaluation suite**: CORE benchmarks and task-specific metrics
|
| 76 |
-
- **Optimizations**: Memory-efficient training, gradient accumulation, distributed training
|
| 77 |
-
|
| 78 |
-
**Original focus**: Training on GSM8K (Grade School Math 8K) with free-form numeric answers.
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
---
|
| 82 |
-
|
| 83 |
-
## Dataset Structure
|
| 84 |
-
|
| 85 |
-
### AQuA-RAT Format
|
| 86 |
-
|
| 87 |
-
The [DeepMind AQuA-RAT dataset](https://github.com/deepmind/AQuA) contains algebraic reasoning problems in JSON format:
|
| 88 |
-
|
| 89 |
-
```json
|
| 90 |
-
{
|
| 91 |
-
"question": "A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?",
|
| 92 |
-
"options": [
|
| 93 |
-
"A) 53 km",
|
| 94 |
-
"B) 55 km",
|
| 95 |
-
"C) 52 km",
|
| 96 |
-
"D) 60 km",
|
| 97 |
-
"E) 50 km"
|
| 98 |
-
],
|
| 99 |
-
"rationale": "The distance that the person traveled = 20 * 2.5 = 50 km. Answer: E",
|
| 100 |
-
"correct": "E"
|
| 101 |
-
}
|
| 102 |
-
```
|
| 103 |
-
|
| 104 |
-
**Dataset splits**:
|
| 105 |
-
- Training: 97,467 problems
|
| 106 |
-
- Development: 254 problems
|
| 107 |
-
- Test: 254 problems
|
| 108 |
-
|
| 109 |
-
**Key characteristics**:
|
| 110 |
-
- Multiple-choice (A-E) format
|
| 111 |
-
- Algebraic word problems
|
| 112 |
-
- Natural language rationales
|
| 113 |
-
- Topics: arithmetic, algebra, geometry, probability
|
| 114 |
-
|
| 115 |
-
### Comparison: GSM8K vs AQuA-RAT
|
| 116 |
-
|
| 117 |
-
| Aspect | GSM8K (Original) | AQuA-RAT (This Project) |
|
| 118 |
-
|--------|------------------|-------------------------|
|
| 119 |
-
| **Format** | Free-form numeric | Multiple choice (A-E) |
|
| 120 |
-
| **Answer** | Single number | Letter choice |
|
| 121 |
-
| **Size** | 8,500 problems | 97,700 problems |
|
| 122 |
-
| **Difficulty** | Elementary school | High school algebra |
|
| 123 |
-
| **Rationale** | Step-by-step | Natural language |
|
| 124 |
-
| **Evaluation** | Exact match on number | Categorical accuracy |
|
| 125 |
-
|
| 126 |
-
---
|
| 127 |
-
|
| 128 |
-
## Modifications from Base nanochat
|
| 129 |
-
|
| 130 |
-
To adapt nanochat from GSM8K to AQuA-RAT, we modified the following components:
|
| 131 |
-
|
| 132 |
-
### 1. Dataset Loader (`scripts/prepare_aqua.py`)
|
| 133 |
-
|
| 134 |
-
**Created new file** to download and format AQuA-RAT:
|
| 135 |
-
|
| 136 |
-
```python
|
| 137 |
-
# New file: scripts/prepare_aqua.py
|
| 138 |
-
### 1. Dataset Preparation (`scripts/prepare_aqua.py`)
|
| 139 |
-
|
| 140 |
-
- Uses `datasets.load_dataset("deepmind/aqua_rat")` and optionally caps split sizes.
|
| 141 |
-
- Emits JSONL files (`train.jsonl`, `validation.jsonl`, `test.jsonl`) compatible with
|
| 142 |
-
the conversation schema used throughout nanochat.
|
| 143 |
-
- Defaults to `~/.cache/nanochat/aqua`, but accepts `--output_dir` overrides so
|
| 144 |
-
launchers can bundle their own artifact.
|
| 145 |
-
|
| 146 |
-
```python
|
| 147 |
-
def format_example(row):
|
| 148 |
-
options = row["options"]
|
| 149 |
-
assistant_content = [
|
| 150 |
-
{"type": "text", "text": row["rationale"].strip()},
|
| 151 |
-
{"type": "text", "text": f"Answer: {row['correct'].strip().upper()}"},
|
| 152 |
-
]
|
| 153 |
-
return {
|
| 154 |
-
"messages": [
|
| 155 |
-
{"role": "user", "content": _render_user_prompt(row["question"], options)},
|
| 156 |
-
{"role": "assistant", "content": assistant_content},
|
| 157 |
-
],
|
| 158 |
-
"letters": letters,
|
| 159 |
-
"answer_letter": correct,
|
| 160 |
-
}
|
| 161 |
-
```
|
| 162 |
-
|
| 163 |
-
### 2. Task Module (`tasks/aqua.py`)
|
| 164 |
-
|
| 165 |
-
- Accepts optional `data_dir` (or `AQUA_DATA_DIR` / `NANOCHAT_AQUA_DIR`) so the task
|
| 166 |
-
can read the cached JSONL; otherwise falls back to Hugging Face.
|
| 167 |
-
- Provides `_render_user_prompt` to format the question/options using the common
|
| 168 |
-
multiple-choice helper and `_extract_letter` to score completions.
|
| 169 |
-
- Returns conversations whose assistant messages include both the rationale and a
|
| 170 |
-
final `Answer: <LETTER>` line for SFT, while `evaluate()` only cares about the letter.
|
| 171 |
-
|
| 172 |
-
```python
|
| 173 |
-
def _extract_letter(text, default=None):
|
| 174 |
-
answer_match = re.search(r"answer\s*[:\-]\s*([A-E])", text, flags=re.IGNORECASE)
|
| 175 |
-
if answer_match:
|
| 176 |
-
return answer_match.group(1).upper()
|
| 177 |
-
match = LETTER_RE.search(text)
|
| 178 |
-
return match.group(1).upper() if match else default
|
| 179 |
-
```
|
| 180 |
-
|
| 181 |
-
**Key differences from GSM8K**:
|
| 182 |
-
- Numeric extraction → Letter extraction
|
| 183 |
-
- Free-form answer → Fixed choices A-E
|
| 184 |
-
- Exact number match → Categorical match
|
| 185 |
-
|
| 186 |
-
### 3. RL Training (`scripts/chat_rl.py`)
|
| 187 |
-
|
| 188 |
-
**Modified** to support both GSM8K and AQuA-RAT:
|
| 189 |
-
|
| 190 |
-
Key updates:
|
| 191 |
-
|
| 192 |
-
- `train_task` / `val_task` now instantiate `AQUA(...)` instead of `GSM8K(...)`.
|
| 193 |
-
- Rewards reuse the task's `evaluate()` helper so any completion containing
|
| 194 |
-
“Answer: X” (or the first bare letter) is scored correctly.
|
| 195 |
-
- The validation helper became `run_aqua_eval`, still reporting pass@k accuracy
|
| 196 |
-
across sampled completions.
|
| 197 |
-
- CLI overrides remain the same because the script continues to rely on the
|
| 198 |
-
nanochat configurator (`--run`, `--temperature`, `--max_new_tokens`, …).
|
| 199 |
-
|
| 200 |
-
### 4. Evaluation (`scripts/chat_eval.py`)
|
| 201 |
-
|
| 202 |
-
- Registered `'AQUA'` in the task registry so `-a AQUA` just works.
|
| 203 |
-
- Added a 20% random-guess baseline when aggregating the ChatCORE metric.
|
| 204 |
-
- The categorical evaluation path reuses `run_categorical_eval`, clamping logits
|
| 205 |
-
to the available letters before scoring.
|
| 206 |
-
|
| 207 |
-
### 5. Training Script (`run_aquarat_small.sh`)
|
| 208 |
-
|
| 209 |
-
**What changed vs upstream nanochat**:
|
| 210 |
-
|
| 211 |
-
```bash
|
| 212 |
-
# (Optional) Cache the dataset locally as JSONL
|
| 213 |
-
python -m scripts.prepare_aqua --output_dir "$NANOCHAT_BASE_DIR/aqua"
|
| 214 |
-
|
| 215 |
-
# Mid-training now samples from the AQuA mixture
|
| 216 |
-
torchrun -m scripts.mid_train -- --run=demo --num_iterations=200
|
| 217 |
-
|
| 218 |
-
# SFT stage emphasises AQuA problems
|
| 219 |
-
torchrun -m scripts.sft_train -- --run=demo --aqua_train_examples=20000
|
| 220 |
-
|
| 221 |
-
# RL fine-tuning rewards the correct letter on AQuA-RAT
|
| 222 |
-
torchrun -m scripts.chat_rl -- --run=demo --temperature=0.7 --max_new_tokens=64
|
| 223 |
-
```
|
| 224 |
-
|
| 225 |
-
- **`tasks/aqua.py`** loads AQuA-RAT either from Hugging Face or the cached JSONL
|
| 226 |
-
splits, formats questions as conversations, and scores completions by letter.
|
| 227 |
-
- **`scripts/mid_train.py`** extends the original Reasoning+Chat mixture with a
|
| 228 |
-
50k slice of AQuA so the model sees multiple-choice algebra earlier.
|
| 229 |
-
- **`scripts/chat_sft.py`** replaces the GSM8K component with AQuA, keeping ARC,
|
| 230 |
-
SmolTalk, and identity prompts for general chat coverage.
|
| 231 |
-
- **`scripts/chat_rl.py`** retools the GRPO loop to sample, reward, and evaluate
|
| 232 |
-
AQuA answers (categorical accuracy instead of GSM8K free-form math).
|
| 233 |
-
- **`scripts/chat_eval.py`** registers the new AQuA task so `chat_eval` can report
|
| 234 |
-
categorical accuracy alongside ARC/MMLU/GSM8K/HumanEval.
|
| 235 |
-
|
| 236 |
-
---
|
| 237 |
-
|
| 238 |
-
## Training Pipeline
|
| 239 |
-
|
| 240 |
-
### Stage 1: Base Pretraining (50-60% of time)
|
| 241 |
-
|
| 242 |
-
**What happens**: Model learns language from scratch on FineWeb corpus
|
| 243 |
-
|
| 244 |
-
```bash
|
| 245 |
-
torchrun --nproc_per_node=8 -m scripts.base_train -- --depth=8
|
| 246 |
-
```
|
| 247 |
-
|
| 248 |
-
**Duration**: 1.5-2 hours on 8x H100
|
| 249 |
-
**Output**: Base checkpoint with general language understanding
|
| 250 |
-
**Metrics**: Validation loss, CORE benchmark scores
|
| 251 |
-
|
| 252 |
-
### Stage 2: Mid-Training (12-15% of time)
|
| 253 |
-
|
| 254 |
-
**What happens**: Teach conversation format and special tokens
|
| 255 |
-
|
| 256 |
-
```bash
|
| 257 |
-
torchrun --nproc_per_node=8 -m scripts.mid_train
|
| 258 |
-
```
|
| 259 |
-
|
| 260 |
-
**Duration**: 30 minutes
|
| 261 |
-
**Output**: Conversational checkpoint
|
| 262 |
-
**Metrics**: Format adherence, tool use capability
|
| 263 |
-
|
| 264 |
-
### Stage 3: Supervised Fine-Tuning (12-15% of time)
|
| 265 |
-
|
| 266 |
-
**What happens**: Fine-tune on AQuA-RAT with ground-truth solutions
|
| 267 |
-
|
| 268 |
-
```bash
|
| 269 |
-
torchrun --nproc_per_node=8 -m scripts.sft_train -- \
|
| 270 |
-
--aqua_train_examples=20000 \
|
| 271 |
-
--aqua_val_examples=254
|
| 272 |
-
```
|
| 273 |
-
|
| 274 |
-
**Duration**: 30 minutes
|
| 275 |
-
**Output**: AQuA-tuned checkpoint
|
| 276 |
-
**Metrics**: Dev set accuracy (categorical)
|
| 277 |
-
|
| 278 |
-
### Stage 4: Reinforcement Learning (12-15% of time)
|
| 279 |
-
|
| 280 |
-
**What happens**: Policy gradient learning with GRPO algorithm
|
| 281 |
-
|
| 282 |
-
```bash
|
| 283 |
-
torchrun --nproc_per_node=1 -m scripts.chat_rl -- \
|
| 284 |
-
--temperature=0.7 \
|
| 285 |
-
--max_new_tokens=64
|
| 286 |
-
```
|
| 287 |
-
|
| 288 |
-
**Duration**: 30 minutes
|
| 289 |
-
**Algorithm**: Group Relative Policy Optimization (GRPO)
|
| 290 |
-
**Reward**: +1.0 for correct letter, +0.1 for valid letter format
|
| 291 |
-
**Output**: RL-optimized checkpoint
|
| 292 |
-
|
| 293 |
-
**Logged metrics**:
|
| 294 |
-
- `rl/acc` - Accuracy on training samples
|
| 295 |
-
- `rl/mean_reward` - Average reward per generation
|
| 296 |
-
- `rl/kl_letter_mean` - KL divergence at decision point
|
| 297 |
-
- `rl/kl_sequence_mean` - Full sequence KL
|
| 298 |
-
- `rl/letter_margin_mean` - Confidence (logit gap)
|
| 299 |
-
- `attn/entropy_mean` - Attention mechanism patterns
|
| 300 |
-
|
| 301 |
-
---
|
| 302 |
-
|
| 303 |
-
## Quick Start
|
| 304 |
-
|
| 305 |
-
### Repo Setup & Rust Toolchain
|
| 306 |
-
|
| 307 |
-
- Clone with submodules so the `rustbpe` tokenizer sources are present:
|
| 308 |
-
```bash
|
| 309 |
-
git clone --recurse-submodules https://github.com/HarleyCoops/nanochatAquaRat.git
|
| 310 |
-
```
|
| 311 |
-
For existing clones run `git submodule update --init --recursive` before building.
|
| 312 |
-
- Install Rust (needed for the tokenizer build). On Linux/macOS follow [https://rustup.rs](https://rustup.rs). On Windows, after installing rustup, ensure the toolchain is MSVC x86\_64 and the cargo bin directory is on `PATH`:
|
| 313 |
-
```powershell
|
| 314 |
-
$env:Path += ";$env:USERPROFILE\.cargo\bin"
|
| 315 |
-
setx PATH "$env:Path"
|
| 316 |
-
setx CARGO_HOME "$env:USERPROFILE\.cargo"
|
| 317 |
-
setx RUSTUP_HOME "$env:USERPROFILE\.rustup"
|
| 318 |
-
rustup set default-host x86_64-pc-windows-msvc
|
| 319 |
-
rustup default stable-x86_64-pc-windows-msvc
|
| 320 |
-
cargo --version
|
| 321 |
-
rustup --version
|
| 322 |
-
```
|
| 323 |
-
- Build the tokenizer once per machine:
|
| 324 |
-
```bash
|
| 325 |
-
uv run maturin develop
|
| 326 |
-
```
|
| 327 |
-
|
| 328 |
-
### Option 1: Lambda Labs Cloud (Automated)
|
| 329 |
-
|
| 330 |
-
Use the automation helper for one-command deployment:
|
| 331 |
-
|
| 332 |
-
```bash
|
| 333 |
-
# Set credentials
|
| 334 |
-
export LAMBDA_API_KEY='your-lambda-api-key'
|
| 335 |
-
export WANDB_API_KEY='your-wandb-api-key'
|
| 336 |
-
|
| 337 |
-
# Launch with auto-start
|
| 338 |
-
python scripts/launch_lambda_training.py \
|
| 339 |
-
--ssh-key-name your_lambda_ssh_key \
|
| 340 |
-
--instance-type gpu_8x_h100_sxm5 \
|
| 341 |
-
--region us-west-1 \
|
| 342 |
-
--auto-start \
|
| 343 |
-
--inject-env WANDB_API_KEY
|
| 344 |
-
```
|
| 345 |
-
|
| 346 |
-
The script provisions the instance, clones this repository, sets up environment variables, and starts training in a tmux session.
|
| 347 |
-
|
| 348 |
-
**Monitor training**:
|
| 349 |
-
```bash
|
| 350 |
-
# SSH to instance
|
| 351 |
-
ssh ubuntu@<INSTANCE_IP>
|
| 352 |
-
|
| 353 |
-
# Attach to tmux session
|
| 354 |
-
tmux attach -t nanochat-train
|
| 355 |
-
|
| 356 |
-
# Or view logs
|
| 357 |
-
tail -f ~/nanochatAquaRat/training.log
|
| 358 |
-
```
|
| 359 |
-
|
| 360 |
-
### Option 2: Hyperbolic Labs Cloud (Automated)
|
| 361 |
-
|
| 362 |
-
Spin up on-demand GPUs via Hyperbolic's marketplace API:
|
| 363 |
-
|
| 364 |
-
```bash
|
| 365 |
-
# Set credentials
|
| 366 |
-
export HYPERBOLIC_API_KEY='your-hyperbolic-api-key'
|
| 367 |
-
export WANDB_API_KEY='your-wandb-api-key'
|
| 368 |
-
|
| 369 |
-
# Launch with auto-start
|
| 370 |
-
python scripts/launch_hyperbolic_training.py \
|
| 371 |
-
--gpu-count 1 \
|
| 372 |
-
--region us-east \
|
| 373 |
-
--auto-start \
|
| 374 |
-
--inject-env WANDB_API_KEY
|
| 375 |
-
```
|
| 376 |
-
|
| 377 |
-
The launcher discovers an available node (respecting `--region`, `--supplier`, or `--max-price` filters), provisions it, copies your `.env`, and optionally starts training in tmux. Use `--list` to inspect available marketplace inventory without launching.
|
| 378 |
-
|
| 379 |
-
### Option 3: Lambda Labs Cloud (Manual)
|
| 380 |
-
|
| 381 |
-
For step-by-step control, see [LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md).
|
| 382 |
-
|
| 383 |
-
**Quick summary**:
|
| 384 |
-
1. Launch instance at https://cloud.lambdalabs.com/instances
|
| 385 |
-
2. SSH to instance: `ssh ubuntu@<IP>`
|
| 386 |
-
3. Clone repo: `git clone <repo-url> && cd nanochatAquaRat`
|
| 387 |
-
4. Set up credentials: `echo "WANDB_API_KEY=..." > .env`
|
| 388 |
-
5. Run training: `bash run_aquarat_small.sh`
|
| 389 |
-
|
| 390 |
-
### Option 4: Hyperbolic VM (Manual)
|
| 391 |
-
|
| 392 |
-
For marketplace nodes without automation access, follow this lightweight bootstrap:
|
| 393 |
-
|
| 394 |
-
1. Provision a GPU VM from the Hyperbolic console and copy the SSH command (including `-p <port>` and username).
|
| 395 |
-
2. SSH in and install prerequisites:
|
| 396 |
-
```bash
|
| 397 |
-
sudo apt-get update
|
| 398 |
-
sudo apt-get install -y git curl unzip build-essential python3 python3-venv tmux
|
| 399 |
-
git clone https://github.com/HarleyCoops/nanochatAquaRat.git
|
| 400 |
-
cd nanochatAquaRat
|
| 401 |
-
```
|
| 402 |
-
3. Create `.env` with the required keys (WANDB, GCS bucket, AQUA path) and upload your GCP service-account JSON to the VM, e.g. `scp -P <port> C:\path\to\credentials.json user@<ip>:/home/user/gcp-sa.json`.
|
| 403 |
-
4. Install tooling and build the tokenizer:
|
| 404 |
-
```bash
|
| 405 |
-
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 406 |
-
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
|
| 407 |
-
source "$HOME/.cargo/env"
|
| 408 |
-
export PATH="$HOME/.local/bin:$PATH"
|
| 409 |
-
uv venv && uv sync --extra gpu
|
| 410 |
-
source .venv/bin/activate
|
| 411 |
-
uv run maturin develop
|
| 412 |
-
uv run python -m scripts.tok_train
|
| 413 |
-
```
|
| 414 |
-
5. Install the Google Cloud SDK, authenticate, and stage the cached AQuA splits (or regenerate them):
|
| 415 |
-
```bash
|
| 416 |
-
curl -sSL https://sdk.cloud.google.com | bash
|
| 417 |
-
source "$HOME/.bashrc"
|
| 418 |
-
gcloud auth login --no-launch-browser
|
| 419 |
-
gcloud config set project <your-project-id>
|
| 420 |
-
gcloud storage cp gs://nanochat-aquarat-datasets/datasets/aqua/aqua_cache.zip .
|
| 421 |
-
unzip -o aqua_cache.zip -d ~/aqua_cache
|
| 422 |
-
export AQUA_DATA_DIR=$HOME/aqua_cache
|
| 423 |
-
```
|
| 424 |
-
6. Fetch the identity conversation bundle (required for SFT) and the evaluation bundle once so CORE metrics don’t fail:
|
| 425 |
-
```bash
|
| 426 |
-
cd ~/.cache/nanochat
|
| 427 |
-
curl -L -o identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
| 428 |
-
curl -L -o eval_bundle.zip https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
| 429 |
-
unzip -q eval_bundle.zip && rm eval_bundle.zip
|
| 430 |
-
cd ~/nanochatAquaRat
|
| 431 |
-
```
|
| 432 |
-
7. Launch the desired script, e.g. `CUDA_VISIBLE_DEVICES=0 bash run_aquarat_lite.sh` or the full `run_aquarat_small.sh`.
|
| 433 |
-
8. Monitor training via tmux/W&B and terminate the VM from Hyperbolic when the run finishes to stop billing.
|
| 434 |
-
|
| 435 |
-
### Option 4: Alternative Launcher Script
|
| 436 |
-
|
| 437 |
-
A simplified launcher is also available:
|
| 438 |
-
|
| 439 |
-
```bash
|
| 440 |
-
export LAMBDA_API_KEY='your-key'
|
| 441 |
-
export WANDB_API_KEY='your-key'
|
| 442 |
-
|
| 443 |
-
python launch_lambda.py \
|
| 444 |
-
--instance-type gpu_8x_h100_sxm5 \
|
| 445 |
-
--region us-west-1
|
| 446 |
-
```
|
| 447 |
-
|
| 448 |
-
See [QUICKSTART.md](QUICKSTART.md) for details.
|
| 449 |
-
|
| 450 |
### Option 5: Local/Custom Setup
|
| 451 |
|
| 452 |
```bash
|
|
@@ -464,175 +464,191 @@ bash run_aquarat_small.sh
|
|
| 464 |
- 40GB+ GPU memory per GPU
|
| 465 |
- ~100GB disk space
|
| 466 |
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
## File Structure
|
| 470 |
|
| 471 |
-
|
| 472 |
-
nanochatAquaRat/
|
| 473 |
-
├── nanochat/… # Vendored upstream nanochat package
|
| 474 |
-
├── scripts/
|
| 475 |
-
│ ├── base_train.py # Base pretraining stage
|
| 476 |
-
│ ├── mid_train.py # Mid-training (now includes AQuA)
|
| 477 |
-
│ ├── chat_sft.py # Chat SFT pipeline
|
| 478 |
-
│ ├── sft_train.py # Shim so `-m scripts.sft_train` still works
|
| 479 |
-
│ ├── chat_rl.py # Reinforcement learning on AQuA-RAT
|
| 480 |
-
│ ├── chat_eval.py # Evaluation harness (adds AQuA task)
|
| 481 |
-
│ ├── prepare_aqua.py # AQuA-RAT JSONL exporter
|
| 482 |
-
│ ├── launch_lambda_training.py # Lambda Labs automation
|
| 483 |
-
│ ├── launch_hyperbolic_training.py # Hyperbolic Labs automation
|
| 484 |
-
│ └── upload_to_gcs.sh # Artifact helper
|
| 485 |
-
├── tasks/
|
| 486 |
-
│ ├── aqua.py # AQuA-RAT task implementation
|
| 487 |
-
│ ├── arc.py / gsm8k.py / mmlu.py # Other reasoning tasks
|
| 488 |
-
│ └── …
|
| 489 |
-
├── run_aquarat_small.sh # End-to-end orchestration
|
| 490 |
-
├── pyproject.toml / uv.lock # Environment definitions
|
| 491 |
-
└── README.md
|
| 492 |
-
```
|
| 493 |
-
### Summary of Code Changes
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
| `launch_lambda.py` / `scripts/launch_lambda_training.py` | EXISTING | Lambda Labs support retained |
|
| 507 |
|
| 508 |
---
|
| 509 |
|
| 510 |
-
##
|
| 511 |
-
|
| 512 |
-
All metrics stream to [Weights & Biases](https://wandb.ai) in real-time:
|
| 513 |
-
|
| 514 |
-
**Training Metrics**:
|
| 515 |
-
- Loss curves (pretraining, SFT, RL)
|
| 516 |
-
- Learning rate schedules
|
| 517 |
-
- Gradient norms
|
| 518 |
-
|
| 519 |
-
**RL Metrics**:
|
| 520 |
-
- Policy performance (accuracy, rewards)
|
| 521 |
-
- KL divergence from initial policy
|
| 522 |
-
- Letter-choice distributions (A-E)
|
| 523 |
-
- Confidence margins
|
| 524 |
-
|
| 525 |
-
**Interpretability**:
|
| 526 |
-
- Attention heatmaps per layer
|
| 527 |
-
- Entropy evolution across training
|
| 528 |
-
- Token-level attention weights
|
| 529 |
|
| 530 |
-
Example W&B dashboard:
|
| 531 |
-
```
|
| 532 |
-
rl/acc ━━━━━━━━━━ 0.45
|
| 533 |
-
rl/kl_letter_mean ━━━━━━━━━━ 0.12
|
| 534 |
-
rl/letter_margin_mean ━━━━━━━━━━ 2.34
|
| 535 |
-
attn/entropy_mean ━━━━━━━━━━ 3.21
|
| 536 |
```
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
#
|
| 541 |
-
|
| 542 |
-
#
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
-
|
| 587 |
-
-
|
| 588 |
-
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
-
|
| 592 |
-
-
|
| 593 |
-
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
-
|
| 628 |
-
-
|
| 629 |
-
-
|
| 630 |
-
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
-
|
| 638 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
---
|
| 14 |
|
| 15 |
<div align="center">
|
| 16 |
+
|
| 17 |
+

|
| 18 |
+
|
| 19 |
+
# nanochatAquaRat
|
| 20 |
+
|
| 21 |
+
**Training Language Models with Reinforcement Learning on Mathematical Reasoning**
|
| 22 |
+
|
| 23 |
+
[](https://github.com/HarleyCoops/nanochatAquaRat)
|
| 24 |
+
[](LICENSE)
|
| 25 |
+
[](https://www.python.org/downloads/)
|
| 26 |
+
|
| 27 |
+
A modified version of [nanochat](https://github.com/karpathy/nanochat) trained with reinforcement learning on the [DeepMind AQuA-RAT dataset](https://huggingface.co/datasets/deepmind/aqua_rat) for algebraic reasoning and multiple-choice problem solving.
|
| 28 |
+
|
| 29 |
+
[Quick Start](#quick-start) • [Dataset](#dataset-structure) • [Modifications](#modifications-from-base-nanochat) • [Training](#training-pipeline) • [Results](#results)
|
| 30 |
+
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Table of Contents
|
| 36 |
+
|
| 37 |
+
- [Overview](#overview)
|
| 38 |
+
- [The Base: nanochat Framework](#the-base-nanochat-framework)
|
| 39 |
+
- [Dataset Structure](#dataset-structure)
|
| 40 |
+
- [Modifications from Base nanochat](#modifications-from-base-nanochat)
|
| 41 |
+
- [Training Pipeline](#training-pipeline)
|
| 42 |
+
- [Quick Start](#quick-start)
|
| 43 |
+
- [File Structure](#file-structure)
|
| 44 |
+
- [Monitoring & Visualization](#monitoring--visualization)
|
| 45 |
+
- [Results](#results)
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## Overview
|
| 50 |
+
|
| 51 |
+
This project adapts the **nanochat** training framework (originally designed for GSM8K numerical reasoning) to work with **AQuA-RAT** (Algebra Question Answering with Rationales), a dataset of ~97,000 algebraic word problems with multiple-choice answers (A-E) and natural language solution rationales.
|
| 52 |
+
|
| 53 |
+
### Why This Matters
|
| 54 |
+
|
| 55 |
+
- **Domain Transfer**: Demonstrates how to adapt a mathematical reasoning pipeline from free-form numeric answers to multiple-choice format
|
| 56 |
+
- **RL on Math**: Implements GRPO-style reinforcement learning with reward shaping for categorical outputs
|
| 57 |
+
- **Mechanistic Interpretability**: Integrates attention analysis during training to understand model reasoning patterns
|
| 58 |
+
- **Production-Ready**: Includes automated Lambda Labs and Hyperbolic Labs deployment helpers for cloud GPU training
|
| 59 |
+
|
| 60 |
+
### Key Results
|
| 61 |
+
|
| 62 |
+
| Model | Parameters | Training Time | AQuA-RAT Dev Accuracy |
|
| 63 |
+
|-------|------------|---------------|----------------------|
|
| 64 |
+
| depth-8 | ~60M | 3-4 hours | 30-50% |
|
| 65 |
+
| depth-20 | ~561M | 6-8 hours | 40-60% |
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## The Base: nanochat Framework
|
| 70 |
+
|
| 71 |
+
**nanochat** is a minimalist yet complete pipeline for training transformer language models from scratch, created by Andrej Karpathy. It implements:
|
| 72 |
+
|
| 73 |
+
- **Custom tokenizer**: BPE tokenizer written in Rust for performance
|
| 74 |
+
- **Training stages**: Pretraining → Mid-training → SFT → RL
|
| 75 |
+
- **Evaluation suite**: CORE benchmarks and task-specific metrics
|
| 76 |
+
- **Optimizations**: Memory-efficient training, gradient accumulation, distributed training
|
| 77 |
+
|
| 78 |
+
**Original focus**: Training on GSM8K (Grade School Math 8K) with free-form numeric answers.
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Dataset Structure
|
| 84 |
+
|
| 85 |
+
### AQuA-RAT Format
|
| 86 |
+
|
| 87 |
+
The [DeepMind AQuA-RAT dataset](https://github.com/deepmind/AQuA) contains algebraic reasoning problems in JSON format:
|
| 88 |
+
|
| 89 |
+
```json
|
| 90 |
+
{
|
| 91 |
+
"question": "A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?",
|
| 92 |
+
"options": [
|
| 93 |
+
"A) 53 km",
|
| 94 |
+
"B) 55 km",
|
| 95 |
+
"C) 52 km",
|
| 96 |
+
"D) 60 km",
|
| 97 |
+
"E) 50 km"
|
| 98 |
+
],
|
| 99 |
+
"rationale": "The distance that the person traveled = 20 * 2.5 = 50 km. Answer: E",
|
| 100 |
+
"correct": "E"
|
| 101 |
+
}
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
**Dataset splits**:
|
| 105 |
+
- Training: 97,467 problems
|
| 106 |
+
- Development: 254 problems
|
| 107 |
+
- Test: 254 problems
|
| 108 |
+
|
| 109 |
+
**Key characteristics**:
|
| 110 |
+
- Multiple-choice (A-E) format
|
| 111 |
+
- Algebraic word problems
|
| 112 |
+
- Natural language rationales
|
| 113 |
+
- Topics: arithmetic, algebra, geometry, probability
|
| 114 |
+
|
| 115 |
+
### Comparison: GSM8K vs AQuA-RAT
|
| 116 |
+
|
| 117 |
+
| Aspect | GSM8K (Original) | AQuA-RAT (This Project) |
|
| 118 |
+
|--------|------------------|-------------------------|
|
| 119 |
+
| **Format** | Free-form numeric | Multiple choice (A-E) |
|
| 120 |
+
| **Answer** | Single number | Letter choice |
|
| 121 |
+
| **Size** | 8,500 problems | 97,700 problems |
|
| 122 |
+
| **Difficulty** | Elementary school | High school algebra |
|
| 123 |
+
| **Rationale** | Step-by-step | Natural language |
|
| 124 |
+
| **Evaluation** | Exact match on number | Categorical accuracy |
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
## Modifications from Base nanochat
|
| 129 |
+
|
| 130 |
+
To adapt nanochat from GSM8K to AQuA-RAT, we modified the following components:
|
| 131 |
+
|
| 132 |
+
### 1. Dataset Loader (`scripts/prepare_aqua.py`)
|
| 133 |
+
|
| 134 |
+
**Created new file** to download and format AQuA-RAT:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
# New file: scripts/prepare_aqua.py
|
| 138 |
+
### 1. Dataset Preparation (`scripts/prepare_aqua.py`)
|
| 139 |
+
|
| 140 |
+
- Uses `datasets.load_dataset("deepmind/aqua_rat")` and optionally caps split sizes.
|
| 141 |
+
- Emits JSONL files (`train.jsonl`, `validation.jsonl`, `test.jsonl`) compatible with
|
| 142 |
+
the conversation schema used throughout nanochat.
|
| 143 |
+
- Defaults to `~/.cache/nanochat/aqua`, but accepts `--output_dir` overrides so
|
| 144 |
+
launchers can bundle their own artifact.
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
def format_example(row):
|
| 148 |
+
options = row["options"]
|
| 149 |
+
assistant_content = [
|
| 150 |
+
{"type": "text", "text": row["rationale"].strip()},
|
| 151 |
+
{"type": "text", "text": f"Answer: {row['correct'].strip().upper()}"},
|
| 152 |
+
]
|
| 153 |
+
return {
|
| 154 |
+
"messages": [
|
| 155 |
+
{"role": "user", "content": _render_user_prompt(row["question"], options)},
|
| 156 |
+
{"role": "assistant", "content": assistant_content},
|
| 157 |
+
],
|
| 158 |
+
"letters": letters,
|
| 159 |
+
"answer_letter": correct,
|
| 160 |
+
}
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### 2. Task Module (`tasks/aqua.py`)
|
| 164 |
+
|
| 165 |
+
- Accepts optional `data_dir` (or `AQUA_DATA_DIR` / `NANOCHAT_AQUA_DIR`) so the task
|
| 166 |
+
can read the cached JSONL; otherwise falls back to Hugging Face.
|
| 167 |
+
- Provides `_render_user_prompt` to format the question/options using the common
|
| 168 |
+
multiple-choice helper and `_extract_letter` to score completions.
|
| 169 |
+
- Returns conversations whose assistant messages include both the rationale and a
|
| 170 |
+
final `Answer: <LETTER>` line for SFT, while `evaluate()` only cares about the letter.
|
| 171 |
+
|
| 172 |
+
```python
|
| 173 |
+
def _extract_letter(text, default=None):
|
| 174 |
+
answer_match = re.search(r"answer\s*[:\-]\s*([A-E])", text, flags=re.IGNORECASE)
|
| 175 |
+
if answer_match:
|
| 176 |
+
return answer_match.group(1).upper()
|
| 177 |
+
match = LETTER_RE.search(text)
|
| 178 |
+
return match.group(1).upper() if match else default
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
**Key differences from GSM8K**:
|
| 182 |
+
- Numeric extraction → Letter extraction
|
| 183 |
+
- Free-form answer → Fixed choices A-E
|
| 184 |
+
- Exact number match → Categorical match
|
| 185 |
+
|
| 186 |
+
### 3. RL Training (`scripts/chat_rl.py`)
|
| 187 |
+
|
| 188 |
+
**Modified** to support both GSM8K and AQuA-RAT:
|
| 189 |
+
|
| 190 |
+
Key updates:
|
| 191 |
+
|
| 192 |
+
- `train_task` / `val_task` now instantiate `AQUA(...)` instead of `GSM8K(...)`.
|
| 193 |
+
- Rewards reuse the task's `evaluate()` helper so any completion containing
|
| 194 |
+
“Answer: X” (or the first bare letter) is scored correctly.
|
| 195 |
+
- The validation helper became `run_aqua_eval`, still reporting pass@k accuracy
|
| 196 |
+
across sampled completions.
|
| 197 |
+
- CLI overrides remain the same because the script continues to rely on the
|
| 198 |
+
nanochat configurator (`--run`, `--temperature`, `--max_new_tokens`, …).
|
| 199 |
+
|
| 200 |
+
### 4. Evaluation (`scripts/chat_eval.py`)
|
| 201 |
+
|
| 202 |
+
- Registered `'AQUA'` in the task registry so `-a AQUA` just works.
|
| 203 |
+
- Added a 20% random-guess baseline when aggregating the ChatCORE metric.
|
| 204 |
+
- The categorical evaluation path reuses `run_categorical_eval`, clamping logits
|
| 205 |
+
to the available letters before scoring.
|
| 206 |
+
|
| 207 |
+
### 5. Training Script (`run_aquarat_small.sh`)
|
| 208 |
+
|
| 209 |
+
**What changed vs upstream nanochat**:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
# (Optional) Cache the dataset locally as JSONL
|
| 213 |
+
python -m scripts.prepare_aqua --output_dir "$NANOCHAT_BASE_DIR/aqua"
|
| 214 |
+
|
| 215 |
+
# Mid-training now samples from the AQuA mixture
|
| 216 |
+
torchrun -m scripts.mid_train -- --run=demo --num_iterations=200
|
| 217 |
+
|
| 218 |
+
# SFT stage emphasises AQuA problems
|
| 219 |
+
torchrun -m scripts.sft_train -- --run=demo --aqua_train_examples=20000
|
| 220 |
+
|
| 221 |
+
# RL fine-tuning rewards the correct letter on AQuA-RAT
|
| 222 |
+
torchrun -m scripts.chat_rl -- --run=demo --temperature=0.7 --max_new_tokens=64
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
- **`tasks/aqua.py`** loads AQuA-RAT either from Hugging Face or the cached JSONL
|
| 226 |
+
splits, formats questions as conversations, and scores completions by letter.
|
| 227 |
+
- **`scripts/mid_train.py`** extends the original Reasoning+Chat mixture with a
|
| 228 |
+
50k slice of AQuA so the model sees multiple-choice algebra earlier.
|
| 229 |
+
- **`scripts/chat_sft.py`** replaces the GSM8K component with AQuA, keeping ARC,
|
| 230 |
+
SmolTalk, and identity prompts for general chat coverage.
|
| 231 |
+
- **`scripts/chat_rl.py`** retools the GRPO loop to sample, reward, and evaluate
|
| 232 |
+
AQuA answers (categorical accuracy instead of GSM8K free-form math).
|
| 233 |
+
- **`scripts/chat_eval.py`** registers the new AQuA task so `chat_eval` can report
|
| 234 |
+
categorical accuracy alongside ARC/MMLU/GSM8K/HumanEval.
|
| 235 |
+
|
| 236 |
+
---
|
| 237 |
+
|
| 238 |
+
## Training Pipeline
|
| 239 |
+
|
| 240 |
+
### Stage 1: Base Pretraining (50-60% of time)
|
| 241 |
+
|
| 242 |
+
**What happens**: Model learns language from scratch on FineWeb corpus
|
| 243 |
+
|
| 244 |
+
```bash
|
| 245 |
+
torchrun --nproc_per_node=8 -m scripts.base_train -- --depth=8
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
**Duration**: 1.5-2 hours on 8x H100
|
| 249 |
+
**Output**: Base checkpoint with general language understanding
|
| 250 |
+
**Metrics**: Validation loss, CORE benchmark scores
|
| 251 |
+
|
| 252 |
+
### Stage 2: Mid-Training (12-15% of time)
|
| 253 |
+
|
| 254 |
+
**What happens**: Teach conversation format and special tokens
|
| 255 |
+
|
| 256 |
+
```bash
|
| 257 |
+
torchrun --nproc_per_node=8 -m scripts.mid_train
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
**Duration**: 30 minutes
|
| 261 |
+
**Output**: Conversational checkpoint
|
| 262 |
+
**Metrics**: Format adherence, tool use capability
|
| 263 |
+
|
| 264 |
+
### Stage 3: Supervised Fine-Tuning (12-15% of time)
|
| 265 |
+
|
| 266 |
+
**What happens**: Fine-tune on AQuA-RAT with ground-truth solutions
|
| 267 |
+
|
| 268 |
+
```bash
|
| 269 |
+
torchrun --nproc_per_node=8 -m scripts.sft_train -- \
|
| 270 |
+
--aqua_train_examples=20000 \
|
| 271 |
+
--aqua_val_examples=254
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
**Duration**: 30 minutes
|
| 275 |
+
**Output**: AQuA-tuned checkpoint
|
| 276 |
+
**Metrics**: Dev set accuracy (categorical)
|
| 277 |
+
|
| 278 |
+
### Stage 4: Reinforcement Learning (12-15% of time)
|
| 279 |
+
|
| 280 |
+
**What happens**: Policy gradient learning with GRPO algorithm
|
| 281 |
+
|
| 282 |
+
```bash
|
| 283 |
+
torchrun --nproc_per_node=1 -m scripts.chat_rl -- \
|
| 284 |
+
--temperature=0.7 \
|
| 285 |
+
--max_new_tokens=64
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
**Duration**: 30 minutes
|
| 289 |
+
**Algorithm**: Group Relative Policy Optimization (GRPO)
|
| 290 |
+
**Reward**: +1.0 for correct letter, +0.1 for valid letter format
|
| 291 |
+
**Output**: RL-optimized checkpoint
|
| 292 |
+
|
| 293 |
+
**Logged metrics**:
|
| 294 |
+
- `rl/acc` - Accuracy on training samples
|
| 295 |
+
- `rl/mean_reward` - Average reward per generation
|
| 296 |
+
- `rl/kl_letter_mean` - KL divergence at decision point
|
| 297 |
+
- `rl/kl_sequence_mean` - Full sequence KL
|
| 298 |
+
- `rl/letter_margin_mean` - Confidence (logit gap)
|
| 299 |
+
- `attn/entropy_mean` - Attention mechanism patterns
|
| 300 |
+
|
| 301 |
+
---
|
| 302 |
+
|
| 303 |
+
## Quick Start
|
| 304 |
+
|
| 305 |
+
### Repo Setup & Rust Toolchain
|
| 306 |
+
|
| 307 |
+
- Clone with submodules so the `rustbpe` tokenizer sources are present:
|
| 308 |
+
```bash
|
| 309 |
+
git clone --recurse-submodules https://github.com/HarleyCoops/nanochatAquaRat.git
|
| 310 |
+
```
|
| 311 |
+
For existing clones run `git submodule update --init --recursive` before building.
|
| 312 |
+
- Install Rust (needed for the tokenizer build). On Linux/macOS follow [https://rustup.rs](https://rustup.rs). On Windows, after installing rustup, ensure the toolchain is MSVC x86\_64 and the cargo bin directory is on `PATH`:
|
| 313 |
+
```powershell
|
| 314 |
+
$env:Path += ";$env:USERPROFILE\.cargo\bin"
|
| 315 |
+
setx PATH "$env:Path"
|
| 316 |
+
setx CARGO_HOME "$env:USERPROFILE\.cargo"
|
| 317 |
+
setx RUSTUP_HOME "$env:USERPROFILE\.rustup"
|
| 318 |
+
rustup set default-host x86_64-pc-windows-msvc
|
| 319 |
+
rustup default stable-x86_64-pc-windows-msvc
|
| 320 |
+
cargo --version
|
| 321 |
+
rustup --version
|
| 322 |
+
```
|
| 323 |
+
- Build the tokenizer once per machine:
|
| 324 |
+
```bash
|
| 325 |
+
uv run maturin develop
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
### Option 1: Lambda Labs Cloud (Automated)
|
| 329 |
+
|
| 330 |
+
Use the automation helper for one-command deployment:
|
| 331 |
+
|
| 332 |
+
```bash
|
| 333 |
+
# Set credentials
|
| 334 |
+
export LAMBDA_API_KEY='your-lambda-api-key'
|
| 335 |
+
export WANDB_API_KEY='your-wandb-api-key'
|
| 336 |
+
|
| 337 |
+
# Launch with auto-start
|
| 338 |
+
python scripts/launch_lambda_training.py \
|
| 339 |
+
--ssh-key-name your_lambda_ssh_key \
|
| 340 |
+
--instance-type gpu_8x_h100_sxm5 \
|
| 341 |
+
--region us-west-1 \
|
| 342 |
+
--auto-start \
|
| 343 |
+
--inject-env WANDB_API_KEY
|
| 344 |
+
```
|
| 345 |
+
|
| 346 |
+
The script provisions the instance, clones this repository, sets up environment variables, and starts training in a tmux session.
|
| 347 |
+
|
| 348 |
+
**Monitor training**:
|
| 349 |
+
```bash
|
| 350 |
+
# SSH to instance
|
| 351 |
+
ssh ubuntu@<INSTANCE_IP>
|
| 352 |
+
|
| 353 |
+
# Attach to tmux session
|
| 354 |
+
tmux attach -t nanochat-train
|
| 355 |
+
|
| 356 |
+
# Or view logs
|
| 357 |
+
tail -f ~/nanochatAquaRat/training.log
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
### Option 2: Hyperbolic Labs Cloud (Automated)
|
| 361 |
+
|
| 362 |
+
Spin up on-demand GPUs via Hyperbolic's marketplace API:
|
| 363 |
+
|
| 364 |
+
```bash
|
| 365 |
+
# Set credentials
|
| 366 |
+
export HYPERBOLIC_API_KEY='your-hyperbolic-api-key'
|
| 367 |
+
export WANDB_API_KEY='your-wandb-api-key'
|
| 368 |
+
|
| 369 |
+
# Launch with auto-start
|
| 370 |
+
python scripts/launch_hyperbolic_training.py \
|
| 371 |
+
--gpu-count 1 \
|
| 372 |
+
--region us-east \
|
| 373 |
+
--auto-start \
|
| 374 |
+
--inject-env WANDB_API_KEY
|
| 375 |
+
```
|
| 376 |
+
|
| 377 |
+
The launcher discovers an available node (respecting `--region`, `--supplier`, or `--max-price` filters), provisions it, copies your `.env`, and optionally starts training in tmux. Use `--list` to inspect available marketplace inventory without launching.
|
| 378 |
+
|
| 379 |
+
### Option 3: Lambda Labs Cloud (Manual)
|
| 380 |
+
|
| 381 |
+
For step-by-step control, see [LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md).
|
| 382 |
+
|
| 383 |
+
**Quick summary**:
|
| 384 |
+
1. Launch instance at https://cloud.lambdalabs.com/instances
|
| 385 |
+
2. SSH to instance: `ssh ubuntu@<IP>`
|
| 386 |
+
3. Clone repo: `git clone <repo-url> && cd nanochatAquaRat`
|
| 387 |
+
4. Set up credentials: `echo "WANDB_API_KEY=..." > .env`
|
| 388 |
+
5. Run training: `bash run_aquarat_small.sh`
|
| 389 |
+
|
| 390 |
+
### Option 4: Hyperbolic VM (Manual)
|
| 391 |
+
|
| 392 |
+
For marketplace nodes without automation access, follow this lightweight bootstrap:
|
| 393 |
+
|
| 394 |
+
1. Provision a GPU VM from the Hyperbolic console and copy the SSH command (including `-p <port>` and username).
|
| 395 |
+
2. SSH in and install prerequisites:
|
| 396 |
+
```bash
|
| 397 |
+
sudo apt-get update
|
| 398 |
+
sudo apt-get install -y git curl unzip build-essential python3 python3-venv tmux
|
| 399 |
+
git clone https://github.com/HarleyCoops/nanochatAquaRat.git
|
| 400 |
+
cd nanochatAquaRat
|
| 401 |
+
```
|
| 402 |
+
3. Create `.env` with the required keys (WANDB, GCS bucket, AQUA path) and upload your GCP service-account JSON to the VM, e.g. `scp -P <port> C:\path\to\credentials.json user@<ip>:/home/user/gcp-sa.json`.
|
| 403 |
+
4. Install tooling and build the tokenizer:
|
| 404 |
+
```bash
|
| 405 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 406 |
+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
|
| 407 |
+
source "$HOME/.cargo/env"
|
| 408 |
+
export PATH="$HOME/.local/bin:$PATH"
|
| 409 |
+
uv venv && uv sync --extra gpu
|
| 410 |
+
source .venv/bin/activate
|
| 411 |
+
uv run maturin develop
|
| 412 |
+
uv run python -m scripts.tok_train
|
| 413 |
+
```
|
| 414 |
+
5. Install the Google Cloud SDK, authenticate, and stage the cached AQuA splits (or regenerate them):
|
| 415 |
+
```bash
|
| 416 |
+
curl -sSL https://sdk.cloud.google.com | bash
|
| 417 |
+
source "$HOME/.bashrc"
|
| 418 |
+
gcloud auth login --no-launch-browser
|
| 419 |
+
gcloud config set project <your-project-id>
|
| 420 |
+
gcloud storage cp gs://nanochat-aquarat-datasets/datasets/aqua/aqua_cache.zip .
|
| 421 |
+
unzip -o aqua_cache.zip -d ~/aqua_cache
|
| 422 |
+
export AQUA_DATA_DIR=$HOME/aqua_cache
|
| 423 |
+
```
|
| 424 |
+
6. Fetch the identity conversation bundle (required for SFT) and the evaluation bundle once so CORE metrics don’t fail:
|
| 425 |
+
```bash
|
| 426 |
+
cd ~/.cache/nanochat
|
| 427 |
+
curl -L -o identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
| 428 |
+
curl -L -o eval_bundle.zip https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
| 429 |
+
unzip -q eval_bundle.zip && rm eval_bundle.zip
|
| 430 |
+
cd ~/nanochatAquaRat
|
| 431 |
+
```
|
| 432 |
+
7. Launch the desired script, e.g. `CUDA_VISIBLE_DEVICES=0 bash run_aquarat_lite.sh` or the full `run_aquarat_small.sh`.
|
| 433 |
+
8. Monitor training via tmux/W&B and terminate the VM from Hyperbolic when the run finishes to stop billing.
|
| 434 |
+
|
| 435 |
+
### Option 4: Alternative Launcher Script
|
| 436 |
+
|
| 437 |
+
A simplified launcher is also available:
|
| 438 |
+
|
| 439 |
+
```bash
|
| 440 |
+
export LAMBDA_API_KEY='your-key'
|
| 441 |
+
export WANDB_API_KEY='your-key'
|
| 442 |
+
|
| 443 |
+
python launch_lambda.py \
|
| 444 |
+
--instance-type gpu_8x_h100_sxm5 \
|
| 445 |
+
--region us-west-1
|
| 446 |
+
```
|
| 447 |
+
|
| 448 |
+
See [QUICKSTART.md](QUICKSTART.md) for details.
|
| 449 |
+
|
| 450 |
### Option 5: Local/Custom Setup
|
| 451 |
|
| 452 |
```bash
|
|
|
|
| 464 |
- 40GB+ GPU memory per GPU
|
| 465 |
- ~100GB disk space
|
| 466 |
|
| 467 |
+
## Hugging Face Sync
|
|
|
|
|
|
|
| 468 |
|
| 469 |
+
Keep the GitHub docs mirrored with the Hugging Face model card:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
1. Edit `README.md` (and any linked docs) as usual.
|
| 472 |
+
2. Stage the release payload locally:
|
| 473 |
+
```bash
|
| 474 |
+
uv run python -m scripts.sync_hf_repo --no-push
|
| 475 |
+
```
|
| 476 |
+
This copies every README dependency into `hf_release/`. The script warns if a referenced file such as `LICENSE` is missing.
|
| 477 |
+
3. Push the staged contents to Hugging Face once you are satisfied:
|
| 478 |
+
```bash
|
| 479 |
+
uv run python -m scripts.sync_hf_repo --repo-id HarleyCooper/nanochatAquaRat
|
| 480 |
+
```
|
| 481 |
+
The command requires prior `huggingface-cli login` (or an `HF_TOKEN` env var). Use `--dry-run` to review operations without copying or uploading.
|
|
|
|
| 482 |
|
| 483 |
---
|
| 484 |
|
| 485 |
+
## File Structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
```
|
| 488 |
+
nanochatAquaRat/
|
| 489 |
+
├── nanochat/… # Vendored upstream nanochat package
|
| 490 |
+
├── scripts/
|
| 491 |
+
│ ├── base_train.py # Base pretraining stage
|
| 492 |
+
│ ├── mid_train.py # Mid-training (now includes AQuA)
|
| 493 |
+
│ ├── chat_sft.py # Chat SFT pipeline
|
| 494 |
+
│ ├── sft_train.py # Shim so `-m scripts.sft_train` still works
|
| 495 |
+
│ ├── chat_rl.py # Reinforcement learning on AQuA-RAT
|
| 496 |
+
│ ├── chat_eval.py # Evaluation harness (adds AQuA task)
|
| 497 |
+
│ ├── prepare_aqua.py # AQuA-RAT JSONL exporter
|
| 498 |
+
│ ├── launch_lambda_training.py # Lambda Labs automation
|
| 499 |
+
│ ├── launch_hyperbolic_training.py # Hyperbolic Labs automation
|
| 500 |
+
│ └── upload_to_gcs.sh # Artifact helper
|
| 501 |
+
├── tasks/
|
| 502 |
+
│ ├── aqua.py # AQuA-RAT task implementation
|
| 503 |
+
│ ├── arc.py / gsm8k.py / mmlu.py # Other reasoning tasks
|
| 504 |
+
│ └── …
|
| 505 |
+
├── run_aquarat_small.sh # End-to-end orchestration
|
| 506 |
+
├── pyproject.toml / uv.lock # Environment definitions
|
| 507 |
+
└── README.md
|
| 508 |
+
```
|
| 509 |
+
### Summary of Code Changes
|
| 510 |
+
|
| 511 |
+
| File | Type | Description |
|
| 512 |
+
|------|------|-------------|
|
| 513 |
+
| `tasks/aqua.py` | NEW | Conversation + evaluation wrapper for AQuA-RAT |
|
| 514 |
+
| `scripts/prepare_aqua.py` | NEW | Materializes train/validation/test JSONL splits for offline use |
|
| 515 |
+
| `scripts/mid_train.py` | MODIFIED | Adds AQuA to the mid-training mixture |
|
| 516 |
+
| `scripts/chat_sft.py` | MODIFIED | SFT mixture now includes AQuA controls |
|
| 517 |
+
| `scripts/sft_train.py` | NEW | Thin compatibility shim around `chat_sft` |
|
| 518 |
+
| `scripts/chat_rl.py` | MODIFIED | RL loop retargeted from GSM8K to AQuA-RAT |
|
| 519 |
+
| `scripts/chat_eval.py` | MODIFIED | Registers AQuA for categorical evaluation |
|
| 520 |
+
| `run_aquarat_small.sh` | MODIFIED | Pipeline glue aligned with AQuA staging |
|
| 521 |
+
| `scripts/launch_hyperbolic_training.py` | NEW | Hyperbolic Labs automation helper |
|
| 522 |
+
| `launch_lambda.py` / `scripts/launch_lambda_training.py` | EXISTING | Lambda Labs support retained |
|
| 523 |
+
|
| 524 |
+
---
|
| 525 |
+
|
| 526 |
+
## Monitoring & Visualization
|
| 527 |
+
|
| 528 |
+
All metrics stream to [Weights & Biases](https://wandb.ai) in real-time:
|
| 529 |
+
|
| 530 |
+
**Training Metrics**:
|
| 531 |
+
- Loss curves (pretraining, SFT, RL)
|
| 532 |
+
- Learning rate schedules
|
| 533 |
+
- Gradient norms
|
| 534 |
+
|
| 535 |
+
**RL Metrics**:
|
| 536 |
+
- Policy performance (accuracy, rewards)
|
| 537 |
+
- KL divergence from initial policy
|
| 538 |
+
- Letter-choice distributions (A-E)
|
| 539 |
+
- Confidence margins
|
| 540 |
+
|
| 541 |
+
**Interpretability**:
|
| 542 |
+
- Attention heatmaps per layer
|
| 543 |
+
- Entropy evolution across training
|
| 544 |
+
- Token-level attention weights
|
| 545 |
+
|
| 546 |
+
Example W&B dashboard:
|
| 547 |
+
```
|
| 548 |
+
rl/acc ━━━━━━━━━━ 0.45
|
| 549 |
+
rl/kl_letter_mean ━━━━━━━━━━ 0.12
|
| 550 |
+
rl/letter_margin_mean ━━━━━━━━━━ 2.34
|
| 551 |
+
attn/entropy_mean ━━━━━━━━━━ 3.21
|
| 552 |
+
```
|
| 553 |
+
|
| 554 |
+
---
|
| 555 |
+
|
| 556 |
+
## Results
|
| 557 |
+
|
| 558 |
+
### Model Configurations
|
| 559 |
+
|
| 560 |
+
| Depth | Parameters | Training Time | Best Instance Type | Estimated Cost |
|
| 561 |
+
|-------|------------|---------------|-------------------|----------------|
|
| 562 |
+
| 8 | ~60M | 3-4 hours | 1-2x A100 | ~$18-35 |
|
| 563 |
+
| 12 | ~180M | 4-5 hours | 4x A100 | ~$35-45 |
|
| 564 |
+
| 20 | ~561M | 6-8 hours | 8x H100 | ~$144-192 |
|
| 565 |
+
| 26 | ~1.1B | 10-12 hours | 8x H100 | ~$240-288 |
|
| 566 |
+
|
| 567 |
+
To change model depth, edit the `--depth` parameter in `run_aquarat_small.sh`.
|
| 568 |
+
|
| 569 |
+
### Expected Performance
|
| 570 |
+
|
| 571 |
+
**After SFT** (before RL):
|
| 572 |
+
- Dev accuracy: 20-30% (depth-8), 30-40% (depth-20)
|
| 573 |
+
- Basic problem-solving capability
|
| 574 |
+
- Some format errors (invalid letters)
|
| 575 |
+
|
| 576 |
+
**After RL**:
|
| 577 |
+
- Dev accuracy: 30-50% (depth-8), 40-60% (depth-20)
|
| 578 |
+
- Improved reasoning coherence
|
| 579 |
+
- Better multiple-choice selection confidence
|
| 580 |
+
- Reduced format errors
|
| 581 |
+
- Stable attention patterns
|
| 582 |
+
|
| 583 |
+
### Cost Management
|
| 584 |
+
|
| 585 |
+
Lambda Labs pricing (8x H100 SXM5 @ ~$24/hour):
|
| 586 |
+
|
| 587 |
+
| Model | Training Time | Total Cost |
|
| 588 |
+
|-------|---------------|------------|
|
| 589 |
+
| depth-8 (60M) | 3-4 hours | ~$96 |
|
| 590 |
+
| depth-20 (561M) | 6-8 hours | ~$192 |
|
| 591 |
+
|
| 592 |
+
Budget options:
|
| 593 |
+
- Test pipeline: 1x A10 @ $0.60/hr
|
| 594 |
+
- Small model: 2x A100 @ $4.40/hr
|
| 595 |
+
- Production: 8x H100 @ $24/hr
|
| 596 |
+
|
| 597 |
+
---
|
| 598 |
+
|
| 599 |
+
## Important Notes
|
| 600 |
+
|
| 601 |
+
### For Lambda Labs Users
|
| 602 |
+
- **Always terminate instances** after training to avoid charges
|
| 603 |
+
- Monitor spending in the Lambda Labs dashboard
|
| 604 |
+
- Check instance availability before launching (high demand periods)
|
| 605 |
+
|
| 606 |
+
### Known Limitations
|
| 607 |
+
- RL on AQuA-RAT is experimental; results may vary
|
| 608 |
+
- Attention logging adds ~5-10% overhead
|
| 609 |
+
- KL computation can be expensive with large batch sizes
|
| 610 |
+
- Smaller models (<100M params) may struggle with complex reasoning
|
| 611 |
+
|
| 612 |
+
---
|
| 613 |
+
|
| 614 |
+
## Documentation
|
| 615 |
+
|
| 616 |
+
- **[scripts/launch_lambda_training.py](scripts/launch_lambda_training.py)** - Full-featured automation
|
| 617 |
+
- **[scripts/launch_hyperbolic_training.py](scripts/launch_hyperbolic_training.py)** - Hyperbolic marketplace automation
|
| 618 |
+
- **[launch_lambda.py](launch_lambda.py)** - Simplified launcher
|
| 619 |
+
- **[QUICKSTART.md](QUICKSTART.md)** - Fast track guide
|
| 620 |
+
- **[LAMBDA_MANUAL_SETUP.md](LAMBDA_MANUAL_SETUP.md)** - Manual setup walkthrough
|
| 621 |
+
- **[GCS_UPLOAD_GUIDE.md](GCS_UPLOAD_GUIDE.md)** - Upload weights to Google Cloud Storage
|
| 622 |
+
- **[.env.template](.env.template)** - Environment configuration
|
| 623 |
+
|
| 624 |
+
---
|
| 625 |
+
|
| 626 |
+
## Contributing
|
| 627 |
+
|
| 628 |
+
This project is based on the nanochat framework. For issues specific to:
|
| 629 |
+
- **AQuA-RAT training**: Open an issue in this repository
|
| 630 |
+
- **Base nanochat framework**: Refer to the upstream nanochat project
|
| 631 |
+
- **Lambda Labs deployment**: See documentation above
|
| 632 |
+
|
| 633 |
+
---
|
| 634 |
+
|
| 635 |
+
## License
|
| 636 |
+
|
| 637 |
+
This project inherits the license from the base nanochat project.
|
| 638 |
+
|
| 639 |
+
---
|
| 640 |
+
|
| 641 |
+
## Acknowledgments
|
| 642 |
+
|
| 643 |
+
- **Andrej Karpathy** - nanochat framework
|
| 644 |
+
- **DeepMind** - AQuA-RAT dataset and mechanistic interpretability tools
|
| 645 |
+
- **Lambda Labs** - Cloud GPU infrastructure
|
| 646 |
+
- **Weights & Biases** - Experiment tracking and visualization
|
| 647 |
+
|
| 648 |
+
---
|
| 649 |
+
|
| 650 |
+
## Support
|
| 651 |
+
|
| 652 |
+
- **Lambda Labs Support**: https://lambdalabs.com/support
|
| 653 |
+
- **Weights & Biases Docs**: https://docs.wandb.ai
|
| 654 |
+
- **Project Issues**: https://github.com/HarleyCoops/nanochatAquaRat/issues
|
aquarat2.png
ADDED
|
Git LFS Details
|
launch_lambda.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Lambda Labs GPU Instance Launcher for AQuA-RAT Training
|
| 4 |
+
|
| 5 |
+
This script automates launching an 8x H100 GPU instance on Lambda Labs
|
| 6 |
+
and deploying the nanochatAquaRat training pipeline.
|
| 7 |
+
|
| 8 |
+
Prerequisites:
|
| 9 |
+
1. Lambda Labs API key (set as LAMBDA_API_KEY environment variable)
|
| 10 |
+
2. Your SSH public key added to Lambda Labs account
|
| 11 |
+
3. W&B API key for logging (set as WANDB_API_KEY environment variable)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python launch_lambda.py --instance-type gpu_8x_h100_sxm5 --region us-west-1
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
import argparse
|
| 21 |
+
import subprocess
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import lambda_cloud_client
|
| 26 |
+
from lambda_cloud_client.rest import ApiException
|
| 27 |
+
except ImportError:
|
| 28 |
+
print("Installing lambda-cloud-client...")
|
| 29 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "lambda-cloud-client"])
|
| 30 |
+
import lambda_cloud_client
|
| 31 |
+
from lambda_cloud_client.rest import ApiException
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def check_env_vars():
|
| 35 |
+
"""Check required environment variables are set"""
|
| 36 |
+
required_vars = {
|
| 37 |
+
'LAMBDA_API_KEY': 'Lambda Labs API key',
|
| 38 |
+
'WANDB_API_KEY': 'Weights & Biases API key'
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
missing = []
|
| 42 |
+
for var, description in required_vars.items():
|
| 43 |
+
if not os.getenv(var):
|
| 44 |
+
missing.append(f" - {var} ({description})")
|
| 45 |
+
|
| 46 |
+
if missing:
|
| 47 |
+
print("ERROR: Missing required environment variables:")
|
| 48 |
+
print("\n".join(missing))
|
| 49 |
+
print("\nSet them with:")
|
| 50 |
+
print(" export LAMBDA_API_KEY='your-lambda-api-key'")
|
| 51 |
+
print(" export WANDB_API_KEY='your-wandb-api-key'")
|
| 52 |
+
sys.exit(1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_api_client():
|
| 56 |
+
"""Initialize Lambda Cloud API client"""
|
| 57 |
+
api_key = os.getenv('LAMBDA_API_KEY')
|
| 58 |
+
configuration = lambda_cloud_client.Configuration(
|
| 59 |
+
host="https://cloud.lambdalabs.com/api/v1",
|
| 60 |
+
access_token=api_key
|
| 61 |
+
)
|
| 62 |
+
return lambda_cloud_client.ApiClient(configuration)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def list_available_instance_types(api_client):
|
| 66 |
+
"""List available instance types and regions"""
|
| 67 |
+
api_instance = lambda_cloud_client.DefaultApi(api_client)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
response = api_instance.instance_types()
|
| 71 |
+
print("\nAvailable Instance Types:")
|
| 72 |
+
print("-" * 80)
|
| 73 |
+
|
| 74 |
+
for type_name, details in response.data.items():
|
| 75 |
+
if details.instance_type.regions_with_capacity_available:
|
| 76 |
+
print(f"\n{type_name}:")
|
| 77 |
+
print(f" GPUs: {details.instance_type.specs.gpus}")
|
| 78 |
+
print(f" GPU Memory: {details.instance_type.specs.memory_gbs} GB")
|
| 79 |
+
print(f" Price: ${details.instance_type.specs.price_cents_per_hour / 100}/hour")
|
| 80 |
+
print(f" Available regions: {', '.join(details.instance_type.regions_with_capacity_available)}")
|
| 81 |
+
|
| 82 |
+
return response.data
|
| 83 |
+
except ApiException as e:
|
| 84 |
+
print(f"Error fetching instance types: {e}")
|
| 85 |
+
sys.exit(1)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def launch_instance(api_client, instance_type, region, name="nanochat-aquarat-training"):
|
| 89 |
+
"""Launch a Lambda Labs GPU instance"""
|
| 90 |
+
api_instance = lambda_cloud_client.DefaultApi(api_client)
|
| 91 |
+
|
| 92 |
+
# Get SSH keys
|
| 93 |
+
try:
|
| 94 |
+
ssh_keys_response = api_instance.list_ssh_keys()
|
| 95 |
+
if not ssh_keys_response.data:
|
| 96 |
+
print("ERROR: No SSH keys found in your Lambda Labs account.")
|
| 97 |
+
print("Please add an SSH key at: https://cloud.lambdalabs.com/ssh-keys")
|
| 98 |
+
sys.exit(1)
|
| 99 |
+
|
| 100 |
+
ssh_key_names = [key.name for key in ssh_keys_response.data]
|
| 101 |
+
print(f"Using SSH keys: {', '.join(ssh_key_names)}")
|
| 102 |
+
except ApiException as e:
|
| 103 |
+
print(f"Error fetching SSH keys: {e}")
|
| 104 |
+
sys.exit(1)
|
| 105 |
+
|
| 106 |
+
# Launch instance
|
| 107 |
+
launch_request = lambda_cloud_client.LaunchInstanceRequest(
|
| 108 |
+
region_name=region,
|
| 109 |
+
instance_type_name=instance_type,
|
| 110 |
+
ssh_key_names=ssh_key_names,
|
| 111 |
+
name=name,
|
| 112 |
+
quantity=1
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
print(f"\nLaunching {instance_type} instance in {region}...")
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
response = api_instance.launch_instance(launch_request)
|
| 119 |
+
|
| 120 |
+
if response.data and response.data.instance_ids:
|
| 121 |
+
instance_id = response.data.instance_ids[0]
|
| 122 |
+
print(f"✓ Instance launched successfully!")
|
| 123 |
+
print(f" Instance ID: {instance_id}")
|
| 124 |
+
return instance_id
|
| 125 |
+
else:
|
| 126 |
+
print("ERROR: Instance launch failed")
|
| 127 |
+
sys.exit(1)
|
| 128 |
+
|
| 129 |
+
except ApiException as e:
|
| 130 |
+
print(f"Error launching instance: {e}")
|
| 131 |
+
sys.exit(1)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def wait_for_instance(api_client, instance_id, timeout=300):
|
| 135 |
+
"""Wait for instance to be ready"""
|
| 136 |
+
api_instance = lambda_cloud_client.DefaultApi(api_client)
|
| 137 |
+
|
| 138 |
+
print("\nWaiting for instance to be ready...")
|
| 139 |
+
start_time = time.time()
|
| 140 |
+
|
| 141 |
+
while time.time() - start_time < timeout:
|
| 142 |
+
try:
|
| 143 |
+
response = api_instance.get_instance(instance_id)
|
| 144 |
+
instance = response.data
|
| 145 |
+
|
| 146 |
+
if instance.status == "active":
|
| 147 |
+
print(f"✓ Instance is ready!")
|
| 148 |
+
print(f" IP Address: {instance.ip}")
|
| 149 |
+
print(f" SSH Command: ssh ubuntu@{instance.ip}")
|
| 150 |
+
return instance
|
| 151 |
+
|
| 152 |
+
print(f" Status: {instance.status}... waiting")
|
| 153 |
+
time.sleep(10)
|
| 154 |
+
|
| 155 |
+
except ApiException as e:
|
| 156 |
+
print(f"Error checking instance status: {e}")
|
| 157 |
+
time.sleep(10)
|
| 158 |
+
|
| 159 |
+
print("ERROR: Timeout waiting for instance to be ready")
|
| 160 |
+
sys.exit(1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def generate_startup_script():
|
| 164 |
+
"""Generate the startup script to run on the instance"""
|
| 165 |
+
wandb_key = os.getenv('WANDB_API_KEY')
|
| 166 |
+
|
| 167 |
+
script = f"""#!/bin/bash
|
| 168 |
+
set -euo pipefail
|
| 169 |
+
|
| 170 |
+
# Create .env file with credentials
|
| 171 |
+
cat > /home/ubuntu/nanochatAquaRat/.env << 'EOF'
|
| 172 |
+
WANDB_API_KEY={wandb_key}
|
| 173 |
+
WANDB_PROJECT=nanochat-aquarat
|
| 174 |
+
WANDB_ENTITY=${{WANDB_ENTITY:-}}
|
| 175 |
+
EOF
|
| 176 |
+
|
| 177 |
+
# Clone repository if not exists
|
| 178 |
+
cd /home/ubuntu
|
| 179 |
+
if [ ! -d "nanochatAquaRat" ]; then
|
| 180 |
+
git clone https://github.com/HarleyCoops/nanochatAquaRat.git
|
| 181 |
+
fi
|
| 182 |
+
|
| 183 |
+
cd nanochatAquaRat
|
| 184 |
+
|
| 185 |
+
# Make script executable
|
| 186 |
+
chmod +x run_aquarat_small.sh
|
| 187 |
+
|
| 188 |
+
# Run training in screen session
|
| 189 |
+
screen -dmS training bash -c './run_aquarat_small.sh 2>&1 | tee training.log'
|
| 190 |
+
|
| 191 |
+
echo "Training started in screen session 'training'"
|
| 192 |
+
echo "To attach: screen -r training"
|
| 193 |
+
echo "To detach: Ctrl+A then D"
|
| 194 |
+
echo "To view log: tail -f training.log"
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
return script
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def deploy_and_run(instance_ip):
|
| 201 |
+
"""Deploy code and start training on the instance"""
|
| 202 |
+
print("\nDeploying code and starting training...")
|
| 203 |
+
|
| 204 |
+
startup_script = generate_startup_script()
|
| 205 |
+
|
| 206 |
+
# Save startup script locally
|
| 207 |
+
script_path = Path("/tmp/lambda_startup.sh")
|
| 208 |
+
script_path.write_text(startup_script)
|
| 209 |
+
|
| 210 |
+
# Copy startup script to instance
|
| 211 |
+
print(" Copying startup script...")
|
| 212 |
+
subprocess.run([
|
| 213 |
+
"scp", "-o", "StrictHostKeyChecking=no",
|
| 214 |
+
str(script_path),
|
| 215 |
+
f"ubuntu@{instance_ip}:/tmp/startup.sh"
|
| 216 |
+
], check=True)
|
| 217 |
+
|
| 218 |
+
# Execute startup script
|
| 219 |
+
print(" Starting training...")
|
| 220 |
+
subprocess.run([
|
| 221 |
+
"ssh", "-o", "StrictHostKeyChecking=no",
|
| 222 |
+
f"ubuntu@{instance_ip}",
|
| 223 |
+
"bash /tmp/startup.sh"
|
| 224 |
+
], check=True)
|
| 225 |
+
|
| 226 |
+
print("\n" + "=" * 80)
|
| 227 |
+
print("✓ Training deployment complete!")
|
| 228 |
+
print("=" * 80)
|
| 229 |
+
print("\nTo monitor your training:")
|
| 230 |
+
print(f" 1. SSH: ssh ubuntu@{instance_ip}")
|
| 231 |
+
print(f" 2. Attach to screen: screen -r training")
|
| 232 |
+
print(f" 3. View log: tail -f ~/nanochatAquaRat/training.log")
|
| 233 |
+
print(f" 4. W&B Dashboard: https://wandb.ai")
|
| 234 |
+
print("\nTo detach from screen: Ctrl+A then D")
|
| 235 |
+
print("\nRemember to terminate the instance when done to avoid charges!")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main():
|
| 239 |
+
parser = argparse.ArgumentParser(description="Launch Lambda Labs instance for AQuA-RAT training")
|
| 240 |
+
parser.add_argument("--instance-type", default="gpu_8x_h100_sxm5",
|
| 241 |
+
help="Instance type (default: gpu_8x_h100_sxm5)")
|
| 242 |
+
parser.add_argument("--region", default="us-west-1",
|
| 243 |
+
help="Region to launch in (default: us-west-1)")
|
| 244 |
+
parser.add_argument("--name", default="nanochat-aquarat-training",
|
| 245 |
+
help="Instance name (default: nanochat-aquarat-training)")
|
| 246 |
+
parser.add_argument("--list-types", action="store_true",
|
| 247 |
+
help="List available instance types and exit")
|
| 248 |
+
parser.add_argument("--no-deploy", action="store_true",
|
| 249 |
+
help="Launch instance but don't deploy code")
|
| 250 |
+
|
| 251 |
+
args = parser.parse_args()
|
| 252 |
+
|
| 253 |
+
print("=" * 80)
|
| 254 |
+
print("Lambda Labs GPU Instance Launcher for AQuA-RAT Training")
|
| 255 |
+
print("=" * 80)
|
| 256 |
+
|
| 257 |
+
# Check environment variables
|
| 258 |
+
check_env_vars()
|
| 259 |
+
|
| 260 |
+
# Initialize API client
|
| 261 |
+
api_client = get_api_client()
|
| 262 |
+
|
| 263 |
+
# List available types if requested
|
| 264 |
+
if args.list_types:
|
| 265 |
+
list_available_instance_types(api_client)
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
# Launch instance
|
| 269 |
+
instance_id = launch_instance(api_client, args.instance_type, args.region, args.name)
|
| 270 |
+
|
| 271 |
+
# Wait for instance to be ready
|
| 272 |
+
instance = wait_for_instance(api_client, instance_id)
|
| 273 |
+
|
| 274 |
+
# Deploy and run training
|
| 275 |
+
if not args.no_deploy:
|
| 276 |
+
time.sleep(5) # Give SSH a moment to be fully ready
|
| 277 |
+
try:
|
| 278 |
+
deploy_and_run(instance.ip)
|
| 279 |
+
except subprocess.CalledProcessError as e:
|
| 280 |
+
print(f"\nWarning: Deployment encountered an error: {e}")
|
| 281 |
+
print(f"You can manually SSH to the instance and run the training:")
|
| 282 |
+
print(f" ssh ubuntu@{instance.ip}")
|
| 283 |
+
print(f" cd nanochatAquaRat && bash run_aquarat_small.sh")
|
| 284 |
+
|
| 285 |
+
print("\n" + "=" * 80)
|
| 286 |
+
print("Instance Information")
|
| 287 |
+
print("=" * 80)
|
| 288 |
+
print(f"Instance ID: {instance_id}")
|
| 289 |
+
print(f"IP Address: {instance.ip}")
|
| 290 |
+
print(f"Status: {instance.status}")
|
| 291 |
+
print("\nTo terminate this instance:")
|
| 292 |
+
print(f" python launch_lambda.py --terminate {instance_id}")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
scripts/launch_hyperbolic_training.py
ADDED
|
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Automation helper to launch a Hyperbolic Labs marketplace instance and kick off the
|
| 4 |
+
nanochat AQuA-RAT training run.
|
| 5 |
+
|
| 6 |
+
The workflow mirrors `launch_lambda_training.py`, but uses Hyperbolic's REST API.
|
| 7 |
+
|
| 8 |
+
Example:
|
| 9 |
+
|
| 10 |
+
python scripts/launch_hyperbolic_training.py \\
|
| 11 |
+
--gpu-count 1 \\
|
| 12 |
+
--region us-east \\
|
| 13 |
+
--max-price 4.5 \\
|
| 14 |
+
--auto-start \\
|
| 15 |
+
--inject-env WANDB_API_KEY
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import shlex
|
| 24 |
+
import subprocess
|
| 25 |
+
import sys
|
| 26 |
+
import tempfile
|
| 27 |
+
import textwrap
|
| 28 |
+
import time
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
| 31 |
+
|
| 32 |
+
import requests
|
| 33 |
+
|
| 34 |
+
API_BASE = "https://api.hyperbolic.xyz"
|
| 35 |
+
MARKETPLACE_BASE = f"{API_BASE}/v1/marketplace"
|
| 36 |
+
READY_STATUSES = {
|
| 37 |
+
"ready",
|
| 38 |
+
"running",
|
| 39 |
+
"instance_running",
|
| 40 |
+
"node_ready",
|
| 41 |
+
"active",
|
| 42 |
+
"online",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def log(msg: str) -> None:
|
| 47 |
+
print(f"[info] {msg}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def warn(msg: str) -> None:
|
| 51 |
+
print(f"[warn] {msg}", file=sys.stderr)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def error(msg: str) -> None:
|
| 55 |
+
print(f"[error] {msg}", file=sys.stderr)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def shell_quote(value: str) -> str:
|
| 59 |
+
return shlex.quote(value)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def collect_env_pairs(cli_pairs: Sequence[str], inject_names: Sequence[str]) -> List[Tuple[str, str]]:
|
| 63 |
+
"""Merge KEY=VALUE pairs with env vars pulled from the local environment."""
|
| 64 |
+
merged: Dict[str, str] = {}
|
| 65 |
+
|
| 66 |
+
for item in cli_pairs:
|
| 67 |
+
if "=" not in item:
|
| 68 |
+
raise ValueError(f"--env expects KEY=VALUE entries, got '{item}'")
|
| 69 |
+
key, value = item.split("=", 1)
|
| 70 |
+
key = key.strip()
|
| 71 |
+
if not key:
|
| 72 |
+
raise ValueError(f"Environment key is empty in '{item}'")
|
| 73 |
+
merged[key] = value
|
| 74 |
+
|
| 75 |
+
for name in inject_names:
|
| 76 |
+
if not name:
|
| 77 |
+
raise ValueError("Encountered empty --inject-env name")
|
| 78 |
+
if name not in os.environ:
|
| 79 |
+
raise ValueError(f"--inject-env requested '{name}' but it is not set locally")
|
| 80 |
+
merged[name] = os.environ[name]
|
| 81 |
+
|
| 82 |
+
return list(merged.items())
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_bootstrap_script(
|
| 86 |
+
repo_dir: str,
|
| 87 |
+
run_script: str,
|
| 88 |
+
branch: str,
|
| 89 |
+
repo_url: Optional[str],
|
| 90 |
+
env_file_remote: str,
|
| 91 |
+
auto_start: bool,
|
| 92 |
+
tmux_session: str,
|
| 93 |
+
) -> str:
|
| 94 |
+
"""Compose the bash script executed on the instance to prepare the training run."""
|
| 95 |
+
lines: List[str] = [
|
| 96 |
+
"#!/usr/bin/env bash",
|
| 97 |
+
"set -euxo pipefail",
|
| 98 |
+
f'REPO_DIR="$HOME/{repo_dir}"',
|
| 99 |
+
f'RUN_SCRIPT="{run_script}"',
|
| 100 |
+
f'ENV_FILE="{env_file_remote}"',
|
| 101 |
+
f'AUTO_START="{1 if auto_start else 0}"',
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
if repo_url:
|
| 105 |
+
lines.append(f"REPO_URL={shell_quote(repo_url)}")
|
| 106 |
+
lines.extend(
|
| 107 |
+
[
|
| 108 |
+
'if [ ! -d "$REPO_DIR/.git" ]; then',
|
| 109 |
+
' rm -rf "$REPO_DIR"',
|
| 110 |
+
' git clone "$REPO_URL" "$REPO_DIR"',
|
| 111 |
+
"fi",
|
| 112 |
+
'cd "$REPO_DIR"',
|
| 113 |
+
"git fetch --all --prune",
|
| 114 |
+
f"git switch {shell_quote(branch)}",
|
| 115 |
+
"git pull --ff-only || true",
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
lines.extend(
|
| 120 |
+
[
|
| 121 |
+
'mkdir -p "$REPO_DIR"',
|
| 122 |
+
'cd "$REPO_DIR"',
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
lines.extend(
|
| 127 |
+
[
|
| 128 |
+
'if [ -f "$ENV_FILE" ]; then',
|
| 129 |
+
' cp "$ENV_FILE" .env',
|
| 130 |
+
"fi",
|
| 131 |
+
'if [ -f "$RUN_SCRIPT" ]; then',
|
| 132 |
+
' chmod +x "$RUN_SCRIPT"',
|
| 133 |
+
"else",
|
| 134 |
+
' echo "Run script $RUN_SCRIPT not found; auto-start will be skipped." >&2',
|
| 135 |
+
' AUTO_START="0"',
|
| 136 |
+
"fi",
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if auto_start:
|
| 141 |
+
tmux_line = (
|
| 142 |
+
f'tmux new -d -s {shell_quote(tmux_session)} '
|
| 143 |
+
'"cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\""'
|
| 144 |
+
)
|
| 145 |
+
nohup_line = (
|
| 146 |
+
'nohup bash -lc "cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\"" '
|
| 147 |
+
'> "$HOME/nanochat-train.log" 2>&1 &'
|
| 148 |
+
)
|
| 149 |
+
lines.extend(
|
| 150 |
+
[
|
| 151 |
+
'if [ "$AUTO_START" = "1" ]; then',
|
| 152 |
+
" if command -v tmux >/dev/null 2>&1; then",
|
| 153 |
+
f" {tmux_line}",
|
| 154 |
+
" else",
|
| 155 |
+
f" {nohup_line}",
|
| 156 |
+
" fi",
|
| 157 |
+
"fi",
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return "\n".join(lines) + "\n"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class HyperbolicClient:
|
| 165 |
+
def __init__(self, api_key: Optional[str]):
|
| 166 |
+
if not api_key:
|
| 167 |
+
raise ValueError("Hyperbolic API key is required. Pass --api-key or set HYPERBOLIC_API_KEY.")
|
| 168 |
+
self.api_key = api_key
|
| 169 |
+
|
| 170 |
+
def _headers(self, with_auth: bool = True) -> Dict[str, str]:
|
| 171 |
+
headers = {"Content-Type": "application/json"}
|
| 172 |
+
if with_auth and self.api_key:
|
| 173 |
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
| 174 |
+
return headers
|
| 175 |
+
|
| 176 |
+
def list_marketplace(self) -> List[Dict[str, Any]]:
|
| 177 |
+
response = requests.post(
|
| 178 |
+
MARKETPLACE_BASE,
|
| 179 |
+
headers=self._headers(with_auth=False),
|
| 180 |
+
json={"filters": {}},
|
| 181 |
+
timeout=30,
|
| 182 |
+
)
|
| 183 |
+
response.raise_for_status()
|
| 184 |
+
payload = response.json()
|
| 185 |
+
instances = payload.get("instances")
|
| 186 |
+
if instances is None:
|
| 187 |
+
if isinstance(payload, list):
|
| 188 |
+
instances = payload
|
| 189 |
+
elif isinstance(payload, dict):
|
| 190 |
+
instances = payload.get("nodes") or payload.get("data") or []
|
| 191 |
+
else:
|
| 192 |
+
instances = []
|
| 193 |
+
return instances
|
| 194 |
+
|
| 195 |
+
def create_instance(
|
| 196 |
+
self,
|
| 197 |
+
cluster_name: str,
|
| 198 |
+
node_name: str,
|
| 199 |
+
gpu_count: int,
|
| 200 |
+
image: Optional[Dict[str, Any]] = None,
|
| 201 |
+
) -> Dict[str, Any]:
|
| 202 |
+
payload: Dict[str, Any] = {
|
| 203 |
+
"cluster_name": cluster_name,
|
| 204 |
+
"node_name": node_name,
|
| 205 |
+
"gpu_count": gpu_count,
|
| 206 |
+
}
|
| 207 |
+
if image:
|
| 208 |
+
payload["image"] = image
|
| 209 |
+
response = requests.post(
|
| 210 |
+
f"{MARKETPLACE_BASE}/instances/create",
|
| 211 |
+
headers=self._headers(),
|
| 212 |
+
json=payload,
|
| 213 |
+
timeout=30,
|
| 214 |
+
)
|
| 215 |
+
response.raise_for_status()
|
| 216 |
+
return response.json()
|
| 217 |
+
|
| 218 |
+
def list_instances(self) -> List[Dict[str, Any]]:
|
| 219 |
+
response = requests.get(
|
| 220 |
+
f"{MARKETPLACE_BASE}/instances",
|
| 221 |
+
headers=self._headers(),
|
| 222 |
+
timeout=30,
|
| 223 |
+
)
|
| 224 |
+
response.raise_for_status()
|
| 225 |
+
payload = response.json()
|
| 226 |
+
if isinstance(payload, dict):
|
| 227 |
+
return payload.get("instances") or payload.get("data") or []
|
| 228 |
+
if isinstance(payload, list):
|
| 229 |
+
return payload
|
| 230 |
+
return []
|
| 231 |
+
|
| 232 |
+
def terminate_instance(self, instance_id: str) -> Dict[str, Any]:
|
| 233 |
+
response = requests.post(
|
| 234 |
+
f"{MARKETPLACE_BASE}/instances/terminate",
|
| 235 |
+
headers=self._headers(),
|
| 236 |
+
json={"id": instance_id},
|
| 237 |
+
timeout=30,
|
| 238 |
+
)
|
| 239 |
+
response.raise_for_status()
|
| 240 |
+
return response.json()
|
| 241 |
+
|
| 242 |
+
def get_balance(self) -> Optional[float]:
|
| 243 |
+
try:
|
| 244 |
+
response = requests.get(
|
| 245 |
+
f"{API_BASE}/billing/get_current_balance",
|
| 246 |
+
headers=self._headers(),
|
| 247 |
+
timeout=30,
|
| 248 |
+
)
|
| 249 |
+
response.raise_for_status()
|
| 250 |
+
data = response.json()
|
| 251 |
+
if isinstance(data, dict):
|
| 252 |
+
return float(data.get("balance") or data.get("amount") or data.get("credits"))
|
| 253 |
+
except (requests.HTTPError, ValueError, TypeError):
|
| 254 |
+
warn("Unable to fetch current account balance.")
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def summarize_node(node: Dict[str, Any]) -> str:
|
| 259 |
+
def _gpu_models() -> str:
|
| 260 |
+
gpus = []
|
| 261 |
+
hardware = node.get("hardware") or {}
|
| 262 |
+
for gpu in hardware.get("gpus") or []:
|
| 263 |
+
model = gpu.get("model")
|
| 264 |
+
ram = gpu.get("ram") or gpu.get("memory") or gpu.get("vram")
|
| 265 |
+
if model and ram:
|
| 266 |
+
gpus.append(f"{model} ({ram} GB)")
|
| 267 |
+
elif model:
|
| 268 |
+
gpus.append(model)
|
| 269 |
+
return ", ".join(gpus) if gpus else "Unknown GPUs"
|
| 270 |
+
|
| 271 |
+
price_info = (node.get("pricing") or {}).get("price") or {}
|
| 272 |
+
price = price_info.get("amount")
|
| 273 |
+
price_str = f"${price:.2f}/hr" if isinstance(price, (int, float)) else "n/a"
|
| 274 |
+
region = ((node.get("location") or {}).get("region")) or "unknown region"
|
| 275 |
+
cluster = node.get("cluster_name") or "unknown cluster"
|
| 276 |
+
available = (node.get("gpus_total") or 0) - (node.get("gpus_reserved") or 0)
|
| 277 |
+
supplier = node.get("supplier_id") or "unknown supplier"
|
| 278 |
+
return (
|
| 279 |
+
f"{node.get('id', '<unknown>')} | {cluster} | {region} | "
|
| 280 |
+
f"{available}/{node.get('gpus_total', '?')} GPUs free | "
|
| 281 |
+
f"{_gpu_models()} | {price_str} | supplier: {supplier}"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def filter_nodes(
|
| 286 |
+
nodes: Iterable[Dict[str, Any]],
|
| 287 |
+
gpu_count: int,
|
| 288 |
+
region: Optional[str],
|
| 289 |
+
supplier: Optional[str],
|
| 290 |
+
max_price: Optional[float],
|
| 291 |
+
) -> List[Dict[str, Any]]:
|
| 292 |
+
filtered: List[Dict[str, Any]] = []
|
| 293 |
+
region = region.lower() if region else None
|
| 294 |
+
supplier = supplier.lower() if supplier else None
|
| 295 |
+
|
| 296 |
+
for node in nodes:
|
| 297 |
+
total = node.get("gpus_total") or 0
|
| 298 |
+
reserved = node.get("gpus_reserved") or 0
|
| 299 |
+
available = total - reserved
|
| 300 |
+
if available < gpu_count:
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
if region:
|
| 304 |
+
region_value = ((node.get("location") or {}).get("region") or "").lower()
|
| 305 |
+
if region not in region_value:
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
if supplier:
|
| 309 |
+
supplier_value = (node.get("supplier_id") or "").lower()
|
| 310 |
+
if supplier not in supplier_value:
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
price_info = (node.get("pricing") or {}).get("price") or {}
|
| 314 |
+
price = price_info.get("amount")
|
| 315 |
+
if max_price is not None and isinstance(price, (int, float)) and price > max_price:
|
| 316 |
+
continue
|
| 317 |
+
|
| 318 |
+
filtered.append(node)
|
| 319 |
+
|
| 320 |
+
filtered.sort(
|
| 321 |
+
key=lambda n: ((n.get("pricing") or {}).get("price") or {}).get("amount", float("inf"))
|
| 322 |
+
)
|
| 323 |
+
return filtered
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def extract_instance_id(payload: Dict[str, Any], before_ids: Sequence[str], client: HyperbolicClient) -> str:
|
| 327 |
+
candidates: List[str] = []
|
| 328 |
+
for key in ("id", "instance_id", "instanceId"):
|
| 329 |
+
value = payload.get(key)
|
| 330 |
+
if isinstance(value, str) and value:
|
| 331 |
+
candidates.append(value)
|
| 332 |
+
instance_obj = payload.get("instance") or payload.get("data")
|
| 333 |
+
if isinstance(instance_obj, dict):
|
| 334 |
+
for key in ("id", "instance_id", "instanceId"):
|
| 335 |
+
value = instance_obj.get(key)
|
| 336 |
+
if isinstance(value, str) and value:
|
| 337 |
+
candidates.append(value)
|
| 338 |
+
if candidates:
|
| 339 |
+
return candidates[0]
|
| 340 |
+
|
| 341 |
+
# Fall back to diffing current instances.
|
| 342 |
+
time.sleep(3)
|
| 343 |
+
current = client.list_instances()
|
| 344 |
+
current_ids = {str(inst.get("id")) for inst in current if inst.get("id")}
|
| 345 |
+
diff = current_ids.difference(before_ids)
|
| 346 |
+
if diff:
|
| 347 |
+
return diff.pop()
|
| 348 |
+
raise RuntimeError("Unable to determine instance ID from API response.")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def extract_ip(instance: Dict[str, Any]) -> Optional[str]:
|
| 352 |
+
network = instance.get("network") or {}
|
| 353 |
+
candidates = [
|
| 354 |
+
instance.get("public_ip"),
|
| 355 |
+
instance.get("ip_address"),
|
| 356 |
+
instance.get("ip"),
|
| 357 |
+
instance.get("ipv4"),
|
| 358 |
+
network.get("public_ip"),
|
| 359 |
+
network.get("ip"),
|
| 360 |
+
network.get("ipv4"),
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
for item in instance.get("ip_addresses") or []:
|
| 364 |
+
candidates.extend(
|
| 365 |
+
[
|
| 366 |
+
item.get("public_ip"),
|
| 367 |
+
item.get("ip"),
|
| 368 |
+
item.get("ipv4"),
|
| 369 |
+
item.get("address"),
|
| 370 |
+
]
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
for candidate in candidates:
|
| 374 |
+
if isinstance(candidate, str) and candidate:
|
| 375 |
+
return candidate
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def extract_ssh_port(instance: Dict[str, Any]) -> int:
|
| 380 |
+
network = instance.get("network") or {}
|
| 381 |
+
candidates = [
|
| 382 |
+
instance.get("ssh_port"),
|
| 383 |
+
network.get("ssh_port"),
|
| 384 |
+
(instance.get("ssh") or {}).get("port"),
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
for candidate in candidates:
|
| 388 |
+
if isinstance(candidate, int):
|
| 389 |
+
return candidate
|
| 390 |
+
if isinstance(candidate, str) and candidate.isdigit():
|
| 391 |
+
return int(candidate)
|
| 392 |
+
return 22
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def extract_status(instance: Dict[str, Any]) -> str:
|
| 396 |
+
for key in ("status", "instance_status", "state"):
|
| 397 |
+
value = instance.get(key)
|
| 398 |
+
if isinstance(value, str):
|
| 399 |
+
return value
|
| 400 |
+
return ""
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def wait_for_instance(
|
| 404 |
+
client: HyperbolicClient,
|
| 405 |
+
instance_id: str,
|
| 406 |
+
poll_seconds: int,
|
| 407 |
+
max_wait_minutes: int,
|
| 408 |
+
) -> Dict[str, Any]:
|
| 409 |
+
log(f"Waiting for instance {instance_id} to become ready...")
|
| 410 |
+
deadline = time.time() + max_wait_minutes * 60
|
| 411 |
+
while time.time() < deadline:
|
| 412 |
+
instances = client.list_instances()
|
| 413 |
+
for instance in instances:
|
| 414 |
+
identifiers = {
|
| 415 |
+
str(instance.get("id")),
|
| 416 |
+
str(instance.get("instance_id")),
|
| 417 |
+
str(instance.get("instanceId")),
|
| 418 |
+
}
|
| 419 |
+
if instance_id not in identifiers:
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
status = extract_status(instance).lower()
|
| 423 |
+
ip = extract_ip(instance)
|
| 424 |
+
if status in READY_STATUSES and ip:
|
| 425 |
+
log(f"Instance is ready: status={status}, ip={ip}")
|
| 426 |
+
return instance
|
| 427 |
+
|
| 428 |
+
log(f" status={status or '<unknown>'}; waiting for ready state...")
|
| 429 |
+
time.sleep(poll_seconds)
|
| 430 |
+
|
| 431 |
+
raise TimeoutError(f"Timed out waiting for instance {instance_id} to become ready.")
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def build_env_content(pairs: Sequence[Tuple[str, str]]) -> str:
|
| 435 |
+
return "\n".join(f"{key}={value}" for key, value in pairs) + ("\n" if pairs else "")
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def scp(
|
| 439 |
+
local_path: Path,
|
| 440 |
+
remote_path: str,
|
| 441 |
+
ssh_user: str,
|
| 442 |
+
host: str,
|
| 443 |
+
port: int,
|
| 444 |
+
ssh_key: Optional[str],
|
| 445 |
+
) -> None:
|
| 446 |
+
cmd = ["scp", "-o", "StrictHostKeyChecking=no"]
|
| 447 |
+
if ssh_key:
|
| 448 |
+
cmd.extend(["-i", ssh_key])
|
| 449 |
+
if port != 22:
|
| 450 |
+
cmd.extend(["-P", str(port)])
|
| 451 |
+
cmd.extend([str(local_path), f"{ssh_user}@{host}:{remote_path}"])
|
| 452 |
+
subprocess.run(cmd, check=True)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def ssh_command(
|
| 456 |
+
ssh_user: str,
|
| 457 |
+
host: str,
|
| 458 |
+
port: int,
|
| 459 |
+
ssh_key: Optional[str],
|
| 460 |
+
*command: str,
|
| 461 |
+
) -> None:
|
| 462 |
+
cmd = ["ssh", "-o", "StrictHostKeyChecking=no"]
|
| 463 |
+
if ssh_key:
|
| 464 |
+
cmd.extend(["-i", ssh_key])
|
| 465 |
+
if port != 22:
|
| 466 |
+
cmd.extend(["-p", str(port)])
|
| 467 |
+
cmd.append(f"{ssh_user}@{host}")
|
| 468 |
+
if command:
|
| 469 |
+
cmd.append(" ".join(command))
|
| 470 |
+
subprocess.run(cmd, check=True)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def deploy_to_instance(
|
| 474 |
+
instance: Dict[str, Any],
|
| 475 |
+
bootstrap_script: str,
|
| 476 |
+
env_pairs: Sequence[Tuple[str, str]],
|
| 477 |
+
env_file_remote: str,
|
| 478 |
+
ssh_user: str,
|
| 479 |
+
ssh_key: Optional[str],
|
| 480 |
+
) -> None:
|
| 481 |
+
ip = extract_ip(instance)
|
| 482 |
+
if not ip:
|
| 483 |
+
raise RuntimeError("Instance does not report a public IP address yet.")
|
| 484 |
+
port = extract_ssh_port(instance)
|
| 485 |
+
log(f"Deploying bootstrap assets to {ssh_user}@{ip}:{port} ...")
|
| 486 |
+
|
| 487 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 488 |
+
tmp_path = Path(tmpdir)
|
| 489 |
+
bootstrap_path = tmp_path / "bootstrap.sh"
|
| 490 |
+
bootstrap_path.write_text(bootstrap_script)
|
| 491 |
+
scp(bootstrap_path, "/tmp/nanochat_bootstrap.sh", ssh_user, ip, port, ssh_key)
|
| 492 |
+
|
| 493 |
+
if env_pairs:
|
| 494 |
+
env_path = tmp_path / "nanochat.env"
|
| 495 |
+
env_path.write_text(build_env_content(env_pairs))
|
| 496 |
+
ssh_command(ssh_user, ip, port, ssh_key, f"mkdir -p {shell_quote(str(Path(env_file_remote).parent))}")
|
| 497 |
+
scp(env_path, env_file_remote, ssh_user, ip, port, ssh_key)
|
| 498 |
+
|
| 499 |
+
log("Executing remote bootstrap script...")
|
| 500 |
+
ssh_command(ssh_user, ip, port, ssh_key, "bash /tmp/nanochat_bootstrap.sh")
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def parse_args() -> argparse.Namespace:
|
| 504 |
+
parser = argparse.ArgumentParser(description="Launch Hyperbolic Labs instance for AQuA-RAT training")
|
| 505 |
+
parser.add_argument("--api-key", default=os.environ.get("HYPERBOLIC_API_KEY"), help="Hyperbolic API key")
|
| 506 |
+
parser.add_argument("--gpu-count", type=int, default=1, help="Number of GPUs to request (default: 1)")
|
| 507 |
+
parser.add_argument("--region", help="Preferred region substring (case-insensitive)")
|
| 508 |
+
parser.add_argument("--supplier", help="Preferred supplier substring (case-insensitive)")
|
| 509 |
+
parser.add_argument("--max-price", type=float, help="Maximum hourly price in USD")
|
| 510 |
+
parser.add_argument("--node-name", help="Specify node name explicitly")
|
| 511 |
+
parser.add_argument("--cluster-name", help="Cluster name when using --node-name")
|
| 512 |
+
parser.add_argument("--list", action="store_true", help="List available marketplace nodes and exit")
|
| 513 |
+
|
| 514 |
+
parser.add_argument("--repo-url", help="Repository URL to clone (defaults to git remote origin)")
|
| 515 |
+
parser.add_argument("--branch", default="main", help="Branch to checkout on the instance (default: main)")
|
| 516 |
+
parser.add_argument("--run-script", default="run_aquarat_small.sh", help="Script to execute on the instance")
|
| 517 |
+
parser.add_argument("--repo-dir", default="nanochatAquaRat", help="Directory name for the repo on the instance")
|
| 518 |
+
|
| 519 |
+
parser.add_argument("--auto-start", action="store_true", help="Automatically run the training script")
|
| 520 |
+
parser.add_argument("--tmux-session", default="training", help="tmux session name when auto-start is enabled")
|
| 521 |
+
parser.add_argument("--ssh-user", default="ubuntu", help="SSH username for the instance (default: ubuntu)")
|
| 522 |
+
parser.add_argument("--ssh-key", help="Path to SSH private key for scp/ssh")
|
| 523 |
+
parser.add_argument("--no-deploy", action="store_true", help="Skip deployment after the instance is ready")
|
| 524 |
+
|
| 525 |
+
parser.add_argument("--env", action="append", default=[], help="Environment variable in KEY=VALUE form")
|
| 526 |
+
parser.add_argument("--inject-env", action="append", default=[], help="Environment variable name to copy from local env")
|
| 527 |
+
|
| 528 |
+
parser.add_argument("--poll-seconds", type=int, default=20, help="Polling interval while waiting (default: 20)")
|
| 529 |
+
parser.add_argument("--max-wait-minutes", type=int, default=25, help="Maximum minutes to wait for ready state")
|
| 530 |
+
|
| 531 |
+
return parser.parse_args()
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def guess_repo_url() -> Optional[str]:
|
| 535 |
+
try:
|
| 536 |
+
completed = subprocess.run(
|
| 537 |
+
["git", "config", "--get", "remote.origin.url"],
|
| 538 |
+
capture_output=True,
|
| 539 |
+
text=True,
|
| 540 |
+
check=True,
|
| 541 |
+
)
|
| 542 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 543 |
+
return None
|
| 544 |
+
url = completed.stdout.strip()
|
| 545 |
+
return url or None
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def main() -> int:
|
| 549 |
+
args = parse_args()
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
env_pairs = collect_env_pairs(args.env, args.inject_env)
|
| 553 |
+
except ValueError as exc:
|
| 554 |
+
error(str(exc))
|
| 555 |
+
return 1
|
| 556 |
+
|
| 557 |
+
if not args.api_key:
|
| 558 |
+
error("Hyperbolic API key not provided. Use --api-key or set HYPERBOLIC_API_KEY.")
|
| 559 |
+
return 1
|
| 560 |
+
|
| 561 |
+
client = HyperbolicClient(args.api_key)
|
| 562 |
+
|
| 563 |
+
try:
|
| 564 |
+
nodes = client.list_marketplace()
|
| 565 |
+
except requests.HTTPError as exc:
|
| 566 |
+
error(f"Failed to list marketplace nodes: {exc}")
|
| 567 |
+
return 1
|
| 568 |
+
|
| 569 |
+
if args.list:
|
| 570 |
+
log("Available marketplace nodes:")
|
| 571 |
+
for node in nodes:
|
| 572 |
+
print(summarize_node(node))
|
| 573 |
+
return 0
|
| 574 |
+
|
| 575 |
+
selected_node: Optional[Dict[str, Any]] = None
|
| 576 |
+
|
| 577 |
+
if args.node_name:
|
| 578 |
+
for node in nodes:
|
| 579 |
+
if node.get("id") == args.node_name:
|
| 580 |
+
selected_node = node
|
| 581 |
+
break
|
| 582 |
+
if not selected_node:
|
| 583 |
+
error(f"Node '{args.node_name}' not found in marketplace list.")
|
| 584 |
+
return 1
|
| 585 |
+
if not args.cluster_name:
|
| 586 |
+
args.cluster_name = selected_node.get("cluster_name")
|
| 587 |
+
else:
|
| 588 |
+
filtered_nodes = filter_nodes(nodes, args.gpu_count, args.region, args.supplier, args.max_price)
|
| 589 |
+
if not filtered_nodes:
|
| 590 |
+
error("No marketplace nodes match the specified constraints.")
|
| 591 |
+
return 1
|
| 592 |
+
selected_node = filtered_nodes[0]
|
| 593 |
+
log("Selected node:")
|
| 594 |
+
print(" " + summarize_node(selected_node))
|
| 595 |
+
args.cluster_name = selected_node.get("cluster_name")
|
| 596 |
+
args.node_name = selected_node.get("id")
|
| 597 |
+
|
| 598 |
+
if not args.cluster_name:
|
| 599 |
+
error("Cluster name is required; unable to determine cluster for the selected node.")
|
| 600 |
+
return 1
|
| 601 |
+
|
| 602 |
+
repo_url = args.repo_url or guess_repo_url()
|
| 603 |
+
if repo_url:
|
| 604 |
+
log(f"Using repository: {repo_url}")
|
| 605 |
+
else:
|
| 606 |
+
warn("Could not determine repository URL. Auto-start will clone existing repo on instance if present.")
|
| 607 |
+
|
| 608 |
+
balance = client.get_balance()
|
| 609 |
+
if balance is not None:
|
| 610 |
+
log(f"Current Hyperbolic balance: ${balance:.2f}")
|
| 611 |
+
|
| 612 |
+
before_instances = client.list_instances()
|
| 613 |
+
before_ids = {str(inst.get("id")) for inst in before_instances if inst.get("id")}
|
| 614 |
+
log(f"Launching instance on cluster '{args.cluster_name}' node '{args.node_name}' "
|
| 615 |
+
f"with {args.gpu_count} GPU(s)...")
|
| 616 |
+
|
| 617 |
+
try:
|
| 618 |
+
create_response = client.create_instance(
|
| 619 |
+
cluster_name=args.cluster_name,
|
| 620 |
+
node_name=args.node_name,
|
| 621 |
+
gpu_count=args.gpu_count,
|
| 622 |
+
)
|
| 623 |
+
except requests.HTTPError as exc:
|
| 624 |
+
error(f"Failed to launch instance: {exc}")
|
| 625 |
+
try:
|
| 626 |
+
warn(f"Response payload: {exc.response.text}") # type: ignore[attr-defined]
|
| 627 |
+
except Exception:
|
| 628 |
+
pass
|
| 629 |
+
return 1
|
| 630 |
+
|
| 631 |
+
instance_id = extract_instance_id(create_response, before_ids, client)
|
| 632 |
+
log(f"Instance request acknowledged with id={instance_id}")
|
| 633 |
+
|
| 634 |
+
try:
|
| 635 |
+
instance = wait_for_instance(
|
| 636 |
+
client=client,
|
| 637 |
+
instance_id=instance_id,
|
| 638 |
+
poll_seconds=args.poll_seconds,
|
| 639 |
+
max_wait_minutes=args.max_wait_minutes,
|
| 640 |
+
)
|
| 641 |
+
except TimeoutError as exc:
|
| 642 |
+
error(str(exc))
|
| 643 |
+
return 1
|
| 644 |
+
|
| 645 |
+
ip = extract_ip(instance)
|
| 646 |
+
port = extract_ssh_port(instance)
|
| 647 |
+
ssh_user = args.ssh_user
|
| 648 |
+
|
| 649 |
+
log("Instance ready. Connection details:")
|
| 650 |
+
if ip:
|
| 651 |
+
ssh_parts = ["ssh", "-o", "StrictHostKeyChecking=no"]
|
| 652 |
+
if args.ssh_key:
|
| 653 |
+
ssh_parts.extend(["-i", args.ssh_key])
|
| 654 |
+
if port != 22:
|
| 655 |
+
ssh_parts.extend(["-p", str(port)])
|
| 656 |
+
ssh_parts.append(f"{ssh_user}@{ip}")
|
| 657 |
+
print(" SSH:", " ".join(ssh_parts))
|
| 658 |
+
else:
|
| 659 |
+
warn("Instance IP not available; SSH command cannot be constructed.")
|
| 660 |
+
|
| 661 |
+
if args.no_deploy:
|
| 662 |
+
log("Skipping deployment (--no-deploy supplied).")
|
| 663 |
+
return 0
|
| 664 |
+
|
| 665 |
+
env_file_remote = f"/home/{ssh_user}/nanochat_aquarat.env"
|
| 666 |
+
bootstrap_script = build_bootstrap_script(
|
| 667 |
+
repo_dir=args.repo_dir,
|
| 668 |
+
run_script=args.run_script,
|
| 669 |
+
branch=args.branch,
|
| 670 |
+
repo_url=repo_url,
|
| 671 |
+
env_file_remote=env_file_remote,
|
| 672 |
+
auto_start=args.auto_start,
|
| 673 |
+
tmux_session=args.tmux_session,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
try:
|
| 677 |
+
deploy_to_instance(
|
| 678 |
+
instance=instance,
|
| 679 |
+
bootstrap_script=bootstrap_script,
|
| 680 |
+
env_pairs=env_pairs,
|
| 681 |
+
env_file_remote=env_file_remote,
|
| 682 |
+
ssh_user=ssh_user,
|
| 683 |
+
ssh_key=args.ssh_key,
|
| 684 |
+
)
|
| 685 |
+
except subprocess.CalledProcessError as exc:
|
| 686 |
+
error(f"Deployment failed: {exc}")
|
| 687 |
+
warn("You can manually SSH to the instance and run the training script.")
|
| 688 |
+
return 1
|
| 689 |
+
|
| 690 |
+
log("Deployment complete.")
|
| 691 |
+
if args.auto_start:
|
| 692 |
+
log("Training should now be running on the instance.")
|
| 693 |
+
else:
|
| 694 |
+
log("Auto-start disabled; after SSHing in, run the configured script manually.")
|
| 695 |
+
|
| 696 |
+
return 0
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
if __name__ == "__main__":
|
| 700 |
+
sys.exit(main())
|
| 701 |
+
|
scripts/launch_lambda_training.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Automation helper to launch a Lambda Labs instance and kick off the nanochat AQuA-RAT run.
|
| 4 |
+
|
| 5 |
+
Usage example:
|
| 6 |
+
|
| 7 |
+
python scripts/launch_lambda_training.py \
|
| 8 |
+
--ssh-key-name my-key \
|
| 9 |
+
--region us-west-1 \
|
| 10 |
+
--instance-type gpu_1x_a10 \
|
| 11 |
+
--repo-url https://github.com/your-org/nanochatAquaRat.git \
|
| 12 |
+
--auto-start \
|
| 13 |
+
--inject-env WANDB_API_KEY
|
| 14 |
+
|
| 15 |
+
By default the script will create cloud-init user-data that installs basic tooling,
|
| 16 |
+
clones the repository, copies an `.env` file when provided, and (optionally) runs
|
| 17 |
+
`run_aquarat_small.sh` inside a detached tmux session. The Lambda Cloud API key is
|
| 18 |
+
read from the `LAMBDA_API_KEY` environment variable or `--api-key`.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import shlex
|
| 27 |
+
import subprocess
|
| 28 |
+
import sys
|
| 29 |
+
import textwrap
|
| 30 |
+
import time
|
| 31 |
+
from collections import OrderedDict
|
| 32 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 33 |
+
|
| 34 |
+
import requests
|
| 35 |
+
|
| 36 |
+
API_BASE = "https://cloud.lambda.ai/api/v1"
|
| 37 |
+
DEFAULT_PACKAGES = ["git", "curl", "tmux", "build-essential"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log(msg: str) -> None:
|
| 41 |
+
print(f"[info] {msg}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def warn(msg: str) -> None:
|
| 45 |
+
print(f"[warn] {msg}", file=sys.stderr)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def error(msg: str) -> None:
|
| 49 |
+
print(f"[error] {msg}", file=sys.stderr)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def shell_quote(value: str) -> str:
|
| 53 |
+
"""Return a shell-escaped string for safe embedding inside scripts."""
|
| 54 |
+
return shlex.quote(value)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def guess_repo_url() -> Optional[str]:
|
| 58 |
+
"""Attempt to infer the git remote URL for the current repository."""
|
| 59 |
+
try:
|
| 60 |
+
completed = subprocess.run(
|
| 61 |
+
["git", "config", "--get", "remote.origin.url"],
|
| 62 |
+
capture_output=True,
|
| 63 |
+
text=True,
|
| 64 |
+
check=True,
|
| 65 |
+
)
|
| 66 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
url = completed.stdout.strip()
|
| 70 |
+
return url or None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def collect_env_pairs(cli_pairs: Sequence[str], inject_names: Sequence[str]) -> List[Tuple[str, str]]:
|
| 74 |
+
"""
|
| 75 |
+
Merge KEY=VALUE pairs declared via the CLI with variables injected from the local env.
|
| 76 |
+
Later occurrences with the same key take precedence.
|
| 77 |
+
"""
|
| 78 |
+
merged: "OrderedDict[str, str]" = OrderedDict()
|
| 79 |
+
|
| 80 |
+
for item in cli_pairs:
|
| 81 |
+
if "=" not in item:
|
| 82 |
+
raise ValueError(f"--env expects KEY=VALUE entries, got '{item}'")
|
| 83 |
+
key, value = item.split("=", 1)
|
| 84 |
+
key = key.strip()
|
| 85 |
+
if not key:
|
| 86 |
+
raise ValueError(f"Environment key is empty in '{item}'")
|
| 87 |
+
merged[key] = value
|
| 88 |
+
|
| 89 |
+
for name in inject_names:
|
| 90 |
+
if not name:
|
| 91 |
+
raise ValueError("Encountered empty --inject-env name")
|
| 92 |
+
if name not in os.environ:
|
| 93 |
+
raise ValueError(f"--inject-env requested '{name}' but it is not set locally")
|
| 94 |
+
merged[name] = os.environ[name]
|
| 95 |
+
|
| 96 |
+
return list(merged.items())
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def build_bootstrap_script(
|
| 100 |
+
repo_dir: str,
|
| 101 |
+
run_script: str,
|
| 102 |
+
branch: str,
|
| 103 |
+
repo_url: Optional[str],
|
| 104 |
+
env_file_remote: str,
|
| 105 |
+
auto_start: bool,
|
| 106 |
+
tmux_session: str,
|
| 107 |
+
) -> str:
|
| 108 |
+
"""Compose the bash script executed on the instance to prepare the training run."""
|
| 109 |
+
lines: List[str] = [
|
| 110 |
+
"#!/usr/bin/env bash",
|
| 111 |
+
"set -euxo pipefail",
|
| 112 |
+
f'REPO_DIR="$HOME/{repo_dir}"',
|
| 113 |
+
f'RUN_SCRIPT="{run_script}"',
|
| 114 |
+
f'ENV_FILE="{env_file_remote}"',
|
| 115 |
+
f'AUTO_START="{1 if auto_start else 0}"',
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
if repo_url:
|
| 119 |
+
lines.append(f"REPO_URL={shell_quote(repo_url)}")
|
| 120 |
+
lines.extend(
|
| 121 |
+
[
|
| 122 |
+
'if [ ! -d "$REPO_DIR/.git" ]; then',
|
| 123 |
+
' rm -rf "$REPO_DIR"',
|
| 124 |
+
' git clone "$REPO_URL" "$REPO_DIR"',
|
| 125 |
+
"fi",
|
| 126 |
+
'cd "$REPO_DIR"',
|
| 127 |
+
"git fetch --all --prune",
|
| 128 |
+
f"git switch {shell_quote(branch)}",
|
| 129 |
+
"git pull --ff-only || true",
|
| 130 |
+
]
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
lines.extend(
|
| 134 |
+
[
|
| 135 |
+
'mkdir -p "$REPO_DIR"',
|
| 136 |
+
'cd "$REPO_DIR"',
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
lines.extend(
|
| 141 |
+
[
|
| 142 |
+
'if [ -f "$ENV_FILE" ]; then',
|
| 143 |
+
' cp "$ENV_FILE" .env',
|
| 144 |
+
"fi",
|
| 145 |
+
'if [ -f "$RUN_SCRIPT" ]; then',
|
| 146 |
+
' chmod +x "$RUN_SCRIPT"',
|
| 147 |
+
"else",
|
| 148 |
+
' echo "Run script $RUN_SCRIPT not found; auto-start will be skipped." >&2',
|
| 149 |
+
' AUTO_START="0"',
|
| 150 |
+
"fi",
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if auto_start:
|
| 155 |
+
tmux_line = (
|
| 156 |
+
f'tmux new -d -s {shell_quote(tmux_session)} '
|
| 157 |
+
'"cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\""'
|
| 158 |
+
)
|
| 159 |
+
nohup_line = (
|
| 160 |
+
'nohup bash -lc "cd \\"$REPO_DIR\\" && bash \\"$RUN_SCRIPT\\"" '
|
| 161 |
+
'> "$HOME/nanochat-train.log" 2>&1 &'
|
| 162 |
+
)
|
| 163 |
+
lines.extend(
|
| 164 |
+
[
|
| 165 |
+
'if [ "$AUTO_START" = "1" ]; then',
|
| 166 |
+
" if command -v tmux >/dev/null 2>&1; then",
|
| 167 |
+
f" {tmux_line}",
|
| 168 |
+
" else",
|
| 169 |
+
f" {nohup_line}",
|
| 170 |
+
" fi",
|
| 171 |
+
"fi",
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return "\n".join(lines) + "\n"
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def build_user_data(
|
| 179 |
+
packages: Sequence[str],
|
| 180 |
+
bootstrap_script: str,
|
| 181 |
+
env_pairs: Sequence[Tuple[str, str]],
|
| 182 |
+
env_file_remote: str,
|
| 183 |
+
) -> str:
|
| 184 |
+
"""Render cloud-init user-data with package installs, env file, and bootstrap script."""
|
| 185 |
+
lines: List[str] = ["#cloud-config", "package_update: true", "package_upgrade: false"]
|
| 186 |
+
|
| 187 |
+
if packages:
|
| 188 |
+
lines.append("packages:")
|
| 189 |
+
for package in packages:
|
| 190 |
+
lines.append(f" - {package}")
|
| 191 |
+
|
| 192 |
+
lines.append("write_files:")
|
| 193 |
+
if env_pairs:
|
| 194 |
+
env_content = "\n".join(f"{key}={value}" for key, value in env_pairs) + "\n"
|
| 195 |
+
lines.extend(
|
| 196 |
+
[
|
| 197 |
+
f" - path: {env_file_remote}",
|
| 198 |
+
" owner: ubuntu:ubuntu",
|
| 199 |
+
" permissions: '0640'",
|
| 200 |
+
" content: |",
|
| 201 |
+
textwrap.indent(env_content, " "),
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
lines.extend(
|
| 206 |
+
[
|
| 207 |
+
" - path: /home/ubuntu/bootstrap_nanochat.sh",
|
| 208 |
+
" owner: ubuntu:ubuntu",
|
| 209 |
+
" permissions: '0755'",
|
| 210 |
+
" content: |",
|
| 211 |
+
textwrap.indent(bootstrap_script, " "),
|
| 212 |
+
]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
lines.extend(
|
| 216 |
+
[
|
| 217 |
+
"runcmd:",
|
| 218 |
+
" - \"su - ubuntu -c '/home/ubuntu/bootstrap_nanochat.sh'\"",
|
| 219 |
+
]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return "\n".join(lines) + "\n"
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class LambdaClient:
|
| 226 |
+
"""Minimal wrapper around the Lambda Cloud REST API."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, api_key: str) -> None:
|
| 229 |
+
if not api_key:
|
| 230 |
+
raise ValueError("Lambda Cloud API key not provided")
|
| 231 |
+
|
| 232 |
+
self.session = requests.Session()
|
| 233 |
+
self.session.headers.update(
|
| 234 |
+
{
|
| 235 |
+
"Authorization": f"Bearer {api_key}",
|
| 236 |
+
"Accept": "application/json",
|
| 237 |
+
"Content-Type": "application/json",
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def launch_instances(self, payload: Dict[str, object]) -> List[str]:
|
| 242 |
+
response = self.session.post(
|
| 243 |
+
f"{API_BASE}/instance-operations/launch",
|
| 244 |
+
data=json.dumps(payload),
|
| 245 |
+
timeout=60,
|
| 246 |
+
)
|
| 247 |
+
if response.status_code >= 400:
|
| 248 |
+
raise requests.HTTPError(response.text, response=response)
|
| 249 |
+
data = response.json()
|
| 250 |
+
instance_ids = data["data"]["instance_ids"]
|
| 251 |
+
return instance_ids
|
| 252 |
+
|
| 253 |
+
def get_instance(self, instance_id: str) -> Optional[Dict[str, object]]:
|
| 254 |
+
response = self.session.get(
|
| 255 |
+
f"{API_BASE}/instances/{instance_id}",
|
| 256 |
+
timeout=30,
|
| 257 |
+
)
|
| 258 |
+
if response.status_code == 404:
|
| 259 |
+
return None
|
| 260 |
+
if response.status_code >= 400:
|
| 261 |
+
raise requests.HTTPError(response.text, response=response)
|
| 262 |
+
return response.json()["data"]
|
| 263 |
+
|
| 264 |
+
def wait_for_instance(
|
| 265 |
+
self,
|
| 266 |
+
instance_id: str,
|
| 267 |
+
poll_seconds: int,
|
| 268 |
+
max_wait_minutes: int,
|
| 269 |
+
) -> Dict[str, object]:
|
| 270 |
+
deadline = time.time() + max_wait_minutes * 60
|
| 271 |
+
last_status = "unknown"
|
| 272 |
+
while time.time() < deadline:
|
| 273 |
+
instance = self.get_instance(instance_id)
|
| 274 |
+
if not instance:
|
| 275 |
+
time.sleep(poll_seconds)
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
status = str(instance.get("status", "unknown"))
|
| 279 |
+
if status != last_status:
|
| 280 |
+
log(f"Instance {instance_id} status: {status}")
|
| 281 |
+
last_status = status
|
| 282 |
+
|
| 283 |
+
if status == "active":
|
| 284 |
+
return instance
|
| 285 |
+
if status in {"terminated", "terminating", "preempted"}:
|
| 286 |
+
raise RuntimeError(f"Instance {instance_id} entered terminal status '{status}'")
|
| 287 |
+
if status == "unhealthy":
|
| 288 |
+
warn(f"Instance {instance_id} reported unhealthy; continuing to poll")
|
| 289 |
+
|
| 290 |
+
time.sleep(poll_seconds)
|
| 291 |
+
|
| 292 |
+
raise TimeoutError(f"Timed out waiting for instance {instance_id} to become active")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| 296 |
+
parser = argparse.ArgumentParser(
|
| 297 |
+
description="Launch a Lambda Labs instance and prepare the nanochat training run."
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--api-key",
|
| 301 |
+
default=os.getenv("LAMBDA_API_KEY"),
|
| 302 |
+
help="Lambda Cloud API key (default: read from LAMBDA_API_KEY).",
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument("--region", default="us-west-1", help="Lambda Cloud region name.")
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--instance-type",
|
| 307 |
+
default="gpu_1x_a10",
|
| 308 |
+
help="Instance type to launch (see Lambda Cloud docs for the catalog).",
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--ssh-key-name",
|
| 312 |
+
required=True,
|
| 313 |
+
help="Name of the SSH key already registered with Lambda Cloud.",
|
| 314 |
+
)
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
"--quantity",
|
| 317 |
+
type=int,
|
| 318 |
+
default=1,
|
| 319 |
+
help="Number of instances to launch (auto-start only supports 1).",
|
| 320 |
+
)
|
| 321 |
+
parser.add_argument("--name", help="Friendly name to assign to the instance(s).")
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--repo-url",
|
| 324 |
+
help="Git URL for the nanochat repository (default: auto-detect from current repo).",
|
| 325 |
+
)
|
| 326 |
+
parser.add_argument("--branch", default="main", help="Git branch to checkout on the instance.")
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--repo-dir",
|
| 329 |
+
default="nanochatAquaRat",
|
| 330 |
+
help="Directory name to clone the repository into on the instance.",
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--run-script",
|
| 334 |
+
default="run_aquarat_small.sh",
|
| 335 |
+
help="Relative path to the training launch script inside the repo.",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--tmux-session",
|
| 339 |
+
default="nanochat-train",
|
| 340 |
+
help="tmux session name used when --auto-start is active.",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--auto-start",
|
| 344 |
+
action="store_true",
|
| 345 |
+
help="Kick off the training script automatically after provisioning completes.",
|
| 346 |
+
)
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--env",
|
| 349 |
+
action="append",
|
| 350 |
+
default=[],
|
| 351 |
+
help="Additional KEY=VALUE pairs to write into the remote .env file (repeatable).",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--inject-env",
|
| 355 |
+
action="append",
|
| 356 |
+
default=[],
|
| 357 |
+
help="Names of local environment variables whose values should populate the remote .env.",
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--env-file-name",
|
| 361 |
+
default=".env.lambda",
|
| 362 |
+
help="Filename (relative to /home/ubuntu) for the generated environment file.",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--image-id",
|
| 366 |
+
help="Optional image ID to use instead of the default Lambda Stack image.",
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--image-family",
|
| 370 |
+
help="Optional image family name to use instead of the default image.",
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--max-wait-minutes",
|
| 374 |
+
type=int,
|
| 375 |
+
default=25,
|
| 376 |
+
help="Maximum minutes to wait for the instance to become active.",
|
| 377 |
+
)
|
| 378 |
+
parser.add_argument(
|
| 379 |
+
"--poll-seconds",
|
| 380 |
+
type=int,
|
| 381 |
+
default=20,
|
| 382 |
+
help="Polling interval while waiting for the instance to become active.",
|
| 383 |
+
)
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
"--skip-wait",
|
| 386 |
+
action="store_true",
|
| 387 |
+
help="Exit after requesting the launch without waiting for active status.",
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--no-user-data",
|
| 391 |
+
action="store_true",
|
| 392 |
+
help="Skip sending user-data; instance boots with the stock image configuration.",
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--print-user-data",
|
| 396 |
+
action="store_true",
|
| 397 |
+
help="Print the generated user-data to stdout before launching.",
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return parser.parse_args(argv)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def main(argv: Optional[Sequence[str]] = None) -> int:
|
| 404 |
+
args = parse_args(argv)
|
| 405 |
+
|
| 406 |
+
if args.image_id and args.image_family:
|
| 407 |
+
error("Provide only one of --image-id or --image-family.")
|
| 408 |
+
return 1
|
| 409 |
+
|
| 410 |
+
if args.auto_start and args.quantity != 1:
|
| 411 |
+
error("--auto-start currently supports a single instance (set --quantity=1).")
|
| 412 |
+
return 1
|
| 413 |
+
|
| 414 |
+
if not args.api_key:
|
| 415 |
+
error("Lambda Cloud API key not provided. Use --api-key or set LAMBDA_API_KEY.")
|
| 416 |
+
return 1
|
| 417 |
+
|
| 418 |
+
repo_url = args.repo_url
|
| 419 |
+
if not repo_url:
|
| 420 |
+
repo_url = guess_repo_url()
|
| 421 |
+
if repo_url:
|
| 422 |
+
log(f"Discovered repository URL: {repo_url}")
|
| 423 |
+
else:
|
| 424 |
+
warn(
|
| 425 |
+
"Could not auto-detect repository URL; "
|
| 426 |
+
"pass --repo-url if you want the instance to clone the repo automatically."
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
try:
|
| 430 |
+
env_pairs = collect_env_pairs(args.env, args.inject_env)
|
| 431 |
+
except ValueError as exc:
|
| 432 |
+
error(str(exc))
|
| 433 |
+
return 1
|
| 434 |
+
|
| 435 |
+
env_file_remote = f"/home/ubuntu/{args.env_file_name}"
|
| 436 |
+
|
| 437 |
+
bootstrap_script = build_bootstrap_script(
|
| 438 |
+
repo_dir=args.repo_dir,
|
| 439 |
+
run_script=args.run_script,
|
| 440 |
+
branch=args.branch,
|
| 441 |
+
repo_url=repo_url,
|
| 442 |
+
env_file_remote=env_file_remote,
|
| 443 |
+
auto_start=args.auto_start,
|
| 444 |
+
tmux_session=args.tmux_session,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
user_data: Optional[str] = None
|
| 448 |
+
if not args.no_user_data:
|
| 449 |
+
user_data = build_user_data(
|
| 450 |
+
packages=DEFAULT_PACKAGES,
|
| 451 |
+
bootstrap_script=bootstrap_script,
|
| 452 |
+
env_pairs=env_pairs,
|
| 453 |
+
env_file_remote=env_file_remote,
|
| 454 |
+
)
|
| 455 |
+
if args.print_user_data:
|
| 456 |
+
print(user_data)
|
| 457 |
+
else:
|
| 458 |
+
if args.print_user_data:
|
| 459 |
+
print("# user-data disabled (--no-user-data)")
|
| 460 |
+
|
| 461 |
+
payload: Dict[str, object] = {
|
| 462 |
+
"region_name": args.region,
|
| 463 |
+
"instance_type_name": args.instance_type,
|
| 464 |
+
"ssh_key_names": [args.ssh_key_name],
|
| 465 |
+
}
|
| 466 |
+
if args.quantity:
|
| 467 |
+
payload["quantity"] = args.quantity
|
| 468 |
+
if args.name:
|
| 469 |
+
payload["name"] = args.name
|
| 470 |
+
if user_data:
|
| 471 |
+
payload["user_data"] = user_data
|
| 472 |
+
if args.image_id:
|
| 473 |
+
payload["image"] = {"id": args.image_id}
|
| 474 |
+
elif args.image_family:
|
| 475 |
+
payload["image"] = {"family": args.image_family}
|
| 476 |
+
|
| 477 |
+
client = LambdaClient(args.api_key)
|
| 478 |
+
|
| 479 |
+
log(
|
| 480 |
+
"Requesting instance launch "
|
| 481 |
+
f"(region={args.region}, type={args.instance_type}, quantity={args.quantity})"
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
instance_ids = client.launch_instances(payload)
|
| 486 |
+
except requests.HTTPError as exc:
|
| 487 |
+
error(f"Instance launch failed: {exc}")
|
| 488 |
+
if exc.response is not None:
|
| 489 |
+
warn(f"Response content: {exc.response.text}")
|
| 490 |
+
return 1
|
| 491 |
+
|
| 492 |
+
log(f"Requested instance IDs: {', '.join(instance_ids)}")
|
| 493 |
+
|
| 494 |
+
if args.skip_wait:
|
| 495 |
+
log("Skipping wait (--skip-wait supplied).")
|
| 496 |
+
return 0
|
| 497 |
+
|
| 498 |
+
instances: List[Dict[str, object]] = []
|
| 499 |
+
for instance_id in instance_ids:
|
| 500 |
+
try:
|
| 501 |
+
instance = client.wait_for_instance(
|
| 502 |
+
instance_id=instance_id,
|
| 503 |
+
poll_seconds=args.poll_seconds,
|
| 504 |
+
max_wait_minutes=args.max_wait_minutes,
|
| 505 |
+
)
|
| 506 |
+
except (RuntimeError, TimeoutError, requests.HTTPError) as exc:
|
| 507 |
+
error(f"Failed while waiting for instance {instance_id}: {exc}")
|
| 508 |
+
return 1
|
| 509 |
+
instances.append(instance)
|
| 510 |
+
|
| 511 |
+
for instance in instances:
|
| 512 |
+
ip = instance.get("ip") or "<pending>"
|
| 513 |
+
name = instance.get("name") or instance.get("id")
|
| 514 |
+
log(f"Instance {name} is active with public IP {ip}")
|
| 515 |
+
if ip and ip != "<pending>":
|
| 516 |
+
log(
|
| 517 |
+
f"SSH command: ssh -i /path/to/key.pem ubuntu@{ip}"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if args.auto_start:
|
| 521 |
+
log(
|
| 522 |
+
"Auto-start enabled. Training is running inside tmux; attach with "
|
| 523 |
+
f"`ssh ...` then `tmux attach -t {args.tmux_session}`."
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
log(
|
| 527 |
+
"Auto-start disabled. After SSH'ing in, run "
|
| 528 |
+
f"`cd ~/{args.repo_dir} && bash {args.run_script}`."
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
return 0
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
if __name__ == "__main__":
|
| 535 |
+
sys.exit(main())
|