Define CustomLlamaConfig
Browse files- modeling_llama.py +5 -1
modeling_llama.py
CHANGED
|
@@ -58,6 +58,10 @@ logger = logging.get_logger(__name__)
|
|
| 58 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def _get_unpad_data(attention_mask):
|
| 62 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 63 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
@@ -929,7 +933,7 @@ LLAMA_START_DOCSTRING = r"""
|
|
| 929 |
LLAMA_START_DOCSTRING,
|
| 930 |
)
|
| 931 |
class LlamaPreTrainedModel(PreTrainedModel):
|
| 932 |
-
config_class =
|
| 933 |
base_model_prefix = "model"
|
| 934 |
supports_gradient_checkpointing = True
|
| 935 |
_no_split_modules = ["LlamaDecoderLayer"]
|
|
|
|
| 58 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
| 59 |
|
| 60 |
|
| 61 |
+
CustomLlamaConfig(LlamaConfig):
|
| 62 |
+
model_type = "custom_llama"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def _get_unpad_data(attention_mask):
|
| 66 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 67 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
| 933 |
LLAMA_START_DOCSTRING,
|
| 934 |
)
|
| 935 |
class LlamaPreTrainedModel(PreTrainedModel):
|
| 936 |
+
config_class = CustomLlamaConfig
|
| 937 |
base_model_prefix = "model"
|
| 938 |
supports_gradient_checkpointing = True
|
| 939 |
_no_split_modules = ["LlamaDecoderLayer"]
|