Fixed exampleccounting glitch
Browse files
main.py
CHANGED
|
@@ -33,7 +33,7 @@ def parse_arguments() -> dict:
|
|
| 33 |
"""
|
| 34 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
| 35 |
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
|
| 36 |
-
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=
|
| 37 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
| 38 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
| 39 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
|
@@ -96,7 +96,9 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
| 96 |
optimiser.step()
|
| 97 |
|
| 98 |
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Save a checkpoint of the model.
|
| 102 |
if examples_seen % config.checkpoint_every_n_tokens == 0:
|
|
@@ -168,11 +170,10 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
|
| 168 |
train_dataset = ds["train"]
|
| 169 |
test_dataset = ds["test"]
|
| 170 |
|
| 171 |
-
# TODO: tokenise the data before sending it to the model.
|
| 172 |
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
| 173 |
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
|
| 174 |
|
| 175 |
-
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser), batched=True).with_format("torch")
|
| 176 |
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
|
| 177 |
|
| 178 |
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
|
|
|
| 33 |
"""
|
| 34 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
| 35 |
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
|
| 36 |
+
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=500_000_000, help="Save a checkpoint of the model every n tokens processed.")
|
| 37 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
| 38 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
| 39 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
|
|
|
| 96 |
optimiser.step()
|
| 97 |
|
| 98 |
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
|
| 99 |
+
|
| 100 |
+
# Number of tokens processed is batch_size * sequence_length.
|
| 101 |
+
examples_seen += batch.numel()
|
| 102 |
|
| 103 |
# Save a checkpoint of the model.
|
| 104 |
if examples_seen % config.checkpoint_every_n_tokens == 0:
|
|
|
|
| 170 |
train_dataset = ds["train"]
|
| 171 |
test_dataset = ds["test"]
|
| 172 |
|
|
|
|
| 173 |
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
| 174 |
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
|
| 175 |
|
| 176 |
+
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser, 1, config.max_positional_embeddings), batched=True).with_format("torch")
|
| 177 |
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
|
| 178 |
|
| 179 |
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
utils.py
CHANGED
|
@@ -42,7 +42,7 @@ class OsSoluConfig:
|
|
| 42 |
self.self_attention_type = args["self_attention_type"]
|
| 43 |
self.vocab_size = args["vocab_size"]
|
| 44 |
|
| 45 |
-
def tokenise(batch, tokeniser, num_gpus: int
|
| 46 |
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4. Code from Neel.
|
| 47 |
|
| 48 |
Args:
|
|
|
|
| 42 |
self.self_attention_type = args["self_attention_type"]
|
| 43 |
self.vocab_size = args["vocab_size"]
|
| 44 |
|
| 45 |
+
def tokenise(batch, tokeniser, num_gpus: int, context_length: int):
|
| 46 |
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4. Code from Neel.
|
| 47 |
|
| 48 |
Args:
|