Jackmin108 commited on
Commit
fbace78
·
1 Parent(s): 036d2ca

modify code to use tt moe

Browse files
config.json CHANGED
@@ -2,6 +2,11 @@
2
  "architectures": [
3
  "Qwen3MoeForCausalLM"
4
  ],
 
 
 
 
 
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
 
2
  "architectures": [
3
  "Qwen3MoeForCausalLM"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_qwen3_moe.Qwen3MoeConfig",
7
+ "AutoModelForCausalLM": "modeling_qwen3_moe.Qwen3MoeForCausalLM",
8
+ "AutoModel": "modeling_qwen3_moe.Qwen3MoeModel"
9
+ },
10
  "attention_bias": false,
11
  "attention_dropout": 0.0,
12
  "bos_token_id": 151643,
configuration_qwen3_moe.py CHANGED
@@ -14,9 +14,9 @@
14
  # limitations under the License.
15
  """Qwen3MoE model configuration"""
16
 
17
- from ...configuration_utils import PretrainedConfig
18
- from ...modeling_rope_utils import rope_config_validation
19
- from ...utils import logging
20
 
21
 
22
  logger = logging.get_logger(__name__)
 
14
  # limitations under the License.
15
  """Qwen3MoE model configuration"""
16
 
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
 
21
 
22
  logger = logging.get_logger(__name__)
modeling_qwen3_moe.py CHANGED
@@ -39,6 +39,8 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tu
39
  from transformers.utils.generic import OutputRecorder, check_model_inputs
40
  from .configuration_qwen3_moe import Qwen3MoeConfig
41
 
 
 
42
 
43
  def rotate_half(x):
44
  """Rotates half the hidden dims of the input."""
@@ -284,10 +286,22 @@ class Qwen3MoeDecoderLayer(GradientCheckpointingLayer):
284
 
285
  self.self_attn = Qwen3MoeAttention(config, layer_idx)
286
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  if (layer_idx not in config.mlp_only_layers) and (
288
  config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
289
  ):
290
- self.mlp = Qwen3MoeSparseMoeBlock(config)
291
  else:
292
  self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
293
 
 
39
  from transformers.utils.generic import OutputRecorder, check_model_inputs
40
  from .configuration_qwen3_moe import Qwen3MoeConfig
41
 
42
+ from torchtitan.models.moe import MoE, MoEArgs
43
+
44
 
45
  def rotate_half(x):
46
  """Rotates half the hidden dims of the input."""
 
286
 
287
  self.self_attn = Qwen3MoeAttention(config, layer_idx)
288
 
289
+ moe_args = MoEArgs(
290
+ num_experts=config.num_experts,
291
+ num_shared_experts=0,
292
+ score_func="softmax",
293
+ route_norm=config.norm_topk_prob,
294
+ route_scale=1.0,
295
+ score_before_experts=False,
296
+ top_k=config.num_experts_per_tok,
297
+ use_grouped_mm=True,
298
+ load_balance_coeff=None,
299
+ )
300
+
301
  if (layer_idx not in config.mlp_only_layers) and (
302
  config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
303
  ):
304
+ self.mlp = MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
305
  else:
306
  self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
307