justinkay commited on
Commit
3c56581
·
1 Parent(s): 991278f

Per-user subsampling

Browse files
app.py CHANGED
@@ -74,61 +74,35 @@ print(f"Loaded {len(images_data)} images for the quiz")
74
 
75
  # Load image filenames list
76
  with open('images.txt', 'r') as f:
77
- image_filenames = [line.strip() for line in f.readlines() if line.strip()]
78
 
79
- # Initialize CODA with subsampled dataset
80
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
 
82
  # Load full dataset
83
  full_preds = torch.load("iwildcam_demo.pt").to(device)
84
  full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
85
 
86
- # Subsample to balance classes
87
  from collections import defaultdict
88
- class_to_indices = defaultdict(list)
89
  for idx, label in enumerate(full_labels):
90
  class_idx = label.item()
91
- class_to_indices[class_idx].append(idx)
92
 
93
  # Find minimum class size
94
- min_class_size = min(len(indices) for indices in class_to_indices.values())
95
- print(f"Subsampling to {min_class_size} images per class (total: {min_class_size * len(class_to_indices)} images)")
96
-
97
- # Randomly subsample each class
98
- subsampled_indices = []
99
- for class_idx in sorted(class_to_indices.keys()):
100
- indices = class_to_indices[class_idx]
101
- sampled = np.random.choice(indices, size=min_class_size, replace=False)
102
- subsampled_indices.extend(sampled.tolist())
103
-
104
- # Sort indices to maintain order
105
- subsampled_indices.sort()
106
-
107
- # Create subsampled dataset
108
- subsampled_preds = full_preds[:, subsampled_indices, :]
109
- subsampled_labels = full_labels[subsampled_indices]
110
- image_filenames = [image_filenames[idx] for idx in subsampled_indices]
111
-
112
- # Create Dataset object with subsampled data
113
- dataset = Dataset.__new__(Dataset)
114
- dataset.preds = subsampled_preds
115
- dataset.labels = subsampled_labels
116
- dataset.device = device
117
 
 
118
  loss_fn = LOSS_FNS['acc']
119
- oracle = Oracle(dataset, loss_fn=loss_fn)
120
 
121
- # Create CODA selector with default parameters
122
- coda_selector = CODA(dataset)
123
-
124
- print(f"Initialized CODA with {dataset.preds.shape[1]} samples and {dataset.preds.shape[0]} models")
125
-
126
- # Global state
127
  current_image_info = None
128
- # coda_selector already initialized above
129
- # oracle already initialized above
130
- # dataset already initialized above
131
- # image_filenames already initialized above
132
  iteration_count = 0
133
 
134
  def get_model_predictions(chosen_idx):
@@ -804,9 +778,34 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
804
 
805
  # Set up button interactions
806
  def start_demo():
807
- global iteration_count, coda_selector
 
808
  # Reset the demo state
809
  iteration_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  coda_selector = CODA(dataset)
811
 
812
  image, status, predictions = get_next_coda_image()
@@ -817,9 +816,34 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
817
  return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)
818
 
819
  def start_over():
820
- global iteration_count, coda_selector
 
821
  # Reset the demo state
822
  iteration_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823
  coda_selector = CODA(dataset)
824
 
825
  # Reset all displays
 
74
 
75
  # Load image filenames list
76
  with open('images.txt', 'r') as f:
77
+ full_image_filenames = [line.strip() for line in f.readlines() if line.strip()]
78
 
79
+ # Initialize full dataset (will be subsampled per-user)
80
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
 
82
  # Load full dataset
83
  full_preds = torch.load("iwildcam_demo.pt").to(device)
84
  full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
85
 
86
+ # Pre-compute class indices for subsampling
87
  from collections import defaultdict
88
+ full_class_to_indices = defaultdict(list)
89
  for idx, label in enumerate(full_labels):
90
  class_idx = label.item()
91
+ full_class_to_indices[class_idx].append(idx)
92
 
93
  # Find minimum class size
94
+ min_class_size = min(len(indices) for indices in full_class_to_indices.values())
95
+ print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Loss function for oracle
98
  loss_fn = LOSS_FNS['acc']
 
99
 
100
+ # Global state (will be set per-user in start_demo)
 
 
 
 
 
101
  current_image_info = None
102
+ coda_selector = None
103
+ oracle = None
104
+ dataset = None
105
+ image_filenames = None
106
  iteration_count = 0
107
 
108
  def get_model_predictions(chosen_idx):
 
778
 
779
  # Set up button interactions
780
  def start_demo():
781
+ global iteration_count, coda_selector, dataset, oracle, image_filenames
782
+
783
  # Reset the demo state
784
  iteration_count = 0
785
+
786
+ # Subsample dataset for this user
787
+ subsampled_indices = []
788
+ for class_idx in sorted(full_class_to_indices.keys()):
789
+ indices = full_class_to_indices[class_idx]
790
+ sampled = np.random.choice(indices, size=min_class_size, replace=False)
791
+ subsampled_indices.extend(sampled.tolist())
792
+
793
+ # Sort indices to maintain order
794
+ subsampled_indices.sort()
795
+
796
+ # Create subsampled dataset for this user
797
+ subsampled_preds = full_preds[:, subsampled_indices, :]
798
+ subsampled_labels = full_labels[subsampled_indices]
799
+ image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
800
+
801
+ # Create Dataset object with subsampled data
802
+ dataset = Dataset.__new__(Dataset)
803
+ dataset.preds = subsampled_preds
804
+ dataset.labels = subsampled_labels
805
+ dataset.device = device
806
+
807
+ # Create oracle and CODA selector for this user
808
+ oracle = Oracle(dataset, loss_fn=loss_fn)
809
  coda_selector = CODA(dataset)
810
 
811
  image, status, predictions = get_next_coda_image()
 
816
  return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)
817
 
818
  def start_over():
819
+ global iteration_count, coda_selector, dataset, oracle, image_filenames
820
+
821
  # Reset the demo state
822
  iteration_count = 0
823
+
824
+ # Subsample dataset for this user (new random subsample)
825
+ subsampled_indices = []
826
+ for class_idx in sorted(full_class_to_indices.keys()):
827
+ indices = full_class_to_indices[class_idx]
828
+ sampled = np.random.choice(indices, size=min_class_size, replace=False)
829
+ subsampled_indices.extend(sampled.tolist())
830
+
831
+ # Sort indices to maintain order
832
+ subsampled_indices.sort()
833
+
834
+ # Create subsampled dataset for this user
835
+ subsampled_preds = full_preds[:, subsampled_indices, :]
836
+ subsampled_labels = full_labels[subsampled_indices]
837
+ image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
838
+
839
+ # Create Dataset object with subsampled data
840
+ dataset = Dataset.__new__(Dataset)
841
+ dataset.preds = subsampled_preds
842
+ dataset.labels = subsampled_labels
843
+ dataset.device = device
844
+
845
+ # Create oracle and CODA selector for this user
846
+ oracle = Oracle(dataset, loss_fn=loss_fn)
847
  coda_selector = CODA(dataset)
848
 
849
  # Reset all displays
images.txt CHANGED
@@ -33,7 +33,6 @@
33
  8eb30b2a-21bc-11ea-a13a-137349068a90.jpg
34
  99005e3e-21bc-11ea-a13a-137349068a90.jpg
35
  86e3b2fa-21bc-11ea-a13a-137349068a90.jpg
36
- 97c99b7a-21bc-11ea-a13a-137349068a90.jpg
37
  8f988d76-21bc-11ea-a13a-137349068a90.jpg
38
  9593f3a0-21bc-11ea-a13a-137349068a90.jpg
39
  988d1cbc-21bc-11ea-a13a-137349068a90.jpg
@@ -188,7 +187,6 @@
188
  98536d82-21bc-11ea-a13a-137349068a90.jpg
189
  8f4dd7f4-21bc-11ea-a13a-137349068a90.jpg
190
  8f88c6e8-21bc-11ea-a13a-137349068a90.jpg
191
- 95ddfafe-21bc-11ea-a13a-137349068a90.jpg
192
  8aaabb04-21bc-11ea-a13a-137349068a90.jpg
193
  8768fa28-21bc-11ea-a13a-137349068a90.jpg
194
  9505e7fe-21bc-11ea-a13a-137349068a90.jpg
 
33
  8eb30b2a-21bc-11ea-a13a-137349068a90.jpg
34
  99005e3e-21bc-11ea-a13a-137349068a90.jpg
35
  86e3b2fa-21bc-11ea-a13a-137349068a90.jpg
 
36
  8f988d76-21bc-11ea-a13a-137349068a90.jpg
37
  9593f3a0-21bc-11ea-a13a-137349068a90.jpg
38
  988d1cbc-21bc-11ea-a13a-137349068a90.jpg
 
187
  98536d82-21bc-11ea-a13a-137349068a90.jpg
188
  8f4dd7f4-21bc-11ea-a13a-137349068a90.jpg
189
  8f88c6e8-21bc-11ea-a13a-137349068a90.jpg
 
190
  8aaabb04-21bc-11ea-a13a-137349068a90.jpg
191
  8768fa28-21bc-11ea-a13a-137349068a90.jpg
192
  9505e7fe-21bc-11ea-a13a-137349068a90.jpg
iwildcam_demo.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1880c2a8d2dc0297fe1f87ce3eb1875c228e1d35b6c81008d81db7ee990f4c2f
3
- size 77843
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b44bb58d31c2f26a17a754861ed39d30f38afd0a6a263c358b5dcfafbe287b21
3
+ size 77715
iwildcam_demo_annotations.json CHANGED
@@ -404,17 +404,6 @@
404
  "file_name": "86e3b2fa-21bc-11ea-a13a-137349068a90.jpg",
405
  "seq_frame_num": 4
406
  },
407
- {
408
- "seq_num_frames": 10,
409
- "location": 218,
410
- "datetime": "2013-05-11 20:29:09.000",
411
- "id": "97c99b7a-21bc-11ea-a13a-137349068a90",
412
- "seq_id": "3019d1ce-7d42-11eb-8fb5-0242ac1c0002",
413
- "width": 1920,
414
- "height": 1080,
415
- "file_name": "97c99b7a-21bc-11ea-a13a-137349068a90.jpg",
416
- "seq_frame_num": 8
417
- },
418
  {
419
  "seq_num_frames": 10,
420
  "location": 408,
@@ -2166,17 +2155,6 @@
2166
  "file_name": "8f88c6e8-21bc-11ea-a13a-137349068a90.jpg",
2167
  "seq_frame_num": 2
2168
  },
2169
- {
2170
- "seq_num_frames": 10,
2171
- "location": 151,
2172
- "datetime": "2013-02-03 01:47:54.000",
2173
- "id": "95ddfafe-21bc-11ea-a13a-137349068a90",
2174
- "seq_id": "302607f0-7d42-11eb-8fb5-0242ac1c0002",
2175
- "width": 1920,
2176
- "height": 1080,
2177
- "file_name": "95ddfafe-21bc-11ea-a13a-137349068a90.jpg",
2178
- "seq_frame_num": 3
2179
- },
2180
  {
2181
  "seq_num_frames": 10,
2182
  "location": 151,
@@ -16359,11 +16337,6 @@
16359
  "image_id": "98c3c686-21bc-11ea-a13a-137349068a90",
16360
  "category_id": 10
16361
  },
16362
- {
16363
- "id": "9a6547bc-21bc-11ea-a13a-137349068a90",
16364
- "image_id": "97c99b7a-21bc-11ea-a13a-137349068a90",
16365
- "category_id": 10
16366
- },
16367
  {
16368
  "id": "a1d760ac-21bc-11ea-a13a-137349068a90",
16369
  "image_id": "873ffbfa-21bc-11ea-a13a-137349068a90",
@@ -17894,11 +17867,6 @@
17894
  "image_id": "8f88c6e8-21bc-11ea-a13a-137349068a90",
17895
  "category_id": 101
17896
  },
17897
- {
17898
- "id": "9cafbc0a-21bc-11ea-a13a-137349068a90",
17899
- "image_id": "95ddfafe-21bc-11ea-a13a-137349068a90",
17900
- "category_id": 101
17901
- },
17902
  {
17903
  "id": "a36c4fd6-21bc-11ea-a13a-137349068a90",
17904
  "image_id": "8aaabb04-21bc-11ea-a13a-137349068a90",
 
404
  "file_name": "86e3b2fa-21bc-11ea-a13a-137349068a90.jpg",
405
  "seq_frame_num": 4
406
  },
 
 
 
 
 
 
 
 
 
 
 
407
  {
408
  "seq_num_frames": 10,
409
  "location": 408,
 
2155
  "file_name": "8f88c6e8-21bc-11ea-a13a-137349068a90.jpg",
2156
  "seq_frame_num": 2
2157
  },
 
 
 
 
 
 
 
 
 
 
 
2158
  {
2159
  "seq_num_frames": 10,
2160
  "location": 151,
 
16337
  "image_id": "98c3c686-21bc-11ea-a13a-137349068a90",
16338
  "category_id": 10
16339
  },
 
 
 
 
 
16340
  {
16341
  "id": "a1d760ac-21bc-11ea-a13a-137349068a90",
16342
  "image_id": "873ffbfa-21bc-11ea-a13a-137349068a90",
 
17867
  "image_id": "8f88c6e8-21bc-11ea-a13a-137349068a90",
17868
  "category_id": 101
17869
  },
 
 
 
 
 
17870
  {
17871
  "id": "a36c4fd6-21bc-11ea-a13a-137349068a90",
17872
  "image_id": "8aaabb04-21bc-11ea-a13a-137349068a90",
iwildcam_demo_images/95ddfafe-21bc-11ea-a13a-137349068a90.jpg DELETED

Git LFS Details

  • SHA256: 85327e363c6e7422675143412fb8b988e0e22c5bf4c21c941b6441b0a9e3f0b9
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB
iwildcam_demo_images/97c99b7a-21bc-11ea-a13a-137349068a90.jpg DELETED

Git LFS Details

  • SHA256: 9729db5cda5979069e507b7b4d9133bcb934beef36976d88a031330b2f166b63
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
iwildcam_demo_labels.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d2a97f8df143685109ade0335ae8c8afc0b74787a12d237a5647fdc4f75842ed
3
- size 11844
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6c3f2f49fe5b1dce7f4c9e45380fa68fe4040f004d9bcf7ff73bb3323d096f7
3
+ size 11780