justinkay
commited on
Commit
·
3c56581
1
Parent(s):
991278f
Per-user subsampling
Browse files
app.py
CHANGED
|
@@ -74,61 +74,35 @@ print(f"Loaded {len(images_data)} images for the quiz")
|
|
| 74 |
|
| 75 |
# Load image filenames list
|
| 76 |
with open('images.txt', 'r') as f:
|
| 77 |
-
|
| 78 |
|
| 79 |
-
# Initialize
|
| 80 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
|
| 82 |
# Load full dataset
|
| 83 |
full_preds = torch.load("iwildcam_demo.pt").to(device)
|
| 84 |
full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
|
| 85 |
|
| 86 |
-
#
|
| 87 |
from collections import defaultdict
|
| 88 |
-
|
| 89 |
for idx, label in enumerate(full_labels):
|
| 90 |
class_idx = label.item()
|
| 91 |
-
|
| 92 |
|
| 93 |
# Find minimum class size
|
| 94 |
-
min_class_size = min(len(indices) for indices in
|
| 95 |
-
print(f"
|
| 96 |
-
|
| 97 |
-
# Randomly subsample each class
|
| 98 |
-
subsampled_indices = []
|
| 99 |
-
for class_idx in sorted(class_to_indices.keys()):
|
| 100 |
-
indices = class_to_indices[class_idx]
|
| 101 |
-
sampled = np.random.choice(indices, size=min_class_size, replace=False)
|
| 102 |
-
subsampled_indices.extend(sampled.tolist())
|
| 103 |
-
|
| 104 |
-
# Sort indices to maintain order
|
| 105 |
-
subsampled_indices.sort()
|
| 106 |
-
|
| 107 |
-
# Create subsampled dataset
|
| 108 |
-
subsampled_preds = full_preds[:, subsampled_indices, :]
|
| 109 |
-
subsampled_labels = full_labels[subsampled_indices]
|
| 110 |
-
image_filenames = [image_filenames[idx] for idx in subsampled_indices]
|
| 111 |
-
|
| 112 |
-
# Create Dataset object with subsampled data
|
| 113 |
-
dataset = Dataset.__new__(Dataset)
|
| 114 |
-
dataset.preds = subsampled_preds
|
| 115 |
-
dataset.labels = subsampled_labels
|
| 116 |
-
dataset.device = device
|
| 117 |
|
|
|
|
| 118 |
loss_fn = LOSS_FNS['acc']
|
| 119 |
-
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 120 |
|
| 121 |
-
#
|
| 122 |
-
coda_selector = CODA(dataset)
|
| 123 |
-
|
| 124 |
-
print(f"Initialized CODA with {dataset.preds.shape[1]} samples and {dataset.preds.shape[0]} models")
|
| 125 |
-
|
| 126 |
-
# Global state
|
| 127 |
current_image_info = None
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
iteration_count = 0
|
| 133 |
|
| 134 |
def get_model_predictions(chosen_idx):
|
|
@@ -804,9 +778,34 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
|
| 804 |
|
| 805 |
# Set up button interactions
|
| 806 |
def start_demo():
|
| 807 |
-
global iteration_count, coda_selector
|
|
|
|
| 808 |
# Reset the demo state
|
| 809 |
iteration_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
coda_selector = CODA(dataset)
|
| 811 |
|
| 812 |
image, status, predictions = get_next_coda_image()
|
|
@@ -817,9 +816,34 @@ with gr.Blocks(title="CODA: Wildlife Photo Classification Challenge",
|
|
| 817 |
return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)
|
| 818 |
|
| 819 |
def start_over():
|
| 820 |
-
global iteration_count, coda_selector
|
|
|
|
| 821 |
# Reset the demo state
|
| 822 |
iteration_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
coda_selector = CODA(dataset)
|
| 824 |
|
| 825 |
# Reset all displays
|
|
|
|
| 74 |
|
| 75 |
# Load image filenames list
|
| 76 |
with open('images.txt', 'r') as f:
|
| 77 |
+
full_image_filenames = [line.strip() for line in f.readlines() if line.strip()]
|
| 78 |
|
| 79 |
+
# Initialize full dataset (will be subsampled per-user)
|
| 80 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
|
| 82 |
# Load full dataset
|
| 83 |
full_preds = torch.load("iwildcam_demo.pt").to(device)
|
| 84 |
full_labels = torch.load("iwildcam_demo_labels.pt").to(device)
|
| 85 |
|
| 86 |
+
# Pre-compute class indices for subsampling
|
| 87 |
from collections import defaultdict
|
| 88 |
+
full_class_to_indices = defaultdict(list)
|
| 89 |
for idx, label in enumerate(full_labels):
|
| 90 |
class_idx = label.item()
|
| 91 |
+
full_class_to_indices[class_idx].append(idx)
|
| 92 |
|
| 93 |
# Find minimum class size
|
| 94 |
+
min_class_size = min(len(indices) for indices in full_class_to_indices.values())
|
| 95 |
+
print(f"Each user will get {min_class_size} images per class (total: {min_class_size * len(full_class_to_indices)} images per user)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
# Loss function for oracle
|
| 98 |
loss_fn = LOSS_FNS['acc']
|
|
|
|
| 99 |
|
| 100 |
+
# Global state (will be set per-user in start_demo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
current_image_info = None
|
| 102 |
+
coda_selector = None
|
| 103 |
+
oracle = None
|
| 104 |
+
dataset = None
|
| 105 |
+
image_filenames = None
|
| 106 |
iteration_count = 0
|
| 107 |
|
| 108 |
def get_model_predictions(chosen_idx):
|
|
|
|
| 778 |
|
| 779 |
# Set up button interactions
|
| 780 |
def start_demo():
|
| 781 |
+
global iteration_count, coda_selector, dataset, oracle, image_filenames
|
| 782 |
+
|
| 783 |
# Reset the demo state
|
| 784 |
iteration_count = 0
|
| 785 |
+
|
| 786 |
+
# Subsample dataset for this user
|
| 787 |
+
subsampled_indices = []
|
| 788 |
+
for class_idx in sorted(full_class_to_indices.keys()):
|
| 789 |
+
indices = full_class_to_indices[class_idx]
|
| 790 |
+
sampled = np.random.choice(indices, size=min_class_size, replace=False)
|
| 791 |
+
subsampled_indices.extend(sampled.tolist())
|
| 792 |
+
|
| 793 |
+
# Sort indices to maintain order
|
| 794 |
+
subsampled_indices.sort()
|
| 795 |
+
|
| 796 |
+
# Create subsampled dataset for this user
|
| 797 |
+
subsampled_preds = full_preds[:, subsampled_indices, :]
|
| 798 |
+
subsampled_labels = full_labels[subsampled_indices]
|
| 799 |
+
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
|
| 800 |
+
|
| 801 |
+
# Create Dataset object with subsampled data
|
| 802 |
+
dataset = Dataset.__new__(Dataset)
|
| 803 |
+
dataset.preds = subsampled_preds
|
| 804 |
+
dataset.labels = subsampled_labels
|
| 805 |
+
dataset.device = device
|
| 806 |
+
|
| 807 |
+
# Create oracle and CODA selector for this user
|
| 808 |
+
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 809 |
coda_selector = CODA(dataset)
|
| 810 |
|
| 811 |
image, status, predictions = get_next_coda_image()
|
|
|
|
| 816 |
return image, status_html, predictions, prob_plot, acc_plot, gr.update(visible=False), "", gr.update(visible=True)
|
| 817 |
|
| 818 |
def start_over():
|
| 819 |
+
global iteration_count, coda_selector, dataset, oracle, image_filenames
|
| 820 |
+
|
| 821 |
# Reset the demo state
|
| 822 |
iteration_count = 0
|
| 823 |
+
|
| 824 |
+
# Subsample dataset for this user (new random subsample)
|
| 825 |
+
subsampled_indices = []
|
| 826 |
+
for class_idx in sorted(full_class_to_indices.keys()):
|
| 827 |
+
indices = full_class_to_indices[class_idx]
|
| 828 |
+
sampled = np.random.choice(indices, size=min_class_size, replace=False)
|
| 829 |
+
subsampled_indices.extend(sampled.tolist())
|
| 830 |
+
|
| 831 |
+
# Sort indices to maintain order
|
| 832 |
+
subsampled_indices.sort()
|
| 833 |
+
|
| 834 |
+
# Create subsampled dataset for this user
|
| 835 |
+
subsampled_preds = full_preds[:, subsampled_indices, :]
|
| 836 |
+
subsampled_labels = full_labels[subsampled_indices]
|
| 837 |
+
image_filenames = [full_image_filenames[idx] for idx in subsampled_indices]
|
| 838 |
+
|
| 839 |
+
# Create Dataset object with subsampled data
|
| 840 |
+
dataset = Dataset.__new__(Dataset)
|
| 841 |
+
dataset.preds = subsampled_preds
|
| 842 |
+
dataset.labels = subsampled_labels
|
| 843 |
+
dataset.device = device
|
| 844 |
+
|
| 845 |
+
# Create oracle and CODA selector for this user
|
| 846 |
+
oracle = Oracle(dataset, loss_fn=loss_fn)
|
| 847 |
coda_selector = CODA(dataset)
|
| 848 |
|
| 849 |
# Reset all displays
|
images.txt
CHANGED
|
@@ -33,7 +33,6 @@
|
|
| 33 |
8eb30b2a-21bc-11ea-a13a-137349068a90.jpg
|
| 34 |
99005e3e-21bc-11ea-a13a-137349068a90.jpg
|
| 35 |
86e3b2fa-21bc-11ea-a13a-137349068a90.jpg
|
| 36 |
-
97c99b7a-21bc-11ea-a13a-137349068a90.jpg
|
| 37 |
8f988d76-21bc-11ea-a13a-137349068a90.jpg
|
| 38 |
9593f3a0-21bc-11ea-a13a-137349068a90.jpg
|
| 39 |
988d1cbc-21bc-11ea-a13a-137349068a90.jpg
|
|
@@ -188,7 +187,6 @@
|
|
| 188 |
98536d82-21bc-11ea-a13a-137349068a90.jpg
|
| 189 |
8f4dd7f4-21bc-11ea-a13a-137349068a90.jpg
|
| 190 |
8f88c6e8-21bc-11ea-a13a-137349068a90.jpg
|
| 191 |
-
95ddfafe-21bc-11ea-a13a-137349068a90.jpg
|
| 192 |
8aaabb04-21bc-11ea-a13a-137349068a90.jpg
|
| 193 |
8768fa28-21bc-11ea-a13a-137349068a90.jpg
|
| 194 |
9505e7fe-21bc-11ea-a13a-137349068a90.jpg
|
|
|
|
| 33 |
8eb30b2a-21bc-11ea-a13a-137349068a90.jpg
|
| 34 |
99005e3e-21bc-11ea-a13a-137349068a90.jpg
|
| 35 |
86e3b2fa-21bc-11ea-a13a-137349068a90.jpg
|
|
|
|
| 36 |
8f988d76-21bc-11ea-a13a-137349068a90.jpg
|
| 37 |
9593f3a0-21bc-11ea-a13a-137349068a90.jpg
|
| 38 |
988d1cbc-21bc-11ea-a13a-137349068a90.jpg
|
|
|
|
| 187 |
98536d82-21bc-11ea-a13a-137349068a90.jpg
|
| 188 |
8f4dd7f4-21bc-11ea-a13a-137349068a90.jpg
|
| 189 |
8f88c6e8-21bc-11ea-a13a-137349068a90.jpg
|
|
|
|
| 190 |
8aaabb04-21bc-11ea-a13a-137349068a90.jpg
|
| 191 |
8768fa28-21bc-11ea-a13a-137349068a90.jpg
|
| 192 |
9505e7fe-21bc-11ea-a13a-137349068a90.jpg
|
iwildcam_demo.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b44bb58d31c2f26a17a754861ed39d30f38afd0a6a263c358b5dcfafbe287b21
|
| 3 |
+
size 77715
|
iwildcam_demo_annotations.json
CHANGED
|
@@ -404,17 +404,6 @@
|
|
| 404 |
"file_name": "86e3b2fa-21bc-11ea-a13a-137349068a90.jpg",
|
| 405 |
"seq_frame_num": 4
|
| 406 |
},
|
| 407 |
-
{
|
| 408 |
-
"seq_num_frames": 10,
|
| 409 |
-
"location": 218,
|
| 410 |
-
"datetime": "2013-05-11 20:29:09.000",
|
| 411 |
-
"id": "97c99b7a-21bc-11ea-a13a-137349068a90",
|
| 412 |
-
"seq_id": "3019d1ce-7d42-11eb-8fb5-0242ac1c0002",
|
| 413 |
-
"width": 1920,
|
| 414 |
-
"height": 1080,
|
| 415 |
-
"file_name": "97c99b7a-21bc-11ea-a13a-137349068a90.jpg",
|
| 416 |
-
"seq_frame_num": 8
|
| 417 |
-
},
|
| 418 |
{
|
| 419 |
"seq_num_frames": 10,
|
| 420 |
"location": 408,
|
|
@@ -2166,17 +2155,6 @@
|
|
| 2166 |
"file_name": "8f88c6e8-21bc-11ea-a13a-137349068a90.jpg",
|
| 2167 |
"seq_frame_num": 2
|
| 2168 |
},
|
| 2169 |
-
{
|
| 2170 |
-
"seq_num_frames": 10,
|
| 2171 |
-
"location": 151,
|
| 2172 |
-
"datetime": "2013-02-03 01:47:54.000",
|
| 2173 |
-
"id": "95ddfafe-21bc-11ea-a13a-137349068a90",
|
| 2174 |
-
"seq_id": "302607f0-7d42-11eb-8fb5-0242ac1c0002",
|
| 2175 |
-
"width": 1920,
|
| 2176 |
-
"height": 1080,
|
| 2177 |
-
"file_name": "95ddfafe-21bc-11ea-a13a-137349068a90.jpg",
|
| 2178 |
-
"seq_frame_num": 3
|
| 2179 |
-
},
|
| 2180 |
{
|
| 2181 |
"seq_num_frames": 10,
|
| 2182 |
"location": 151,
|
|
@@ -16359,11 +16337,6 @@
|
|
| 16359 |
"image_id": "98c3c686-21bc-11ea-a13a-137349068a90",
|
| 16360 |
"category_id": 10
|
| 16361 |
},
|
| 16362 |
-
{
|
| 16363 |
-
"id": "9a6547bc-21bc-11ea-a13a-137349068a90",
|
| 16364 |
-
"image_id": "97c99b7a-21bc-11ea-a13a-137349068a90",
|
| 16365 |
-
"category_id": 10
|
| 16366 |
-
},
|
| 16367 |
{
|
| 16368 |
"id": "a1d760ac-21bc-11ea-a13a-137349068a90",
|
| 16369 |
"image_id": "873ffbfa-21bc-11ea-a13a-137349068a90",
|
|
@@ -17894,11 +17867,6 @@
|
|
| 17894 |
"image_id": "8f88c6e8-21bc-11ea-a13a-137349068a90",
|
| 17895 |
"category_id": 101
|
| 17896 |
},
|
| 17897 |
-
{
|
| 17898 |
-
"id": "9cafbc0a-21bc-11ea-a13a-137349068a90",
|
| 17899 |
-
"image_id": "95ddfafe-21bc-11ea-a13a-137349068a90",
|
| 17900 |
-
"category_id": 101
|
| 17901 |
-
},
|
| 17902 |
{
|
| 17903 |
"id": "a36c4fd6-21bc-11ea-a13a-137349068a90",
|
| 17904 |
"image_id": "8aaabb04-21bc-11ea-a13a-137349068a90",
|
|
|
|
| 404 |
"file_name": "86e3b2fa-21bc-11ea-a13a-137349068a90.jpg",
|
| 405 |
"seq_frame_num": 4
|
| 406 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
{
|
| 408 |
"seq_num_frames": 10,
|
| 409 |
"location": 408,
|
|
|
|
| 2155 |
"file_name": "8f88c6e8-21bc-11ea-a13a-137349068a90.jpg",
|
| 2156 |
"seq_frame_num": 2
|
| 2157 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2158 |
{
|
| 2159 |
"seq_num_frames": 10,
|
| 2160 |
"location": 151,
|
|
|
|
| 16337 |
"image_id": "98c3c686-21bc-11ea-a13a-137349068a90",
|
| 16338 |
"category_id": 10
|
| 16339 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16340 |
{
|
| 16341 |
"id": "a1d760ac-21bc-11ea-a13a-137349068a90",
|
| 16342 |
"image_id": "873ffbfa-21bc-11ea-a13a-137349068a90",
|
|
|
|
| 17867 |
"image_id": "8f88c6e8-21bc-11ea-a13a-137349068a90",
|
| 17868 |
"category_id": 101
|
| 17869 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17870 |
{
|
| 17871 |
"id": "a36c4fd6-21bc-11ea-a13a-137349068a90",
|
| 17872 |
"image_id": "8aaabb04-21bc-11ea-a13a-137349068a90",
|
iwildcam_demo_images/95ddfafe-21bc-11ea-a13a-137349068a90.jpg
DELETED
Git LFS Details
|
iwildcam_demo_images/97c99b7a-21bc-11ea-a13a-137349068a90.jpg
DELETED
Git LFS Details
|
iwildcam_demo_labels.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6c3f2f49fe5b1dce7f4c9e45380fa68fe4040f004d9bcf7ff73bb3323d096f7
|
| 3 |
+
size 11780
|