Merge branch 'main' of https://github.com/MRiabov/wire-seg-hr-impl
Browse files
train.py
CHANGED
|
@@ -339,6 +339,42 @@ def main():
|
|
| 339 |
step += 1
|
| 340 |
pbar.update(1)
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
print("[WireSegHR][train] Done.")
|
| 343 |
|
| 344 |
|
|
|
|
| 339 |
step += 1
|
| 340 |
pbar.update(1)
|
| 341 |
|
| 342 |
+
# Save a final checkpoint upon completion
|
| 343 |
+
_save_checkpoint(
|
| 344 |
+
os.path.join(out_dir, f"ckpt_{iters}.pt"), step, model, optim, scaler, best_f1
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Final test evaluation
|
| 348 |
+
if dset_test is not None:
|
| 349 |
+
torch.cuda.empty_cache()
|
| 350 |
+
model.eval()
|
| 351 |
+
print(
|
| 352 |
+
f"[WireSegHR][train] Final test starting... test_size={len(dset_test)} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}",
|
| 353 |
+
flush=True,
|
| 354 |
+
)
|
| 355 |
+
test_stats = validate(
|
| 356 |
+
model,
|
| 357 |
+
dset_test,
|
| 358 |
+
coarse_test,
|
| 359 |
+
device,
|
| 360 |
+
amp_enabled,
|
| 361 |
+
amp_dtype,
|
| 362 |
+
prob_thresh,
|
| 363 |
+
mm_enable,
|
| 364 |
+
mm_kernel,
|
| 365 |
+
eval_patch_size,
|
| 366 |
+
overlap,
|
| 367 |
+
eval_fine_batch,
|
| 368 |
+
len(dset_test),
|
| 369 |
+
)
|
| 370 |
+
print(
|
| 371 |
+
f"[Test Final][Fine] IoU={test_stats['iou']:.4f} F1={test_stats['f1']:.4f} P={test_stats['precision']:.4f} R={test_stats['recall']:.4f}"
|
| 372 |
+
)
|
| 373 |
+
print(
|
| 374 |
+
f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
|
| 375 |
+
)
|
| 376 |
+
model.train()
|
| 377 |
+
|
| 378 |
print("[WireSegHR][train] Done.")
|
| 379 |
|
| 380 |
|