Commit
·
fbace78
1
Parent(s):
036d2ca
modify code to use tt moe
Browse files- config.json +5 -0
- configuration_qwen3_moe.py +3 -3
- modeling_qwen3_moe.py +15 -1
config.json
CHANGED
|
@@ -2,6 +2,11 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"Qwen3MoeForCausalLM"
|
| 4 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
| 7 |
"bos_token_id": 151643,
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"Qwen3MoeForCausalLM"
|
| 4 |
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_qwen3_moe.Qwen3MoeConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_qwen3_moe.Qwen3MoeForCausalLM",
|
| 8 |
+
"AutoModel": "modeling_qwen3_moe.Qwen3MoeModel"
|
| 9 |
+
},
|
| 10 |
"attention_bias": false,
|
| 11 |
"attention_dropout": 0.0,
|
| 12 |
"bos_token_id": 151643,
|
configuration_qwen3_moe.py
CHANGED
|
@@ -14,9 +14,9 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""Qwen3MoE model configuration"""
|
| 16 |
|
| 17 |
-
from
|
| 18 |
-
from
|
| 19 |
-
from
|
| 20 |
|
| 21 |
|
| 22 |
logger = logging.get_logger(__name__)
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""Qwen3MoE model configuration"""
|
| 16 |
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
|
| 21 |
|
| 22 |
logger = logging.get_logger(__name__)
|
modeling_qwen3_moe.py
CHANGED
|
@@ -39,6 +39,8 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
|
|
| 39 |
from transformers.utils.generic import OutputRecorder, check_model_inputs
|
| 40 |
from .configuration_qwen3_moe import Qwen3MoeConfig
|
| 41 |
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def rotate_half(x):
|
| 44 |
"""Rotates half the hidden dims of the input."""
|
|
@@ -284,10 +286,22 @@ class Qwen3MoeDecoderLayer(GradientCheckpointingLayer):
|
|
| 284 |
|
| 285 |
self.self_attn = Qwen3MoeAttention(config, layer_idx)
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
if (layer_idx not in config.mlp_only_layers) and (
|
| 288 |
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
|
| 289 |
):
|
| 290 |
-
self.mlp =
|
| 291 |
else:
|
| 292 |
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
|
| 293 |
|
|
|
|
| 39 |
from transformers.utils.generic import OutputRecorder, check_model_inputs
|
| 40 |
from .configuration_qwen3_moe import Qwen3MoeConfig
|
| 41 |
|
| 42 |
+
from torchtitan.models.moe import MoE, MoEArgs
|
| 43 |
+
|
| 44 |
|
| 45 |
def rotate_half(x):
|
| 46 |
"""Rotates half the hidden dims of the input."""
|
|
|
|
| 286 |
|
| 287 |
self.self_attn = Qwen3MoeAttention(config, layer_idx)
|
| 288 |
|
| 289 |
+
moe_args = MoEArgs(
|
| 290 |
+
num_experts=config.num_experts,
|
| 291 |
+
num_shared_experts=0,
|
| 292 |
+
score_func="softmax",
|
| 293 |
+
route_norm=config.norm_topk_prob,
|
| 294 |
+
route_scale=1.0,
|
| 295 |
+
score_before_experts=False,
|
| 296 |
+
top_k=config.num_experts_per_tok,
|
| 297 |
+
use_grouped_mm=True,
|
| 298 |
+
load_balance_coeff=None,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
if (layer_idx not in config.mlp_only_layers) and (
|
| 302 |
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
|
| 303 |
):
|
| 304 |
+
self.mlp = MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
|
| 305 |
else:
|
| 306 |
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
|
| 307 |
|