MRiabov commited on
Commit
8cd5198
·
2 Parent(s): 854831b feec877

Merge branch 'main' of https://github.com/MRiabov/wire-seg-hr-impl

Browse files
Files changed (1) hide show
  1. train.py +36 -0
train.py CHANGED
@@ -339,6 +339,42 @@ def main():
339
  step += 1
340
  pbar.update(1)
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  print("[WireSegHR][train] Done.")
343
 
344
 
 
339
  step += 1
340
  pbar.update(1)
341
 
342
+ # Save a final checkpoint upon completion
343
+ _save_checkpoint(
344
+ os.path.join(out_dir, f"ckpt_{iters}.pt"), step, model, optim, scaler, best_f1
345
+ )
346
+
347
+ # Final test evaluation
348
+ if dset_test is not None:
349
+ torch.cuda.empty_cache()
350
+ model.eval()
351
+ print(
352
+ f"[WireSegHR][train] Final test starting... test_size={len(dset_test)} patch={eval_patch_size} overlap={overlap} stride={eval_patch_size - overlap} fine_batch={eval_fine_batch}",
353
+ flush=True,
354
+ )
355
+ test_stats = validate(
356
+ model,
357
+ dset_test,
358
+ coarse_test,
359
+ device,
360
+ amp_enabled,
361
+ amp_dtype,
362
+ prob_thresh,
363
+ mm_enable,
364
+ mm_kernel,
365
+ eval_patch_size,
366
+ overlap,
367
+ eval_fine_batch,
368
+ len(dset_test),
369
+ )
370
+ print(
371
+ f"[Test Final][Fine] IoU={test_stats['iou']:.4f} F1={test_stats['f1']:.4f} P={test_stats['precision']:.4f} R={test_stats['recall']:.4f}"
372
+ )
373
+ print(
374
+ f"[Test Final][Coarse] IoU={test_stats['iou_coarse']:.4f} F1={test_stats['f1_coarse']:.4f} P={test_stats['precision_coarse']:.4f} R={test_stats['recall_coarse']:.4f}"
375
+ )
376
+ model.train()
377
+
378
  print("[WireSegHR][train] Done.")
379
 
380