LucaGroup commited on
Commit
b5831e8
·
verified ·
1 Parent(s): dcb0e36

Update weights and modeling code to latest version

Browse files
__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2025, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2025/12/30 11:32
9
+ @project: lucaone
10
+ @file: configuration_lucaone
11
+ @desc: configuration_lucaone
12
+ '''
13
+
14
+ from .configuration_lucaone import LucaGPLMConfig
15
+ from .tokenization_lucaone import LucaGPLMTokenizer, LucaGPLMTokenizerFast
16
+ from .modeling_lucaone import (
17
+ LucaGPLMModel,
18
+ LucaGPLMPreTrainedModel,
19
+ LucaGPLMForMaskedLM,
20
+ LucaGPLMForSequenceClassification,
21
+ LucaGPLMForTokenClassification
22
+ )
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModel,
26
+ AutoModelForMaskedLM,
27
+ AutoModelForSequenceClassification,
28
+ AutoModelForTokenClassification
29
+ )
30
+
31
+ __all__ = [
32
+ "LucaGPLMConfig",
33
+ "LucaGPLMModel",
34
+ "LucaGPLMPreTrainedModel",
35
+ "LucaGPLMTokenizer",
36
+ "LucaGPLMTokenizerFast",
37
+ "LucaGPLMForMaskedLM",
38
+ "LucaGPLMForSequenceClassification",
39
+ "LucaGPLMForTokenClassification"
40
+ ]
41
+
42
+
43
+ # 1. 注册配置类 (必选)
44
+ AutoConfig.register("lucaone", LucaGPLMConfig)
45
+
46
+ # 2. 注册基础模型 (用于 AutoModel.from_pretrained)
47
+ AutoModel.register(LucaGPLMConfig, LucaGPLMModel)
48
+
49
+ # 3. 注册序列分类模型 (用于 AutoModelForSequenceClassification)
50
+ AutoModelForSequenceClassification.register(LucaGPLMConfig, LucaGPLMForSequenceClassification)
51
+
52
+ # 4. 注册 Token 分类模型 (用于 AutoModelForTokenClassification)
53
+ AutoModelForTokenClassification.register(LucaGPLMConfig, LucaGPLMForTokenClassification)
54
+
55
+ # 5. 注册掩码语言模型 (用于 AutoModelForMaskedLM)
56
+ AutoModelForMaskedLM.register(LucaGPLMConfig, LucaGPLMForMaskedLM)
config.json CHANGED
@@ -1,30 +1,54 @@
1
  {
2
  "alphabet": "gene_prot",
3
  "architectures": [
4
- "LucaGPLMModel"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
7
  "classifier_dropout_prob": 0.0,
 
 
 
 
 
 
8
  "embed_scale": 1.0,
 
9
  "ffn_dim": 10240,
 
10
  "hidden_dropout_prob": 0.0,
11
  "hidden_size": 2560,
12
  "ignore_index": -100,
13
  "initializer_range": 0.02,
14
  "layer_norm_eps": 1e-12,
15
- "max_position_embeddings": 1280,
16
- "model_type": "lucagplm",
 
17
  "no_position_embeddings": true,
18
  "no_token_type_embeddings": false,
19
  "num_attention_heads": 40,
20
  "num_hidden_layers": 20,
21
  "pad_token_id": 0,
22
- "pooling_type": "avg",
23
- "pooling_units": null,
 
 
24
  "token_dropout": false,
25
  "torch_dtype": "float32",
26
- "transformers_version": "4.45.2",
27
  "type_vocab_size": 2,
 
28
  "use_embed_layer_norm": false,
29
  "use_last_layer_norm": true,
30
  "vocab_size": 39
 
1
  {
2
  "alphabet": "gene_prot",
3
  "architectures": [
4
+ "LucaGPLMForMaskedLM"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_lucaone.LucaGPLMConfig",
9
+ "AutoModel": "modeling_lucaone.LucaGPLMModel",
10
+ "AutoModelForMaskedLM": "modeling_lucaone.LucaGPLMForMaskedLM",
11
+ "AutoModelForSequenceClassification": "modeling_lucaone.LucaGPLMForSequenceClassification",
12
+ "AutoModelForTokenClassification": "modeling_lucaone.LucaGPLMForTokenClassification",
13
+ "AutoTokenizer": [
14
+ "tokenization_lucaone.LucaGPLMTokenizer",
15
+ null
16
+ ]
17
+ },
18
+ "bos_token_id": 2,
19
  "classifier_dropout_prob": 0.0,
20
+ "classifier_loss_reduction": "mean",
21
+ "classifier_loss_type": "cross_entropy",
22
+ "classifier_num_labels": -1,
23
+ "classifier_pooling_type": "value_attention",
24
+ "classifier_pos_weight": 1.0,
25
+ "classifier_weight": null,
26
  "embed_scale": 1.0,
27
+ "eos_token_id": 3,
28
  "ffn_dim": 10240,
29
+ "hidden_act": "gelu",
30
  "hidden_dropout_prob": 0.0,
31
  "hidden_size": 2560,
32
  "ignore_index": -100,
33
  "initializer_range": 0.02,
34
  "layer_norm_eps": 1e-12,
35
+ "mask_token_id": 4,
36
+ "max_position_embeddings": 4096,
37
+ "model_type": "lucaone",
38
  "no_position_embeddings": true,
39
  "no_token_type_embeddings": false,
40
  "num_attention_heads": 40,
41
  "num_hidden_layers": 20,
42
  "pad_token_id": 0,
43
+ "sep_token_id": 3,
44
+ "task_level": "seq_level",
45
+ "task_type": "embedding",
46
+ "tie_word_embeddings": false,
47
  "token_dropout": false,
48
  "torch_dtype": "float32",
49
+ "transformers_version": "4.41.2",
50
  "type_vocab_size": 2,
51
+ "unk_token_id": 1,
52
  "use_embed_layer_norm": false,
53
  "use_last_layer_norm": true,
54
  "vocab_size": 39
configuration_lucaone.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2025, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2025/12/30 11:34
9
+ @project: lucaone
10
+ @file: tokenization_lucaone
11
+ @desc: tokenization_lucaone
12
+ '''
13
+
14
+ from typing import Literal
15
+ from transformers import PretrainedConfig
16
+
17
+ class LucaGPLMConfig(PretrainedConfig):
18
+ model_type = "lucaone"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size: int = 39,
23
+ pad_token_id: int = 0,
24
+ unk_token_id: int = 1,
25
+ bos_token_id: int = 2,
26
+ eos_token_id: int = 3,
27
+ sep_token_id: int = 3,
28
+ mask_token_id: int = 4,
29
+ hidden_act: str = "gelu",
30
+ max_position_embeddings: int = 4096,
31
+ type_vocab_size: int = 2,
32
+ num_hidden_layers: int = 20,
33
+ num_attention_heads: int = 40,
34
+ hidden_size: int = 2560,
35
+ ffn_dim: int = 10240,
36
+ no_position_embeddings: bool = True,
37
+ no_token_type_embeddings: bool = False,
38
+ alphabet: str = "gene_prot",
39
+ token_dropout: bool = False,
40
+ attention_probs_dropout_prob: float = 0.0,
41
+ hidden_dropout_prob: float = 0.0,
42
+ use_embed_layer_norm: bool = False,
43
+ use_last_layer_norm: bool = True,
44
+ embed_scale: float = 1.0,
45
+ ignore_index: int = -100,
46
+ layer_norm_eps: float = 1e-12,
47
+ initializer_range: float = 0.02,
48
+ task_level: Literal["seq_level", "token_level"] = "seq_level",
49
+ task_type: Literal["embedding", "mlm", "multi_class", "binary_class", "regression", "multi_label"] = "embedding",
50
+ classifier_num_labels: int = -1,
51
+ classifier_dropout_prob: float = 0.1,
52
+ classifier_pooling_type: Literal["cls", "value_attention", "context_attention", "mean"] = "value_attention",
53
+ classifier_loss_type: Literal["binary_cross_entropy", "cross_entropy", "mse", "mae"] = "cross_entropy",
54
+ classifier_loss_reduction: Literal["mean", "sum", "none"] = "mean",
55
+ classifier_pos_weight: float=1.0,
56
+ classifier_weight: list=None,
57
+ tie_word_embeddings: bool=True,
58
+ **kwargs
59
+ ):
60
+ super().__init__(
61
+ tie_word_embeddings=tie_word_embeddings,
62
+ pad_token_id=pad_token_id,
63
+ **kwargs
64
+ )
65
+
66
+ self.alphabet = alphabet
67
+ self.vocab_size = vocab_size
68
+ self.max_position_embeddings = max_position_embeddings
69
+ self.type_vocab_size = type_vocab_size
70
+ self.no_token_type_embeddings = no_token_type_embeddings
71
+ self.no_position_embeddings = no_position_embeddings
72
+ self.num_hidden_layers = num_hidden_layers
73
+ self.hidden_size = hidden_size
74
+ self.num_attention_heads = num_attention_heads
75
+ self.ffn_dim = ffn_dim
76
+ self.token_dropout = token_dropout
77
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
78
+ self.hidden_dropout_prob = hidden_dropout_prob
79
+ self.classifier_dropout_prob = classifier_dropout_prob
80
+ self.ignore_index = ignore_index
81
+ self.use_embed_layer_norm = use_embed_layer_norm
82
+ self.use_last_layer_norm = use_last_layer_norm
83
+ self.embed_scale = embed_scale
84
+ self.layer_norm_eps = layer_norm_eps
85
+ self.initializer_range = initializer_range
86
+ self.unk_token_id = unk_token_id
87
+ self.bos_token_id = bos_token_id
88
+ self.eos_token_id = eos_token_id
89
+ self.sep_token_id = sep_token_id
90
+ self.mask_token_id = mask_token_id
91
+ self.hidden_act = hidden_act
92
+ self.classifier_num_labels = classifier_num_labels
93
+ self.classifier_pooling_type = classifier_pooling_type
94
+ self.task_level = task_level
95
+ self.task_type = task_type
96
+ self.classifier_loss_type = classifier_loss_type
97
+ self.classifier_loss_reduction = classifier_loss_reduction
98
+ self.classifier_pos_weight = classifier_pos_weight
99
+ self.classifier_weight = classifier_weight
100
+
101
+
102
+ __all__ = ["LucaGPLMConfig"]
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:abba84d3e29bcafdd171cdbc587402c93694831bc41f276a7007ce23f8ca7aca
3
- size 4930878944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4f62f17030bbf8353cfabd8d8291688b9b7a6632d55710178651b4ea8a529c6
3
+ size 4930881120
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9882d0048f5392ccb8752df13a963e49a7328de7cd40c0adcb54c56fc1fb39b8
3
- size 1363720440
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b2b9ec5fa2159a09c671e5d0dde6c1ea9537219f5ceab51a53e189cb79c5fb5
3
+ size 1390366196
model.safetensors.index.json CHANGED
@@ -1,351 +1,357 @@
1
  {
2
  "metadata": {
3
- "total_size": 6294561280
4
  },
5
  "weight_map": {
6
- "embeddings.embed_tokens.weight": "model-00001-of-00002.safetensors",
7
- "embeddings.embed_type.weight": "model-00001-of-00002.safetensors",
8
- "encoder.last_layer_norm.bias": "model-00002-of-00002.safetensors",
9
- "encoder.last_layer_norm.weight": "model-00002-of-00002.safetensors",
10
- "encoder.layers.0.fc1.bias": "model-00001-of-00002.safetensors",
11
- "encoder.layers.0.fc1.weight": "model-00001-of-00002.safetensors",
12
- "encoder.layers.0.fc2.bias": "model-00001-of-00002.safetensors",
13
- "encoder.layers.0.fc2.weight": "model-00001-of-00002.safetensors",
14
- "encoder.layers.0.post_layer_norm.bias": "model-00001-of-00002.safetensors",
15
- "encoder.layers.0.post_layer_norm.weight": "model-00001-of-00002.safetensors",
16
- "encoder.layers.0.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
17
- "encoder.layers.0.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
18
- "encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
19
- "encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
20
- "encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
21
- "encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
22
- "encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
23
- "encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
24
- "encoder.layers.0.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
25
- "encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
26
- "encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
27
- "encoder.layers.1.fc1.bias": "model-00001-of-00002.safetensors",
28
- "encoder.layers.1.fc1.weight": "model-00001-of-00002.safetensors",
29
- "encoder.layers.1.fc2.bias": "model-00001-of-00002.safetensors",
30
- "encoder.layers.1.fc2.weight": "model-00001-of-00002.safetensors",
31
- "encoder.layers.1.post_layer_norm.bias": "model-00001-of-00002.safetensors",
32
- "encoder.layers.1.post_layer_norm.weight": "model-00001-of-00002.safetensors",
33
- "encoder.layers.1.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
34
- "encoder.layers.1.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
35
- "encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
36
- "encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
37
- "encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
38
- "encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
39
- "encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
40
- "encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
41
- "encoder.layers.1.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
42
- "encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
43
- "encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
44
- "encoder.layers.10.fc1.bias": "model-00001-of-00002.safetensors",
45
- "encoder.layers.10.fc1.weight": "model-00001-of-00002.safetensors",
46
- "encoder.layers.10.fc2.bias": "model-00001-of-00002.safetensors",
47
- "encoder.layers.10.fc2.weight": "model-00001-of-00002.safetensors",
48
- "encoder.layers.10.post_layer_norm.bias": "model-00001-of-00002.safetensors",
49
- "encoder.layers.10.post_layer_norm.weight": "model-00001-of-00002.safetensors",
50
- "encoder.layers.10.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
51
- "encoder.layers.10.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
52
- "encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
53
- "encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
54
- "encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
55
- "encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
56
- "encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
57
- "encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
58
- "encoder.layers.10.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
59
- "encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
60
- "encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
61
- "encoder.layers.11.fc1.bias": "model-00001-of-00002.safetensors",
62
- "encoder.layers.11.fc1.weight": "model-00001-of-00002.safetensors",
63
- "encoder.layers.11.fc2.bias": "model-00001-of-00002.safetensors",
64
- "encoder.layers.11.fc2.weight": "model-00001-of-00002.safetensors",
65
- "encoder.layers.11.post_layer_norm.bias": "model-00001-of-00002.safetensors",
66
- "encoder.layers.11.post_layer_norm.weight": "model-00001-of-00002.safetensors",
67
- "encoder.layers.11.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
68
- "encoder.layers.11.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
69
- "encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
70
- "encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
71
- "encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
72
- "encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
73
- "encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
74
- "encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
75
- "encoder.layers.11.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
76
- "encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
77
- "encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
78
- "encoder.layers.12.fc1.bias": "model-00001-of-00002.safetensors",
79
- "encoder.layers.12.fc1.weight": "model-00001-of-00002.safetensors",
80
- "encoder.layers.12.fc2.bias": "model-00001-of-00002.safetensors",
81
- "encoder.layers.12.fc2.weight": "model-00001-of-00002.safetensors",
82
- "encoder.layers.12.post_layer_norm.bias": "model-00001-of-00002.safetensors",
83
- "encoder.layers.12.post_layer_norm.weight": "model-00001-of-00002.safetensors",
84
- "encoder.layers.12.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
85
- "encoder.layers.12.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
86
- "encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
87
- "encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
88
- "encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
89
- "encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
90
- "encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
91
- "encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
92
- "encoder.layers.12.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
93
- "encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
94
- "encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
95
- "encoder.layers.13.fc1.bias": "model-00001-of-00002.safetensors",
96
- "encoder.layers.13.fc1.weight": "model-00001-of-00002.safetensors",
97
- "encoder.layers.13.fc2.bias": "model-00001-of-00002.safetensors",
98
- "encoder.layers.13.fc2.weight": "model-00001-of-00002.safetensors",
99
- "encoder.layers.13.post_layer_norm.bias": "model-00001-of-00002.safetensors",
100
- "encoder.layers.13.post_layer_norm.weight": "model-00001-of-00002.safetensors",
101
- "encoder.layers.13.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
102
- "encoder.layers.13.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
103
- "encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
104
- "encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
105
- "encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
106
- "encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
107
- "encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
108
- "encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
109
- "encoder.layers.13.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
110
- "encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
111
- "encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
112
- "encoder.layers.14.fc1.bias": "model-00001-of-00002.safetensors",
113
- "encoder.layers.14.fc1.weight": "model-00001-of-00002.safetensors",
114
- "encoder.layers.14.fc2.bias": "model-00001-of-00002.safetensors",
115
- "encoder.layers.14.fc2.weight": "model-00001-of-00002.safetensors",
116
- "encoder.layers.14.post_layer_norm.bias": "model-00001-of-00002.safetensors",
117
- "encoder.layers.14.post_layer_norm.weight": "model-00001-of-00002.safetensors",
118
- "encoder.layers.14.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
119
- "encoder.layers.14.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
120
- "encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
121
- "encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
122
- "encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
123
- "encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
124
- "encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
125
- "encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
126
- "encoder.layers.14.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
127
- "encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
128
- "encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
129
- "encoder.layers.15.fc1.bias": "model-00001-of-00002.safetensors",
130
- "encoder.layers.15.fc1.weight": "model-00001-of-00002.safetensors",
131
- "encoder.layers.15.fc2.bias": "model-00002-of-00002.safetensors",
132
- "encoder.layers.15.fc2.weight": "model-00002-of-00002.safetensors",
133
- "encoder.layers.15.post_layer_norm.bias": "model-00001-of-00002.safetensors",
134
- "encoder.layers.15.post_layer_norm.weight": "model-00001-of-00002.safetensors",
135
- "encoder.layers.15.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
136
- "encoder.layers.15.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
137
- "encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
138
- "encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
139
- "encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
140
- "encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
141
- "encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
142
- "encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
143
- "encoder.layers.15.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
144
- "encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
145
- "encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
146
- "encoder.layers.16.fc1.bias": "model-00002-of-00002.safetensors",
147
- "encoder.layers.16.fc1.weight": "model-00002-of-00002.safetensors",
148
- "encoder.layers.16.fc2.bias": "model-00002-of-00002.safetensors",
149
- "encoder.layers.16.fc2.weight": "model-00002-of-00002.safetensors",
150
- "encoder.layers.16.post_layer_norm.bias": "model-00002-of-00002.safetensors",
151
- "encoder.layers.16.post_layer_norm.weight": "model-00002-of-00002.safetensors",
152
- "encoder.layers.16.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
153
- "encoder.layers.16.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
154
- "encoder.layers.16.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
155
- "encoder.layers.16.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
156
- "encoder.layers.16.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
157
- "encoder.layers.16.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
158
- "encoder.layers.16.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
159
- "encoder.layers.16.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
160
- "encoder.layers.16.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
161
- "encoder.layers.16.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
162
- "encoder.layers.16.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
163
- "encoder.layers.17.fc1.bias": "model-00002-of-00002.safetensors",
164
- "encoder.layers.17.fc1.weight": "model-00002-of-00002.safetensors",
165
- "encoder.layers.17.fc2.bias": "model-00002-of-00002.safetensors",
166
- "encoder.layers.17.fc2.weight": "model-00002-of-00002.safetensors",
167
- "encoder.layers.17.post_layer_norm.bias": "model-00002-of-00002.safetensors",
168
- "encoder.layers.17.post_layer_norm.weight": "model-00002-of-00002.safetensors",
169
- "encoder.layers.17.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
170
- "encoder.layers.17.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
171
- "encoder.layers.17.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
172
- "encoder.layers.17.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
173
- "encoder.layers.17.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
174
- "encoder.layers.17.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
175
- "encoder.layers.17.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
176
- "encoder.layers.17.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
177
- "encoder.layers.17.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
178
- "encoder.layers.17.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
179
- "encoder.layers.17.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
180
- "encoder.layers.18.fc1.bias": "model-00002-of-00002.safetensors",
181
- "encoder.layers.18.fc1.weight": "model-00002-of-00002.safetensors",
182
- "encoder.layers.18.fc2.bias": "model-00002-of-00002.safetensors",
183
- "encoder.layers.18.fc2.weight": "model-00002-of-00002.safetensors",
184
- "encoder.layers.18.post_layer_norm.bias": "model-00002-of-00002.safetensors",
185
- "encoder.layers.18.post_layer_norm.weight": "model-00002-of-00002.safetensors",
186
- "encoder.layers.18.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
187
- "encoder.layers.18.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
188
- "encoder.layers.18.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
189
- "encoder.layers.18.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
190
- "encoder.layers.18.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
191
- "encoder.layers.18.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
192
- "encoder.layers.18.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
193
- "encoder.layers.18.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
194
- "encoder.layers.18.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
195
- "encoder.layers.18.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
196
- "encoder.layers.18.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
197
- "encoder.layers.19.fc1.bias": "model-00002-of-00002.safetensors",
198
- "encoder.layers.19.fc1.weight": "model-00002-of-00002.safetensors",
199
- "encoder.layers.19.fc2.bias": "model-00002-of-00002.safetensors",
200
- "encoder.layers.19.fc2.weight": "model-00002-of-00002.safetensors",
201
- "encoder.layers.19.post_layer_norm.bias": "model-00002-of-00002.safetensors",
202
- "encoder.layers.19.post_layer_norm.weight": "model-00002-of-00002.safetensors",
203
- "encoder.layers.19.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
204
- "encoder.layers.19.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
205
- "encoder.layers.19.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
206
- "encoder.layers.19.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
207
- "encoder.layers.19.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
208
- "encoder.layers.19.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
209
- "encoder.layers.19.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
210
- "encoder.layers.19.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
211
- "encoder.layers.19.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
212
- "encoder.layers.19.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
213
- "encoder.layers.19.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
214
- "encoder.layers.2.fc1.bias": "model-00001-of-00002.safetensors",
215
- "encoder.layers.2.fc1.weight": "model-00001-of-00002.safetensors",
216
- "encoder.layers.2.fc2.bias": "model-00001-of-00002.safetensors",
217
- "encoder.layers.2.fc2.weight": "model-00001-of-00002.safetensors",
218
- "encoder.layers.2.post_layer_norm.bias": "model-00001-of-00002.safetensors",
219
- "encoder.layers.2.post_layer_norm.weight": "model-00001-of-00002.safetensors",
220
- "encoder.layers.2.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
221
- "encoder.layers.2.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
222
- "encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
223
- "encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
224
- "encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
225
- "encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
226
- "encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
227
- "encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
228
- "encoder.layers.2.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
229
- "encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
230
- "encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
231
- "encoder.layers.3.fc1.bias": "model-00001-of-00002.safetensors",
232
- "encoder.layers.3.fc1.weight": "model-00001-of-00002.safetensors",
233
- "encoder.layers.3.fc2.bias": "model-00001-of-00002.safetensors",
234
- "encoder.layers.3.fc2.weight": "model-00001-of-00002.safetensors",
235
- "encoder.layers.3.post_layer_norm.bias": "model-00001-of-00002.safetensors",
236
- "encoder.layers.3.post_layer_norm.weight": "model-00001-of-00002.safetensors",
237
- "encoder.layers.3.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
238
- "encoder.layers.3.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
239
- "encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
240
- "encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
241
- "encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
242
- "encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
243
- "encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
244
- "encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
245
- "encoder.layers.3.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
246
- "encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
247
- "encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
248
- "encoder.layers.4.fc1.bias": "model-00001-of-00002.safetensors",
249
- "encoder.layers.4.fc1.weight": "model-00001-of-00002.safetensors",
250
- "encoder.layers.4.fc2.bias": "model-00001-of-00002.safetensors",
251
- "encoder.layers.4.fc2.weight": "model-00001-of-00002.safetensors",
252
- "encoder.layers.4.post_layer_norm.bias": "model-00001-of-00002.safetensors",
253
- "encoder.layers.4.post_layer_norm.weight": "model-00001-of-00002.safetensors",
254
- "encoder.layers.4.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
255
- "encoder.layers.4.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
256
- "encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
257
- "encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
258
- "encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
259
- "encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
260
- "encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
261
- "encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
262
- "encoder.layers.4.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
263
- "encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
264
- "encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
265
- "encoder.layers.5.fc1.bias": "model-00001-of-00002.safetensors",
266
- "encoder.layers.5.fc1.weight": "model-00001-of-00002.safetensors",
267
- "encoder.layers.5.fc2.bias": "model-00001-of-00002.safetensors",
268
- "encoder.layers.5.fc2.weight": "model-00001-of-00002.safetensors",
269
- "encoder.layers.5.post_layer_norm.bias": "model-00001-of-00002.safetensors",
270
- "encoder.layers.5.post_layer_norm.weight": "model-00001-of-00002.safetensors",
271
- "encoder.layers.5.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
272
- "encoder.layers.5.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
273
- "encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
274
- "encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
275
- "encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
276
- "encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
277
- "encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
278
- "encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
279
- "encoder.layers.5.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
280
- "encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
281
- "encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
282
- "encoder.layers.6.fc1.bias": "model-00001-of-00002.safetensors",
283
- "encoder.layers.6.fc1.weight": "model-00001-of-00002.safetensors",
284
- "encoder.layers.6.fc2.bias": "model-00001-of-00002.safetensors",
285
- "encoder.layers.6.fc2.weight": "model-00001-of-00002.safetensors",
286
- "encoder.layers.6.post_layer_norm.bias": "model-00001-of-00002.safetensors",
287
- "encoder.layers.6.post_layer_norm.weight": "model-00001-of-00002.safetensors",
288
- "encoder.layers.6.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
289
- "encoder.layers.6.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
290
- "encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
291
- "encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
292
- "encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
293
- "encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
294
- "encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
295
- "encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
296
- "encoder.layers.6.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
297
- "encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
298
- "encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
299
- "encoder.layers.7.fc1.bias": "model-00001-of-00002.safetensors",
300
- "encoder.layers.7.fc1.weight": "model-00001-of-00002.safetensors",
301
- "encoder.layers.7.fc2.bias": "model-00001-of-00002.safetensors",
302
- "encoder.layers.7.fc2.weight": "model-00001-of-00002.safetensors",
303
- "encoder.layers.7.post_layer_norm.bias": "model-00001-of-00002.safetensors",
304
- "encoder.layers.7.post_layer_norm.weight": "model-00001-of-00002.safetensors",
305
- "encoder.layers.7.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
306
- "encoder.layers.7.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
307
- "encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
308
- "encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
309
- "encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
310
- "encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
311
- "encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
312
- "encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
313
- "encoder.layers.7.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
314
- "encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
315
- "encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
316
- "encoder.layers.8.fc1.bias": "model-00001-of-00002.safetensors",
317
- "encoder.layers.8.fc1.weight": "model-00001-of-00002.safetensors",
318
- "encoder.layers.8.fc2.bias": "model-00001-of-00002.safetensors",
319
- "encoder.layers.8.fc2.weight": "model-00001-of-00002.safetensors",
320
- "encoder.layers.8.post_layer_norm.bias": "model-00001-of-00002.safetensors",
321
- "encoder.layers.8.post_layer_norm.weight": "model-00001-of-00002.safetensors",
322
- "encoder.layers.8.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
323
- "encoder.layers.8.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
324
- "encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
325
- "encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
326
- "encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
327
- "encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
328
- "encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
329
- "encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
330
- "encoder.layers.8.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
331
- "encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
332
- "encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
333
- "encoder.layers.9.fc1.bias": "model-00001-of-00002.safetensors",
334
- "encoder.layers.9.fc1.weight": "model-00001-of-00002.safetensors",
335
- "encoder.layers.9.fc2.bias": "model-00001-of-00002.safetensors",
336
- "encoder.layers.9.fc2.weight": "model-00001-of-00002.safetensors",
337
- "encoder.layers.9.post_layer_norm.bias": "model-00001-of-00002.safetensors",
338
- "encoder.layers.9.post_layer_norm.weight": "model-00001-of-00002.safetensors",
339
- "encoder.layers.9.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
340
- "encoder.layers.9.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
341
- "encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
342
- "encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
343
- "encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
344
- "encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
345
- "encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
346
- "encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
347
- "encoder.layers.9.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
348
- "encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
349
- "encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors"
 
 
 
 
 
 
350
  }
351
  }
 
1
  {
2
  "metadata": {
3
+ "total_size": 6321205916
4
  },
5
  "weight_map": {
6
+ "lm_head.bias": "model-00002-of-00002.safetensors",
7
+ "lm_head.decoder.weight": "model-00002-of-00002.safetensors",
8
+ "lm_head.dense.bias": "model-00002-of-00002.safetensors",
9
+ "lm_head.dense.weight": "model-00002-of-00002.safetensors",
10
+ "lm_head.layer_norm.bias": "model-00002-of-00002.safetensors",
11
+ "lm_head.layer_norm.weight": "model-00002-of-00002.safetensors",
12
+ "lucaone.embeddings.embed_tokens.weight": "model-00001-of-00002.safetensors",
13
+ "lucaone.embeddings.embed_type.weight": "model-00001-of-00002.safetensors",
14
+ "lucaone.encoder.last_layer_norm.bias": "model-00002-of-00002.safetensors",
15
+ "lucaone.encoder.last_layer_norm.weight": "model-00002-of-00002.safetensors",
16
+ "lucaone.encoder.layers.0.fc1.bias": "model-00001-of-00002.safetensors",
17
+ "lucaone.encoder.layers.0.fc1.weight": "model-00001-of-00002.safetensors",
18
+ "lucaone.encoder.layers.0.fc2.bias": "model-00001-of-00002.safetensors",
19
+ "lucaone.encoder.layers.0.fc2.weight": "model-00001-of-00002.safetensors",
20
+ "lucaone.encoder.layers.0.post_layer_norm.bias": "model-00001-of-00002.safetensors",
21
+ "lucaone.encoder.layers.0.post_layer_norm.weight": "model-00001-of-00002.safetensors",
22
+ "lucaone.encoder.layers.0.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
23
+ "lucaone.encoder.layers.0.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
24
+ "lucaone.encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
25
+ "lucaone.encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
26
+ "lucaone.encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
27
+ "lucaone.encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
28
+ "lucaone.encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
29
+ "lucaone.encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
30
+ "lucaone.encoder.layers.0.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
31
+ "lucaone.encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
32
+ "lucaone.encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
33
+ "lucaone.encoder.layers.1.fc1.bias": "model-00001-of-00002.safetensors",
34
+ "lucaone.encoder.layers.1.fc1.weight": "model-00001-of-00002.safetensors",
35
+ "lucaone.encoder.layers.1.fc2.bias": "model-00001-of-00002.safetensors",
36
+ "lucaone.encoder.layers.1.fc2.weight": "model-00001-of-00002.safetensors",
37
+ "lucaone.encoder.layers.1.post_layer_norm.bias": "model-00001-of-00002.safetensors",
38
+ "lucaone.encoder.layers.1.post_layer_norm.weight": "model-00001-of-00002.safetensors",
39
+ "lucaone.encoder.layers.1.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
40
+ "lucaone.encoder.layers.1.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
41
+ "lucaone.encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
42
+ "lucaone.encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
43
+ "lucaone.encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
44
+ "lucaone.encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
45
+ "lucaone.encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
46
+ "lucaone.encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
47
+ "lucaone.encoder.layers.1.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
48
+ "lucaone.encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
49
+ "lucaone.encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
50
+ "lucaone.encoder.layers.10.fc1.bias": "model-00001-of-00002.safetensors",
51
+ "lucaone.encoder.layers.10.fc1.weight": "model-00001-of-00002.safetensors",
52
+ "lucaone.encoder.layers.10.fc2.bias": "model-00001-of-00002.safetensors",
53
+ "lucaone.encoder.layers.10.fc2.weight": "model-00001-of-00002.safetensors",
54
+ "lucaone.encoder.layers.10.post_layer_norm.bias": "model-00001-of-00002.safetensors",
55
+ "lucaone.encoder.layers.10.post_layer_norm.weight": "model-00001-of-00002.safetensors",
56
+ "lucaone.encoder.layers.10.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
57
+ "lucaone.encoder.layers.10.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
58
+ "lucaone.encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
59
+ "lucaone.encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
60
+ "lucaone.encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
61
+ "lucaone.encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
62
+ "lucaone.encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
63
+ "lucaone.encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
64
+ "lucaone.encoder.layers.10.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
65
+ "lucaone.encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
66
+ "lucaone.encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
67
+ "lucaone.encoder.layers.11.fc1.bias": "model-00001-of-00002.safetensors",
68
+ "lucaone.encoder.layers.11.fc1.weight": "model-00001-of-00002.safetensors",
69
+ "lucaone.encoder.layers.11.fc2.bias": "model-00001-of-00002.safetensors",
70
+ "lucaone.encoder.layers.11.fc2.weight": "model-00001-of-00002.safetensors",
71
+ "lucaone.encoder.layers.11.post_layer_norm.bias": "model-00001-of-00002.safetensors",
72
+ "lucaone.encoder.layers.11.post_layer_norm.weight": "model-00001-of-00002.safetensors",
73
+ "lucaone.encoder.layers.11.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
74
+ "lucaone.encoder.layers.11.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
75
+ "lucaone.encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
76
+ "lucaone.encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
77
+ "lucaone.encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
78
+ "lucaone.encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
79
+ "lucaone.encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
80
+ "lucaone.encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
81
+ "lucaone.encoder.layers.11.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
82
+ "lucaone.encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
83
+ "lucaone.encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
84
+ "lucaone.encoder.layers.12.fc1.bias": "model-00001-of-00002.safetensors",
85
+ "lucaone.encoder.layers.12.fc1.weight": "model-00001-of-00002.safetensors",
86
+ "lucaone.encoder.layers.12.fc2.bias": "model-00001-of-00002.safetensors",
87
+ "lucaone.encoder.layers.12.fc2.weight": "model-00001-of-00002.safetensors",
88
+ "lucaone.encoder.layers.12.post_layer_norm.bias": "model-00001-of-00002.safetensors",
89
+ "lucaone.encoder.layers.12.post_layer_norm.weight": "model-00001-of-00002.safetensors",
90
+ "lucaone.encoder.layers.12.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
91
+ "lucaone.encoder.layers.12.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
92
+ "lucaone.encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
93
+ "lucaone.encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
94
+ "lucaone.encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
95
+ "lucaone.encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
96
+ "lucaone.encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
97
+ "lucaone.encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
98
+ "lucaone.encoder.layers.12.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
99
+ "lucaone.encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
100
+ "lucaone.encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
101
+ "lucaone.encoder.layers.13.fc1.bias": "model-00001-of-00002.safetensors",
102
+ "lucaone.encoder.layers.13.fc1.weight": "model-00001-of-00002.safetensors",
103
+ "lucaone.encoder.layers.13.fc2.bias": "model-00001-of-00002.safetensors",
104
+ "lucaone.encoder.layers.13.fc2.weight": "model-00001-of-00002.safetensors",
105
+ "lucaone.encoder.layers.13.post_layer_norm.bias": "model-00001-of-00002.safetensors",
106
+ "lucaone.encoder.layers.13.post_layer_norm.weight": "model-00001-of-00002.safetensors",
107
+ "lucaone.encoder.layers.13.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
108
+ "lucaone.encoder.layers.13.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
109
+ "lucaone.encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
110
+ "lucaone.encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
111
+ "lucaone.encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
112
+ "lucaone.encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
113
+ "lucaone.encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
114
+ "lucaone.encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
115
+ "lucaone.encoder.layers.13.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
116
+ "lucaone.encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
117
+ "lucaone.encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
118
+ "lucaone.encoder.layers.14.fc1.bias": "model-00001-of-00002.safetensors",
119
+ "lucaone.encoder.layers.14.fc1.weight": "model-00001-of-00002.safetensors",
120
+ "lucaone.encoder.layers.14.fc2.bias": "model-00001-of-00002.safetensors",
121
+ "lucaone.encoder.layers.14.fc2.weight": "model-00001-of-00002.safetensors",
122
+ "lucaone.encoder.layers.14.post_layer_norm.bias": "model-00001-of-00002.safetensors",
123
+ "lucaone.encoder.layers.14.post_layer_norm.weight": "model-00001-of-00002.safetensors",
124
+ "lucaone.encoder.layers.14.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
125
+ "lucaone.encoder.layers.14.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
126
+ "lucaone.encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
127
+ "lucaone.encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
128
+ "lucaone.encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
129
+ "lucaone.encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
130
+ "lucaone.encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
131
+ "lucaone.encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
132
+ "lucaone.encoder.layers.14.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
133
+ "lucaone.encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
134
+ "lucaone.encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
135
+ "lucaone.encoder.layers.15.fc1.bias": "model-00001-of-00002.safetensors",
136
+ "lucaone.encoder.layers.15.fc1.weight": "model-00001-of-00002.safetensors",
137
+ "lucaone.encoder.layers.15.fc2.bias": "model-00002-of-00002.safetensors",
138
+ "lucaone.encoder.layers.15.fc2.weight": "model-00002-of-00002.safetensors",
139
+ "lucaone.encoder.layers.15.post_layer_norm.bias": "model-00001-of-00002.safetensors",
140
+ "lucaone.encoder.layers.15.post_layer_norm.weight": "model-00001-of-00002.safetensors",
141
+ "lucaone.encoder.layers.15.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
142
+ "lucaone.encoder.layers.15.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
143
+ "lucaone.encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
144
+ "lucaone.encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
145
+ "lucaone.encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
146
+ "lucaone.encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
147
+ "lucaone.encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
148
+ "lucaone.encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
149
+ "lucaone.encoder.layers.15.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
150
+ "lucaone.encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
151
+ "lucaone.encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
152
+ "lucaone.encoder.layers.16.fc1.bias": "model-00002-of-00002.safetensors",
153
+ "lucaone.encoder.layers.16.fc1.weight": "model-00002-of-00002.safetensors",
154
+ "lucaone.encoder.layers.16.fc2.bias": "model-00002-of-00002.safetensors",
155
+ "lucaone.encoder.layers.16.fc2.weight": "model-00002-of-00002.safetensors",
156
+ "lucaone.encoder.layers.16.post_layer_norm.bias": "model-00002-of-00002.safetensors",
157
+ "lucaone.encoder.layers.16.post_layer_norm.weight": "model-00002-of-00002.safetensors",
158
+ "lucaone.encoder.layers.16.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
159
+ "lucaone.encoder.layers.16.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
160
+ "lucaone.encoder.layers.16.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
161
+ "lucaone.encoder.layers.16.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
162
+ "lucaone.encoder.layers.16.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
163
+ "lucaone.encoder.layers.16.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
164
+ "lucaone.encoder.layers.16.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
165
+ "lucaone.encoder.layers.16.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
166
+ "lucaone.encoder.layers.16.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
167
+ "lucaone.encoder.layers.16.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
168
+ "lucaone.encoder.layers.16.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
169
+ "lucaone.encoder.layers.17.fc1.bias": "model-00002-of-00002.safetensors",
170
+ "lucaone.encoder.layers.17.fc1.weight": "model-00002-of-00002.safetensors",
171
+ "lucaone.encoder.layers.17.fc2.bias": "model-00002-of-00002.safetensors",
172
+ "lucaone.encoder.layers.17.fc2.weight": "model-00002-of-00002.safetensors",
173
+ "lucaone.encoder.layers.17.post_layer_norm.bias": "model-00002-of-00002.safetensors",
174
+ "lucaone.encoder.layers.17.post_layer_norm.weight": "model-00002-of-00002.safetensors",
175
+ "lucaone.encoder.layers.17.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
176
+ "lucaone.encoder.layers.17.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
177
+ "lucaone.encoder.layers.17.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
178
+ "lucaone.encoder.layers.17.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
179
+ "lucaone.encoder.layers.17.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
180
+ "lucaone.encoder.layers.17.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
181
+ "lucaone.encoder.layers.17.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
182
+ "lucaone.encoder.layers.17.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
183
+ "lucaone.encoder.layers.17.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
184
+ "lucaone.encoder.layers.17.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
185
+ "lucaone.encoder.layers.17.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
186
+ "lucaone.encoder.layers.18.fc1.bias": "model-00002-of-00002.safetensors",
187
+ "lucaone.encoder.layers.18.fc1.weight": "model-00002-of-00002.safetensors",
188
+ "lucaone.encoder.layers.18.fc2.bias": "model-00002-of-00002.safetensors",
189
+ "lucaone.encoder.layers.18.fc2.weight": "model-00002-of-00002.safetensors",
190
+ "lucaone.encoder.layers.18.post_layer_norm.bias": "model-00002-of-00002.safetensors",
191
+ "lucaone.encoder.layers.18.post_layer_norm.weight": "model-00002-of-00002.safetensors",
192
+ "lucaone.encoder.layers.18.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
193
+ "lucaone.encoder.layers.18.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
194
+ "lucaone.encoder.layers.18.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
195
+ "lucaone.encoder.layers.18.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
196
+ "lucaone.encoder.layers.18.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
197
+ "lucaone.encoder.layers.18.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
198
+ "lucaone.encoder.layers.18.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
199
+ "lucaone.encoder.layers.18.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
200
+ "lucaone.encoder.layers.18.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
201
+ "lucaone.encoder.layers.18.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
202
+ "lucaone.encoder.layers.18.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
203
+ "lucaone.encoder.layers.19.fc1.bias": "model-00002-of-00002.safetensors",
204
+ "lucaone.encoder.layers.19.fc1.weight": "model-00002-of-00002.safetensors",
205
+ "lucaone.encoder.layers.19.fc2.bias": "model-00002-of-00002.safetensors",
206
+ "lucaone.encoder.layers.19.fc2.weight": "model-00002-of-00002.safetensors",
207
+ "lucaone.encoder.layers.19.post_layer_norm.bias": "model-00002-of-00002.safetensors",
208
+ "lucaone.encoder.layers.19.post_layer_norm.weight": "model-00002-of-00002.safetensors",
209
+ "lucaone.encoder.layers.19.pre_layer_norm.bias": "model-00002-of-00002.safetensors",
210
+ "lucaone.encoder.layers.19.pre_layer_norm.weight": "model-00002-of-00002.safetensors",
211
+ "lucaone.encoder.layers.19.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
212
+ "lucaone.encoder.layers.19.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
213
+ "lucaone.encoder.layers.19.self_attn.out_proj.bias": "model-00002-of-00002.safetensors",
214
+ "lucaone.encoder.layers.19.self_attn.out_proj.weight": "model-00002-of-00002.safetensors",
215
+ "lucaone.encoder.layers.19.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
216
+ "lucaone.encoder.layers.19.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
217
+ "lucaone.encoder.layers.19.self_attn.rot_emb.inv_freq": "model-00002-of-00002.safetensors",
218
+ "lucaone.encoder.layers.19.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
219
+ "lucaone.encoder.layers.19.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
220
+ "lucaone.encoder.layers.2.fc1.bias": "model-00001-of-00002.safetensors",
221
+ "lucaone.encoder.layers.2.fc1.weight": "model-00001-of-00002.safetensors",
222
+ "lucaone.encoder.layers.2.fc2.bias": "model-00001-of-00002.safetensors",
223
+ "lucaone.encoder.layers.2.fc2.weight": "model-00001-of-00002.safetensors",
224
+ "lucaone.encoder.layers.2.post_layer_norm.bias": "model-00001-of-00002.safetensors",
225
+ "lucaone.encoder.layers.2.post_layer_norm.weight": "model-00001-of-00002.safetensors",
226
+ "lucaone.encoder.layers.2.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
227
+ "lucaone.encoder.layers.2.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
228
+ "lucaone.encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
229
+ "lucaone.encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
230
+ "lucaone.encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
231
+ "lucaone.encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
232
+ "lucaone.encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
233
+ "lucaone.encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
234
+ "lucaone.encoder.layers.2.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
235
+ "lucaone.encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
236
+ "lucaone.encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
237
+ "lucaone.encoder.layers.3.fc1.bias": "model-00001-of-00002.safetensors",
238
+ "lucaone.encoder.layers.3.fc1.weight": "model-00001-of-00002.safetensors",
239
+ "lucaone.encoder.layers.3.fc2.bias": "model-00001-of-00002.safetensors",
240
+ "lucaone.encoder.layers.3.fc2.weight": "model-00001-of-00002.safetensors",
241
+ "lucaone.encoder.layers.3.post_layer_norm.bias": "model-00001-of-00002.safetensors",
242
+ "lucaone.encoder.layers.3.post_layer_norm.weight": "model-00001-of-00002.safetensors",
243
+ "lucaone.encoder.layers.3.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
244
+ "lucaone.encoder.layers.3.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
245
+ "lucaone.encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
246
+ "lucaone.encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
247
+ "lucaone.encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
248
+ "lucaone.encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
249
+ "lucaone.encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
250
+ "lucaone.encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
251
+ "lucaone.encoder.layers.3.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
252
+ "lucaone.encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
253
+ "lucaone.encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
254
+ "lucaone.encoder.layers.4.fc1.bias": "model-00001-of-00002.safetensors",
255
+ "lucaone.encoder.layers.4.fc1.weight": "model-00001-of-00002.safetensors",
256
+ "lucaone.encoder.layers.4.fc2.bias": "model-00001-of-00002.safetensors",
257
+ "lucaone.encoder.layers.4.fc2.weight": "model-00001-of-00002.safetensors",
258
+ "lucaone.encoder.layers.4.post_layer_norm.bias": "model-00001-of-00002.safetensors",
259
+ "lucaone.encoder.layers.4.post_layer_norm.weight": "model-00001-of-00002.safetensors",
260
+ "lucaone.encoder.layers.4.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
261
+ "lucaone.encoder.layers.4.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
262
+ "lucaone.encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
263
+ "lucaone.encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
264
+ "lucaone.encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
265
+ "lucaone.encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
266
+ "lucaone.encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
267
+ "lucaone.encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
268
+ "lucaone.encoder.layers.4.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
269
+ "lucaone.encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
270
+ "lucaone.encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
271
+ "lucaone.encoder.layers.5.fc1.bias": "model-00001-of-00002.safetensors",
272
+ "lucaone.encoder.layers.5.fc1.weight": "model-00001-of-00002.safetensors",
273
+ "lucaone.encoder.layers.5.fc2.bias": "model-00001-of-00002.safetensors",
274
+ "lucaone.encoder.layers.5.fc2.weight": "model-00001-of-00002.safetensors",
275
+ "lucaone.encoder.layers.5.post_layer_norm.bias": "model-00001-of-00002.safetensors",
276
+ "lucaone.encoder.layers.5.post_layer_norm.weight": "model-00001-of-00002.safetensors",
277
+ "lucaone.encoder.layers.5.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
278
+ "lucaone.encoder.layers.5.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
279
+ "lucaone.encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
280
+ "lucaone.encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
281
+ "lucaone.encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
282
+ "lucaone.encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
283
+ "lucaone.encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
284
+ "lucaone.encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
285
+ "lucaone.encoder.layers.5.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
286
+ "lucaone.encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
287
+ "lucaone.encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
288
+ "lucaone.encoder.layers.6.fc1.bias": "model-00001-of-00002.safetensors",
289
+ "lucaone.encoder.layers.6.fc1.weight": "model-00001-of-00002.safetensors",
290
+ "lucaone.encoder.layers.6.fc2.bias": "model-00001-of-00002.safetensors",
291
+ "lucaone.encoder.layers.6.fc2.weight": "model-00001-of-00002.safetensors",
292
+ "lucaone.encoder.layers.6.post_layer_norm.bias": "model-00001-of-00002.safetensors",
293
+ "lucaone.encoder.layers.6.post_layer_norm.weight": "model-00001-of-00002.safetensors",
294
+ "lucaone.encoder.layers.6.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
295
+ "lucaone.encoder.layers.6.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
296
+ "lucaone.encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
297
+ "lucaone.encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
298
+ "lucaone.encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
299
+ "lucaone.encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
300
+ "lucaone.encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
301
+ "lucaone.encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
302
+ "lucaone.encoder.layers.6.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
303
+ "lucaone.encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
304
+ "lucaone.encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
305
+ "lucaone.encoder.layers.7.fc1.bias": "model-00001-of-00002.safetensors",
306
+ "lucaone.encoder.layers.7.fc1.weight": "model-00001-of-00002.safetensors",
307
+ "lucaone.encoder.layers.7.fc2.bias": "model-00001-of-00002.safetensors",
308
+ "lucaone.encoder.layers.7.fc2.weight": "model-00001-of-00002.safetensors",
309
+ "lucaone.encoder.layers.7.post_layer_norm.bias": "model-00001-of-00002.safetensors",
310
+ "lucaone.encoder.layers.7.post_layer_norm.weight": "model-00001-of-00002.safetensors",
311
+ "lucaone.encoder.layers.7.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
312
+ "lucaone.encoder.layers.7.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
313
+ "lucaone.encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
314
+ "lucaone.encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
315
+ "lucaone.encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
316
+ "lucaone.encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
317
+ "lucaone.encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
318
+ "lucaone.encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
319
+ "lucaone.encoder.layers.7.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
320
+ "lucaone.encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
321
+ "lucaone.encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
322
+ "lucaone.encoder.layers.8.fc1.bias": "model-00001-of-00002.safetensors",
323
+ "lucaone.encoder.layers.8.fc1.weight": "model-00001-of-00002.safetensors",
324
+ "lucaone.encoder.layers.8.fc2.bias": "model-00001-of-00002.safetensors",
325
+ "lucaone.encoder.layers.8.fc2.weight": "model-00001-of-00002.safetensors",
326
+ "lucaone.encoder.layers.8.post_layer_norm.bias": "model-00001-of-00002.safetensors",
327
+ "lucaone.encoder.layers.8.post_layer_norm.weight": "model-00001-of-00002.safetensors",
328
+ "lucaone.encoder.layers.8.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
329
+ "lucaone.encoder.layers.8.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
330
+ "lucaone.encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
331
+ "lucaone.encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
332
+ "lucaone.encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
333
+ "lucaone.encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
334
+ "lucaone.encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
335
+ "lucaone.encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
336
+ "lucaone.encoder.layers.8.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
337
+ "lucaone.encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
338
+ "lucaone.encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
339
+ "lucaone.encoder.layers.9.fc1.bias": "model-00001-of-00002.safetensors",
340
+ "lucaone.encoder.layers.9.fc1.weight": "model-00001-of-00002.safetensors",
341
+ "lucaone.encoder.layers.9.fc2.bias": "model-00001-of-00002.safetensors",
342
+ "lucaone.encoder.layers.9.fc2.weight": "model-00001-of-00002.safetensors",
343
+ "lucaone.encoder.layers.9.post_layer_norm.bias": "model-00001-of-00002.safetensors",
344
+ "lucaone.encoder.layers.9.post_layer_norm.weight": "model-00001-of-00002.safetensors",
345
+ "lucaone.encoder.layers.9.pre_layer_norm.bias": "model-00001-of-00002.safetensors",
346
+ "lucaone.encoder.layers.9.pre_layer_norm.weight": "model-00001-of-00002.safetensors",
347
+ "lucaone.encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
348
+ "lucaone.encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
349
+ "lucaone.encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00002.safetensors",
350
+ "lucaone.encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00002.safetensors",
351
+ "lucaone.encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
352
+ "lucaone.encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
353
+ "lucaone.encoder.layers.9.self_attn.rot_emb.inv_freq": "model-00001-of-00002.safetensors",
354
+ "lucaone.encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
355
+ "lucaone.encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors"
356
  }
357
  }
modeling_lucaone.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2025, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2025/12/30 11:35
9
+ @project: lucaone
10
+ @file: modeling_lucaone
11
+ @desc: modeling_lucaone
12
+ '''
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers import PreTrainedModel
18
+ from transformers.modeling_outputs import BaseModelOutput
19
+ from transformers.modeling_outputs import MaskedLMOutput
20
+ from transformers.modeling_outputs import SequenceClassifierOutput
21
+ from transformers.modeling_outputs import TokenClassifierOutput
22
+ from typing import Optional, List, Union, Tuple
23
+ from .configuration_lucaone import LucaGPLMConfig
24
+ try:
25
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
26
+ class LucaGPLM1bLayerNorm(_FusedLayerNorm):
27
+ @torch.jit.unused
28
+ def forward(self, x):
29
+ if not x.is_cuda:
30
+ return super().forward(x)
31
+ else:
32
+ with torch.cuda.device(x.device):
33
+ return super().forward(x)
34
+ except ImportError:
35
+ from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
36
+
37
+ def gelu(x):
38
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
39
+
40
+ def rotate_half(x):
41
+ x1, x2 = x.chunk(2, dim=-1)
42
+ return torch.cat((-x2, x1), dim=-1)
43
+
44
+ def apply_rotary_pos_emb(x, cos, sin):
45
+ cos = cos[:, : x.shape[-2], :]
46
+ sin = sin[:, : x.shape[-2], :]
47
+ return (x * cos) + (rotate_half(x) * sin)
48
+
49
+ class LucaGPLMRotaryEmbedding(torch.nn.Module):
50
+ def __init__(self, dim: int, *_, **__):
51
+ super().__init__()
52
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
53
+ self.register_buffer("inv_freq", inv_freq)
54
+
55
+ self._seq_len_cached = None
56
+ self._cos_cached = None
57
+ self._sin_cached = None
58
+
59
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
60
+ seq_len = x.shape[seq_dimension]
61
+
62
+ if (seq_len != self._seq_len_cached or
63
+ self._cos_cached is None or
64
+ self._sin_cached is None or
65
+ self._cos_cached.device != x.device):
66
+ self._seq_len_cached = seq_len
67
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
68
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
69
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
70
+
71
+ self._cos_cached = emb.cos()[None, :, :]
72
+ self._sin_cached = emb.sin()[None, :, :]
73
+
74
+ return self._cos_cached, self._sin_cached
75
+
76
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
78
+
79
+ return (
80
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
81
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
82
+ )
83
+
84
+ class LucaGPLMGlobalMaskWeightedAttentionPooling1D(nn.Module):
85
+ def __init__(self, embed_size, use_bias=False):
86
+ super(LucaGPLMGlobalMaskWeightedAttentionPooling1D, self).__init__()
87
+ self.embed_size = embed_size
88
+ self.use_bias = use_bias
89
+
90
+ self.W = nn.Parameter(torch.Tensor(self.embed_size))
91
+ nn.init.trunc_normal_(self.W, std=0.01)
92
+ if self.use_bias:
93
+ self.b = nn.Parameter(torch.Tensor(1))
94
+ nn.init.trunc_normal_(self.b, std=0.01)
95
+
96
+ def forward(self, x, mask=None):
97
+ # (B, Len, Embed) x (Embed,) = (B, Len)
98
+ logits = torch.matmul(x, self.W)
99
+ if self.use_bias:
100
+ logits += self.b
101
+
102
+ if mask is not None:
103
+ attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000)
104
+ else:
105
+ attention_probs = nn.Softmax(dim=-1)(logits)
106
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
107
+ return x
108
+
109
+ def __repr__(self):
110
+ return self.__class__.__name__ + ' (' + str(self.embed_size) + (', bias=%r)' % self.use_bias)
111
+
112
+ class LucaGPLMGlobalMaskContextAttentionPooling1D(nn.Module):
113
+ def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
114
+ super(LucaGPLMGlobalMaskContextAttentionPooling1D, self).__init__()
115
+ self.embed_size = embed_size
116
+ self.use_additive_bias = use_additive_bias
117
+ self.use_attention_bias = use_attention_bias
118
+ self.units = units if units else embed_size
119
+
120
+ self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
121
+ self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
122
+ if self.use_additive_bias:
123
+ self.b1 = nn.Parameter(torch.Tensor(self.units))
124
+ nn.init.trunc_normal_(self.b1, std=0.01)
125
+ if self.use_attention_bias:
126
+ self.b2 = nn.Parameter(torch.Tensor(1))
127
+ nn.init.trunc_normal_(self.b2, std=0.01)
128
+
129
+ self.c = nn.Parameter(torch.Tensor(self.units))
130
+
131
+ nn.init.trunc_normal_(self.U, std=0.01)
132
+ nn.init.trunc_normal_(self.V, std=0.01)
133
+ nn.init.trunc_normal_(self.c, std=0.01)
134
+
135
+ def forward(self, x, mask=None):
136
+ # (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
137
+ q = torch.matmul(x, self.U)
138
+ k = torch.matmul(x, self.V)
139
+ if self.use_additive_bias:
140
+ h = torch.tanh(q + k + self.b1)
141
+ else:
142
+ h = torch.tanh(q + k)
143
+
144
+ if self.use_attention_bias:
145
+ e = torch.matmul(h, self.c) + self.b2
146
+ else:
147
+ e = torch.matmul(h, self.c)
148
+ if mask is not None:
149
+ attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000)
150
+ else:
151
+ attention_probs = nn.Softmax(dim=-1)(e)
152
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
153
+ return x
154
+
155
+ def __repr__(self):
156
+ return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias)
157
+
158
+ class LucaGPLMGlobalMaskValueAttentionPooling1D(nn.Module):
159
+ def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
160
+ super(LucaGPLMGlobalMaskValueAttentionPooling1D, self).__init__()
161
+ self.embed_size = embed_size
162
+ self.use_additive_bias = use_additive_bias
163
+ self.use_attention_bias = use_attention_bias
164
+ self.units = units if units else embed_size
165
+
166
+ self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
167
+ self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
168
+ if self.use_additive_bias:
169
+ self.b1 = nn.Parameter(torch.Tensor(self.units))
170
+ nn.init.trunc_normal_(self.b1, std=0.01)
171
+ if self.use_attention_bias:
172
+ self.b2 = nn.Parameter(torch.Tensor(self.embed_size))
173
+ nn.init.trunc_normal_(self.b2, std=0.01)
174
+
175
+ self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size))
176
+
177
+ nn.init.trunc_normal_(self.U, std=0.01)
178
+ nn.init.trunc_normal_(self.V, std=0.01)
179
+ nn.init.trunc_normal_(self.W, std=0.01)
180
+
181
+ def forward(self, x, mask=None):
182
+ # (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
183
+ q = torch.matmul(x, self.U)
184
+ k = torch.matmul(x, self.V)
185
+ if self.use_additive_bias:
186
+ h = torch.tanh(q + k + self.b1)
187
+ else:
188
+ h = torch.tanh(q + k)
189
+
190
+ # (B, Len, Units) x (Units, Embed) = (B, Len, Embed)
191
+ if self.use_attention_bias:
192
+ e = torch.matmul(h, self.W) + self.b2
193
+ else:
194
+ e = torch.matmul(h, self.W)
195
+ if mask is not None:
196
+ attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1))
197
+ else:
198
+ attention_probs = nn.Softmax(dim=1)(e)
199
+ x = torch.sum(attention_probs * x, dim=1)
200
+ return x
201
+
202
+ def __repr__(self):
203
+ return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias)
204
+
205
+ class LucaGPLM1LayerNorm(nn.Module):
206
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
207
+ super().__init__()
208
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
209
+ self.eps = eps
210
+ self.affine = bool(affine)
211
+ if self.affine:
212
+ self.weight = nn.Parameter(torch.ones(hidden_size))
213
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
214
+ else:
215
+ self.weight, self.bias = None, None
216
+
217
+ def forward(self, x):
218
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
219
+ means = x.mean(dims, keepdim=True)
220
+ x_zeromean = x - means
221
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
222
+ x = x_zeromean / torch.sqrt(variances + self.eps)
223
+ if self.affine:
224
+ x = (self.weight * x) + self.bias
225
+ return x
226
+
227
+ class LucaGPLMMultiheadAttention(nn.Module):
228
+ def __init__(
229
+ self,
230
+ embed_dim,
231
+ num_heads,
232
+ kdim=None,
233
+ vdim=None,
234
+ dropout=0.0,
235
+ bias=True,
236
+ add_bias_kv: bool = False,
237
+ add_zero_attn: bool = False,
238
+ self_attention: bool = False,
239
+ encoder_decoder_attention: bool = False,
240
+ use_rotary_embeddings: bool = False,
241
+ ):
242
+ super().__init__()
243
+ self.embed_dim = embed_dim
244
+ self.kdim = kdim if kdim is not None else embed_dim
245
+ self.vdim = vdim if vdim is not None else embed_dim
246
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
247
+
248
+ self.num_heads = num_heads
249
+ self.dropout = dropout
250
+ self.head_dim = embed_dim // num_heads
251
+ assert (
252
+ self.head_dim * num_heads == self.embed_dim
253
+ ), "embed_dim must be divisible by num_heads"
254
+ self.scaling = self.head_dim**-0.5
255
+
256
+ self.self_attention = self_attention
257
+ self.encoder_decoder_attention = encoder_decoder_attention
258
+
259
+ assert not self.self_attention or self.qkv_same_dim, (
260
+ "Self-attention requires query, key and " "value to be of the same size"
261
+ )
262
+
263
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
264
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
265
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
266
+
267
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
268
+
269
+ if add_bias_kv:
270
+ self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
271
+ self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
272
+ else:
273
+ self.bias_k = self.bias_v = None
274
+
275
+ self.add_zero_attn = add_zero_attn
276
+
277
+ self.reset_parameters()
278
+
279
+ self.rot_emb = None
280
+ if use_rotary_embeddings:
281
+ self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim)
282
+
283
+ def reset_parameters(self):
284
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
285
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
286
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
287
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
288
+
289
+ if self.out_proj.bias is not None:
290
+ nn.init.constant_(self.out_proj.bias, 0.0)
291
+ if self.bias_k is not None:
292
+ nn.init.xavier_normal_(self.bias_k)
293
+ if self.bias_v is not None:
294
+ nn.init.xavier_normal_(self.bias_v)
295
+
296
+ def forward(
297
+ self,
298
+ query,
299
+ key: Optional[torch.Tensor] = None,
300
+ value: Optional[torch.Tensor] = None,
301
+ key_padding_mask: Optional[torch.Tensor] = None,
302
+ need_weights: bool = True,
303
+ attn_mask: Optional[torch.Tensor] = None,
304
+ need_head_weights: bool = False,
305
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
306
+ if need_head_weights:
307
+ need_weights = True
308
+
309
+ tgt_len, bsz, embed_dim = query.size()
310
+ assert embed_dim == self.embed_dim
311
+
312
+ if self.self_attention:
313
+ q = self.q_proj(query)
314
+ k = self.k_proj(query)
315
+ v = self.v_proj(query)
316
+ else:
317
+ assert key is not None and value is not None
318
+ q = self.q_proj(query)
319
+ k = self.k_proj(key)
320
+ v = self.v_proj(value)
321
+
322
+ q *= self.scaling
323
+
324
+ if self.bias_k is not None:
325
+ assert self.bias_v is not None
326
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
327
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
328
+ if attn_mask is not None:
329
+ attn_mask = torch.cat(
330
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
331
+ )
332
+ if key_padding_mask is not None:
333
+ key_padding_mask = torch.cat(
334
+ [
335
+ key_padding_mask,
336
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
337
+ ],
338
+ dim=1,
339
+ )
340
+
341
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
342
+ if k is not None:
343
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
344
+ if v is not None:
345
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
346
+
347
+ assert k is not None
348
+ src_len = k.size(1)
349
+
350
+ if self.rot_emb:
351
+ q, k = self.rot_emb(q, k)
352
+
353
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
354
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
355
+
356
+ if attn_mask is not None:
357
+ attn_mask = attn_mask.unsqueeze(0)
358
+ attn_weights += attn_mask
359
+
360
+ if key_padding_mask is not None:
361
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
362
+ attn_weights = attn_weights.masked_fill(
363
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
364
+ )
365
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
366
+
367
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
368
+ attn_weights = attn_weights_float.type_as(attn_weights)
369
+ attn_probs = F.dropout(
370
+ attn_weights_float.type_as(attn_weights),
371
+ p=self.dropout,
372
+ training=self.training,
373
+ )
374
+
375
+ assert v is not None
376
+ attn = torch.bmm(attn_probs, v)
377
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
378
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
379
+ attn = self.out_proj(attn)
380
+
381
+ attn_weights_output: Optional[torch.Tensor] = None
382
+ if need_weights:
383
+ attn_weights_output = attn_weights_float.view(
384
+ bsz, self.num_heads, tgt_len, src_len
385
+ ).type_as(attn).transpose(1, 0)
386
+ if not need_head_weights:
387
+ # average attention weights over heads
388
+ attn_weights_output = attn_weights_output.mean(dim=0)
389
+
390
+ return attn, attn_weights_output
391
+
392
+ class LucaGPLMMultiheadAttentionWithSDPA(nn.Module):
393
+ def __init__(
394
+ self,
395
+ embed_dim,
396
+ num_heads,
397
+ kdim=None,
398
+ vdim=None,
399
+ dropout=0.0,
400
+ bias=True,
401
+ add_bias_kv: bool = False,
402
+ add_zero_attn: bool = False,
403
+ self_attention: bool = False,
404
+ encoder_decoder_attention: bool = False,
405
+ use_rotary_embeddings: bool = True,
406
+ ):
407
+ super().__init__()
408
+ self.embed_dim = embed_dim
409
+ self.kdim = kdim if kdim is not None else embed_dim
410
+ self.vdim = vdim if vdim is not None else embed_dim
411
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
412
+
413
+ self.num_heads = num_heads
414
+ self.dropout = dropout
415
+ self.head_dim = embed_dim // num_heads
416
+ assert (
417
+ self.head_dim * num_heads == self.embed_dim
418
+ ), "embed_dim must be divisible by num_heads"
419
+ self.scaling = self.head_dim**-0.5
420
+
421
+ self.self_attention = self_attention
422
+ self.encoder_decoder_attention = encoder_decoder_attention
423
+
424
+ assert not self.self_attention or self.qkv_same_dim, (
425
+ "Self-attention requires query, key and " "value to be of the same size"
426
+ )
427
+
428
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
429
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
430
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
431
+
432
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
433
+
434
+ if add_bias_kv:
435
+ self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
436
+ self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
437
+ else:
438
+ self.bias_k = self.bias_v = None
439
+
440
+ self.add_zero_attn = add_zero_attn
441
+
442
+ self.reset_parameters()
443
+
444
+ self.rot_emb = None
445
+ if use_rotary_embeddings:
446
+ self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim)
447
+
448
+ def reset_parameters(self):
449
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
450
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
451
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
452
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
453
+
454
+ if self.out_proj.bias is not None:
455
+ nn.init.constant_(self.out_proj.bias, 0.0)
456
+ if self.bias_k is not None:
457
+ nn.init.xavier_normal_(self.bias_k)
458
+ if self.bias_v is not None:
459
+ nn.init.xavier_normal_(self.bias_v)
460
+
461
+ def forward(
462
+ self,
463
+ query,
464
+ key: Optional[torch.Tensor] = None,
465
+ value: Optional[torch.Tensor] = None,
466
+ key_padding_mask: Optional[torch.Tensor] = None,
467
+ need_weights: bool = True,
468
+ attn_mask: Optional[torch.Tensor] = None,
469
+ need_head_weights: bool = False,
470
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
471
+
472
+ tgt_len, bsz, embed_dim = query.size()
473
+ assert embed_dim == self.embed_dim
474
+
475
+ if self.self_attention:
476
+ q = self.q_proj(query)
477
+ k = self.k_proj(query)
478
+ v = self.v_proj(query)
479
+ else:
480
+ assert key is not None and value is not None
481
+ q = self.q_proj(query)
482
+ k = self.k_proj(key)
483
+ v = self.v_proj(value)
484
+
485
+ if self.bias_k is not None:
486
+ assert self.bias_v is not None
487
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
488
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
489
+ if attn_mask is not None:
490
+ attn_mask = torch.cat(
491
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
492
+ )
493
+ if key_padding_mask is not None:
494
+ key_padding_mask = torch.cat(
495
+ [
496
+ key_padding_mask,
497
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
498
+ ],
499
+ dim=1,
500
+ )
501
+
502
+ # ----------------------------------------------------------------------
503
+ # Flash Attention Optimization
504
+ # ----------------------------------------------------------------------
505
+ # 如果不需要返回 head weights 且 PyTorch 版本支持,则使用 Flash Attention
506
+ if not need_head_weights and hasattr(F, "scaled_dot_product_attention"):
507
+ # Reshape inputs to (Batch, Head, Seq_Len, Dim) for SDPA
508
+ # q, k, v input shape: (Seq_Len, Batch, Embed_Dim)
509
+ q_sdpa = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
510
+ k_sdpa = k.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
511
+ v_sdpa = v.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
512
+
513
+ # Apply Rotary Embedding if needed
514
+ if self.rot_emb:
515
+ # Rotary expects inputs (..., Seq_Len, Dim)
516
+ # It handles broadcasting over Batch and Head
517
+ q_sdpa, k_sdpa = self.rot_emb(q_sdpa, k_sdpa)
518
+
519
+ # Prepare Mask
520
+ # SDPA accepts a broadcastable boolean mask or float mask
521
+ # key_padding_mask is (Batch, Seq_Len), True where padding
522
+ sdpa_mask = None
523
+ if attn_mask is not None or key_padding_mask is not None:
524
+ # Start with a float mask suitable for SDPA
525
+ target_shape = (bsz, 1, tgt_len, k_sdpa.size(2))
526
+ sdpa_mask = torch.zeros(target_shape, device=q.device, dtype=q.dtype)
527
+
528
+ if key_padding_mask is not None:
529
+ # key_padding_mask is (Batch, Seq_Len) -> (Batch, 1, 1, Seq_Len)
530
+ sdpa_mask = sdpa_mask.masked_fill(
531
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
532
+ float("-inf")
533
+ )
534
+
535
+ if attn_mask is not None:
536
+ if attn_mask.dim() == 2:
537
+ sdpa_mask = sdpa_mask + attn_mask.unsqueeze(0).unsqueeze(0)
538
+ elif attn_mask.dim() == 3:
539
+ pass
540
+ else:
541
+ sdpa_mask = sdpa_mask + attn_mask
542
+
543
+ # Call Flash Attention
544
+ # 【关键修改】:添加 scale=1.0,因为 q 已经被手动缩放过了
545
+ attn_output = F.scaled_dot_product_attention(
546
+ q_sdpa,
547
+ k_sdpa,
548
+ v_sdpa,
549
+ attn_mask=sdpa_mask,
550
+ dropout_p=self.dropout if self.training else 0.0,
551
+ is_causal=False
552
+ )
553
+
554
+ # Reshape back to (Seq_Len, Batch, Embed_Dim)
555
+ # (B, H, L, D) -> (L, B, H, D) -> (L, B, E)
556
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
557
+
558
+ # Linear projection
559
+ attn_output = self.out_proj(attn_output)
560
+
561
+ # Return None for weights (optimization trade-off)
562
+ return attn_output, None
563
+
564
+ q = q * self.scaling
565
+ # ----------------------------------------------------------------------
566
+ # Original Implementation (Fallback)
567
+ # ----------------------------------------------------------------------
568
+ # print('Fall back to slow implementation.')
569
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
570
+ if k is not None:
571
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
572
+ if v is not None:
573
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
574
+
575
+ assert k is not None
576
+ src_len = k.size(1)
577
+
578
+ if self.rot_emb:
579
+ q, k = self.rot_emb(q, k)
580
+
581
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
582
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
583
+
584
+ if attn_mask is not None:
585
+ attn_mask = attn_mask.unsqueeze(0)
586
+ attn_weights += attn_mask
587
+
588
+ if key_padding_mask is not None:
589
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
590
+ attn_weights = attn_weights.masked_fill(
591
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
592
+ )
593
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
594
+
595
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
596
+ attn_weights = attn_weights_float.type_as(attn_weights)
597
+ attn_probs = F.dropout(
598
+ attn_weights_float.type_as(attn_weights),
599
+ p=self.dropout,
600
+ training=self.training,
601
+ )
602
+
603
+ assert v is not None
604
+ attn = torch.bmm(attn_probs, v)
605
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
606
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
607
+ attn = self.out_proj(attn)
608
+
609
+ attn_weights_output: Optional[torch.Tensor] = None
610
+ if need_weights:
611
+ attn_weights_output = attn_weights_float.view(
612
+ bsz, self.num_heads, tgt_len, src_len
613
+ ).type_as(attn).transpose(1, 0)
614
+ if not need_head_weights:
615
+ # average attention weights over heads
616
+ attn_weights_output = attn_weights_output.mean(dim=0)
617
+
618
+ return attn, attn_weights_output
619
+
620
+ class LucaGPLMRobertaLMHead(nn.Module):
621
+ def __init__(self, embed_dim, output_dim):
622
+ super().__init__()
623
+ self.dense = nn.Linear(embed_dim, embed_dim)
624
+ self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
625
+ # 使用标准的 nn.Linear
626
+ self.decoder = nn.Linear(embed_dim, output_dim, bias=False)
627
+ self.bias = nn.Parameter(torch.zeros(output_dim))
628
+
629
+ def forward(self, features):
630
+ x = self.dense(features)
631
+ x = gelu(x)
632
+ x = self.layer_norm(x)
633
+ # project back to size of vocabulary with bias
634
+ # x = F.linear(x, self.weight) + self.bias
635
+ x = self.decoder(x) + self.bias
636
+ return x
637
+
638
+ class LucaGPLMTransformerLayer(nn.Module):
639
+ def __init__(
640
+ self,
641
+ embed_dim,
642
+ ffn_embed_dim,
643
+ attention_heads,
644
+ add_bias_kv=True,
645
+ use_lucagplm1b_layer_norm=False,
646
+ use_rotary_embeddings: bool=True,
647
+ ):
648
+ super().__init__()
649
+ self.embed_dim = embed_dim
650
+ self.ffn_embed_dim = ffn_embed_dim
651
+ self.attention_heads = attention_heads
652
+ self.use_rotary_embeddings = use_rotary_embeddings
653
+
654
+ LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
655
+
656
+ self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
657
+
658
+ self.self_attn = LucaGPLMMultiheadAttentionWithSDPA(
659
+ self.embed_dim,
660
+ self.attention_heads,
661
+ add_bias_kv=add_bias_kv,
662
+ add_zero_attn=False,
663
+ self_attention=True,
664
+ use_rotary_embeddings=self.use_rotary_embeddings,
665
+ )
666
+
667
+ # post layer norm
668
+ self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
669
+
670
+ # dimension increase by the fully connected layer
671
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
672
+
673
+ # dimension reduction by the fully connected layer
674
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
675
+
676
+ def forward(
677
+ self,
678
+ x,
679
+ self_attn_mask=None,
680
+ self_attn_padding_mask=None,
681
+ need_head_weights=False
682
+ ):
683
+ residual = x
684
+ x = self.pre_layer_norm(x)
685
+ x, attn = self.self_attn(
686
+ query=x,
687
+ key=x,
688
+ value=x,
689
+ key_padding_mask=self_attn_padding_mask,
690
+ need_weights=True,
691
+ need_head_weights=need_head_weights,
692
+ attn_mask=self_attn_mask,
693
+ )
694
+ x = residual + x
695
+
696
+ residual = x
697
+ x = self.post_layer_norm(x)
698
+ x = gelu(self.fc1(x))
699
+ x = self.fc2(x)
700
+ x = residual + x
701
+
702
+ return x, attn
703
+
704
+ class LucaGPLMEmbeddings(nn.Module):
705
+ def __init__(self, config: LucaGPLMConfig):
706
+ super().__init__()
707
+
708
+ # Store config flags for forward pass
709
+ self.no_position_embeddings = getattr(config, 'no_position_embeddings', False)
710
+ self.no_token_type_embeddings = getattr(config, 'no_token_type_embeddings', False)
711
+ self.use_embed_layer_norm = getattr(config, 'use_embed_layer_norm', True)
712
+ self.embed_scale = getattr(config, 'embed_scale', 1.0)
713
+ self.token_dropout = getattr(config, 'token_dropout', False)
714
+
715
+ # Token ids for special tokens (matching old model)
716
+ self.mask_idx = getattr(config, 'mask_token_id', 4)
717
+ self.padding_idx = getattr(config, 'pad_token_id', 0)
718
+
719
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
720
+
721
+ # Only create position embeddings if not disabled
722
+ if not self.no_position_embeddings:
723
+ self.embed_pos = nn.Embedding(config.max_position_embeddings, config.hidden_size)
724
+ else:
725
+ self.embed_pos = None
726
+
727
+ # Only create token type embeddings if not disabled
728
+ if not self.no_token_type_embeddings:
729
+ self.embed_type = nn.Embedding(config.type_vocab_size, config.hidden_size)
730
+ else:
731
+ self.embed_type = None
732
+
733
+ # Only create layer norm if enabled
734
+ if self.use_embed_layer_norm:
735
+ self.embed_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size)
736
+ else:
737
+ self.embed_layer_norm = None
738
+
739
+ def forward(
740
+ self,
741
+ input_ids: torch.Tensor,
742
+ token_type_ids: Optional[torch.Tensor] = None,
743
+ position_ids: Optional[torch.Tensor] = None,
744
+ ) -> torch.Tensor:
745
+ input_shape = input_ids.size()
746
+ seq_length = input_shape[1]
747
+
748
+ # Start with token embeddings and apply embed_scale
749
+ inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
750
+
751
+ # Add position embeddings if enabled
752
+ if not self.no_position_embeddings and self.embed_pos is not None:
753
+ if position_ids is None:
754
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
755
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
756
+ position_embeddings = self.embed_scale * self.embed_pos(position_ids)
757
+ inputs_embeds = inputs_embeds + position_embeddings
758
+
759
+ # Add token type embeddings if enabled
760
+ if not self.no_token_type_embeddings and self.embed_type is not None:
761
+ if token_type_ids is None:
762
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
763
+ token_type_embeddings = self.embed_scale * self.embed_type(token_type_ids)
764
+ inputs_embeds = inputs_embeds + token_type_embeddings
765
+
766
+ # Apply layer norm if enabled
767
+ if self.use_embed_layer_norm and self.embed_layer_norm is not None:
768
+ embeddings = self.embed_layer_norm(inputs_embeds)
769
+ else:
770
+ embeddings = inputs_embeds
771
+
772
+ # Apply token dropout (matching old model behavior)
773
+ if self.token_dropout and self.training:
774
+ # Zero out masked token embeddings
775
+ embeddings = embeddings.masked_fill((input_ids == self.mask_idx).unsqueeze(-1), 0.0)
776
+
777
+ # Apply token dropout scaling
778
+ mask_ratio_train = 0.15 * 0.8
779
+ padding_mask = input_ids.eq(self.padding_idx)
780
+ src_lengths = (~padding_mask).sum(-1)
781
+ mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(embeddings.dtype) / src_lengths
782
+ embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
783
+
784
+ # Apply padding mask to embeddings
785
+ padding_mask = input_ids.eq(self.padding_idx)
786
+ if padding_mask.any():
787
+ embeddings = embeddings * (1 - padding_mask.unsqueeze(-1).type_as(embeddings))
788
+
789
+ return embeddings
790
+
791
+ class LucaGPLMEncoder(nn.Module):
792
+ def __init__(self, config: LucaGPLMConfig):
793
+ super().__init__()
794
+
795
+ self.layers = nn.ModuleList([
796
+ LucaGPLMTransformerLayer(
797
+ config.hidden_size,
798
+ 4 * config.hidden_size, # ffn_embed_dim = 4 * embed_dim
799
+ config.num_attention_heads,
800
+ add_bias_kv=False,
801
+ use_lucagplm1b_layer_norm=True,
802
+ use_rotary_embeddings=True,
803
+ )
804
+ for _ in range(config.num_hidden_layers)
805
+ ])
806
+
807
+ self.use_last_layer_norm = getattr(config, 'use_last_layer_norm', True)
808
+ if self.use_last_layer_norm:
809
+ self.last_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size)
810
+ else:
811
+ self.last_layer_norm = None
812
+
813
+ self.padding_idx = config.pad_token_id
814
+ self.gradient_checkpointing = False
815
+
816
+ def forward(
817
+ self,
818
+ hidden_states: torch.Tensor,
819
+ attention_mask: Optional[torch.Tensor] = None,
820
+ output_attentions: bool = False,
821
+ output_hidden_states: bool = False,
822
+ return_dict: bool = True,
823
+ need_head_weights: bool = False,
824
+ repr_layers: Optional[List[int]] = None,
825
+ use_last_layer_norm: bool = True,
826
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
827
+ all_hidden_states = () if output_hidden_states else None
828
+ all_attentions = () if output_attentions else None
829
+
830
+ if repr_layers is None:
831
+ repr_layers = [-1]
832
+
833
+ # 转换为原始模型的索引系统
834
+ layer_size = len(self.layers)
835
+ repr_layers = [(i + layer_size + 1) % (layer_size + 1) for i in repr_layers]
836
+ repr_layers = set(repr_layers)
837
+ hidden_representations = {}
838
+
839
+ # Process attention mask - 原始模型期望的是padding mask
840
+ if attention_mask is None:
841
+ padding_mask = hidden_states.new_zeros(hidden_states.shape[:2]).eq(self.padding_idx)
842
+ else:
843
+ # 原始模型中 padding_mask 是 True 表示 padding位置
844
+ padding_mask = attention_mask.eq(0)
845
+
846
+ # 0: embedding layer
847
+ if 0 in repr_layers:
848
+ hidden_representations[0] = hidden_states
849
+
850
+ # 转换为 (seq_len, batch_size, hidden_size) 格式,与原始模型一致
851
+ hidden_states = hidden_states.transpose(0, 1)
852
+
853
+ if not padding_mask.any():
854
+ padding_mask = None
855
+
856
+ # 是否需要返回head weights
857
+ if need_head_weights or output_attentions:
858
+ attn_weights = []
859
+
860
+ for layer_idx, layer_module in enumerate(self.layers):
861
+ if output_hidden_states:
862
+ all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),)
863
+
864
+ if self.gradient_checkpointing and self.training:
865
+ layer_outputs = self._gradient_checkpointing_func(
866
+ layer_module.__call__,
867
+ hidden_states,
868
+ None, # self_attn_mask
869
+ padding_mask,
870
+ need_head_weights or output_attentions,
871
+ )
872
+ else:
873
+ layer_outputs = layer_module(
874
+ hidden_states,
875
+ self_attn_mask=None,
876
+ self_attn_padding_mask=padding_mask,
877
+ need_head_weights=need_head_weights or output_attentions,
878
+ )
879
+
880
+ hidden_states, attn = layer_outputs
881
+
882
+ if (layer_idx + 1) in repr_layers:
883
+ hidden_representations[layer_idx + 1] = hidden_states.transpose(0, 1)
884
+
885
+ if need_head_weights or output_attentions:
886
+ # (H, B, L, L) => (B, H, L, L)
887
+ attn_weights.append(attn.transpose(1, 0))
888
+
889
+ # 应用最后的layer norm
890
+ if self.last_layer_norm is not None and use_last_layer_norm:
891
+ hidden_states = self.last_layer_norm(hidden_states)
892
+
893
+ # 转换回 (batch_size, seq_len, hidden_size) 格式
894
+ hidden_states = hidden_states.transpose(0, 1)
895
+
896
+ # last hidden representation should have layer norm applied
897
+ if (layer_idx + 1) in repr_layers:
898
+ hidden_representations[layer_idx + 1] = hidden_states
899
+
900
+ if output_hidden_states:
901
+ all_hidden_states = all_hidden_states + (hidden_states,)
902
+
903
+ if need_head_weights or output_attentions:
904
+ # 将attention weights转换为正确格式
905
+ if attn_weights:
906
+ # B x Layers x H x L x L
907
+ all_attentions = torch.stack(attn_weights, 1)
908
+ if padding_mask is not None:
909
+ attention_mask_expanded = 1 - padding_mask.type_as(all_attentions)
910
+ attention_mask_expanded = attention_mask_expanded.unsqueeze(1) * attention_mask_expanded.unsqueeze(2)
911
+ all_attentions = all_attentions * attention_mask_expanded[:, None, None, :, :]
912
+
913
+ if not output_attentions:
914
+ all_attentions = None
915
+
916
+ if not return_dict:
917
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
918
+
919
+ return BaseModelOutput(
920
+ last_hidden_state=hidden_states,
921
+ hidden_states=all_hidden_states,
922
+ attentions=all_attentions,
923
+ )
924
+
925
+ class LucaGPLMPreTrainedModel(PreTrainedModel):
926
+ config_class = LucaGPLMConfig
927
+ base_model_prefix = "lucaone"
928
+ supports_gradient_checkpointing = True
929
+ _no_split_modules = ["LucaGPLMTransformerLayer"]
930
+
931
+ def _init_weights(self, module):
932
+ if isinstance(module, nn.Linear):
933
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
934
+ if module.bias is not None:
935
+ module.bias.data.zero_()
936
+ elif isinstance(module, nn.Embedding):
937
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
938
+ if module.padding_idx is not None:
939
+ module.weight.data[module.padding_idx].zero_()
940
+ elif isinstance(module, (LucaGPLM1LayerNorm, LucaGPLM1bLayerNorm)):
941
+ if hasattr(module, 'weight') and module.weight is not None:
942
+ module.weight.data.fill_(1.0)
943
+ if hasattr(module, 'bias') and module.bias is not None:
944
+ module.bias.data.zero_()
945
+
946
+ class LucaGPLMModel(LucaGPLMPreTrainedModel):
947
+ """
948
+ The LucaGPLM model for extracting sequence representations and optionally predicting contacts.
949
+ Based on the original LucaGPLM implementation but restructured to use modern transformers architecture.
950
+ """
951
+
952
+ def __init__(self, config: LucaGPLMConfig):
953
+ super().__init__(config)
954
+ self.config = config
955
+ self.embeddings = LucaGPLMEmbeddings(self.config)
956
+ self.encoder = LucaGPLMEncoder(self.config)
957
+ self.post_init()
958
+
959
+ def get_input_embeddings(self):
960
+ return self.embeddings.embed_tokens
961
+
962
+ def set_input_embeddings(self, value):
963
+ self.embeddings.embed_tokens = value
964
+
965
+ def forward(
966
+ self,
967
+ input_ids: Optional[torch.Tensor] = None,
968
+ attention_mask: Optional[torch.Tensor] = None,
969
+ token_type_ids: Optional[torch.Tensor] = None,
970
+ position_ids: Optional[torch.Tensor] = None,
971
+ inputs_embeds: Optional[torch.Tensor] = None,
972
+ output_attentions: Optional[bool] = None,
973
+ output_hidden_states: Optional[bool] = None,
974
+ return_contacts: Optional[bool] = None,
975
+ return_dict: Optional[bool] = None,
976
+ need_head_weights: Optional[bool] = None,
977
+ repr_layers: Optional[List[int]] = None,
978
+ use_last_layer_norm: Optional[bool] = True,
979
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
980
+
981
+ output_attentions = output_attentions if output_attentions is not None else getattr(self.config, 'output_attentions', False)
982
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, 'output_hidden_states', False)
983
+ return_contacts = return_contacts if return_contacts is not None else False
984
+ return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
985
+ need_head_weights = need_head_weights if need_head_weights is not None else return_contacts # Need attention weights for contacts
986
+ use_last_layer_norm = use_last_layer_norm if use_last_layer_norm is not None else True
987
+
988
+ # Force output_attentions=True when return_contacts=True since we need attention weights
989
+ if return_contacts:
990
+ output_attentions = True
991
+ need_head_weights = True
992
+
993
+ if input_ids is not None and inputs_embeds is not None:
994
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
995
+ elif input_ids is not None:
996
+ input_shape = input_ids.size()
997
+ elif inputs_embeds is not None:
998
+ input_shape = inputs_embeds.size()[:-1]
999
+ else:
1000
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1001
+
1002
+ batch_size, seq_length = input_shape
1003
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1004
+
1005
+ # Create attention mask if not provided
1006
+ if attention_mask is None:
1007
+ attention_mask = torch.ones(input_shape, device=device)
1008
+
1009
+ # Get embeddings
1010
+ if inputs_embeds is None:
1011
+ embedding_output = self.embeddings(
1012
+ input_ids=input_ids,
1013
+ position_ids=position_ids,
1014
+ token_type_ids=token_type_ids,
1015
+ )
1016
+ else:
1017
+ embedding_output = inputs_embeds
1018
+
1019
+ # Pass through encoder
1020
+ encoder_outputs = self.encoder(
1021
+ embedding_output,
1022
+ attention_mask=attention_mask,
1023
+ output_attentions=output_attentions,
1024
+ output_hidden_states=output_hidden_states,
1025
+ return_dict=return_dict,
1026
+ need_head_weights=need_head_weights,
1027
+ repr_layers=repr_layers,
1028
+ use_last_layer_norm=use_last_layer_norm,
1029
+ )
1030
+
1031
+ sequence_output = encoder_outputs[0]
1032
+
1033
+ # Handle contact prediction
1034
+ contacts = None
1035
+ if return_contacts and encoder_outputs.attentions is not None:
1036
+ # Simple contact prediction using attention weights
1037
+ # This is a simplified implementation - you can enhance this later
1038
+ attentions = encoder_outputs.attentions
1039
+ # Average over layers and heads, then symmetrize
1040
+ averaged_attention = attentions.mean(dim=(1, 2)) # Average over layers and heads
1041
+ contacts = (averaged_attention + averaged_attention.transpose(-1, -2)) / 2
1042
+
1043
+ # Remove special tokens (BOS/EOS) if present
1044
+ if attention_mask is not None:
1045
+ # Find actual sequence positions (non-padding)
1046
+ seq_lens = attention_mask.sum(dim=1)
1047
+ # For now, keep the full contact map - you can trim special tokens later if needed
1048
+
1049
+ if not return_dict:
1050
+ outputs = (sequence_output, ) + encoder_outputs[1:]
1051
+ if contacts is not None:
1052
+ outputs = outputs + (contacts,)
1053
+ return outputs
1054
+
1055
+ # Create output object with contacts
1056
+ output = BaseModelOutput(
1057
+ last_hidden_state=sequence_output,
1058
+ hidden_states=encoder_outputs.hidden_states,
1059
+ attentions=encoder_outputs.attentions,
1060
+ )
1061
+
1062
+ # Add contacts as an attribute if computed
1063
+ if contacts is not None:
1064
+ output.contacts = contacts
1065
+
1066
+ return output
1067
+
1068
+ class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel):
1069
+ def __init__(self, config):
1070
+ super().__init__(config)
1071
+ # 基础编码器
1072
+ self.lucaone = LucaGPLMModel(config)
1073
+
1074
+ # MLM 预测头
1075
+ self.lm_head = LucaGPLMRobertaLMHead(
1076
+ embed_dim=config.hidden_size,
1077
+ output_dim=config.vocab_size
1078
+ )
1079
+ self._tied_weights_keys = [
1080
+ "lucaone.embeddings.embed_tokens.weight",
1081
+ "lm_head.decoder.weight"
1082
+ ]
1083
+ # 初始化权重并进行权重绑定
1084
+ self.post_init()
1085
+
1086
+ def get_input_embeddings(self):
1087
+ return self.lucaone.get_input_embeddings()
1088
+
1089
+ def get_output_embeddings(self):
1090
+ return self.lm_head.decoder
1091
+
1092
+ def set_output_embeddings(self, new_embeddings):
1093
+ self.lm_head.decoder = new_embeddings
1094
+
1095
+ def forward(
1096
+ self,
1097
+ input_ids: Optional[torch.Tensor] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ token_type_ids: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.Tensor] = None,
1101
+ labels: Optional[torch.Tensor] = None, # MLM 训练时的标签
1102
+ output_attentions: Optional[bool] = None,
1103
+ output_hidden_states: Optional[bool] = None,
1104
+ return_dict: Optional[bool] = None,
1105
+ ) -> Union[Tuple, MaskedLMOutput]:
1106
+
1107
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1108
+
1109
+ # 1. 获取基础模型的输出 (Hidden States)
1110
+ outputs = self.lucaone(
1111
+ input_ids,
1112
+ attention_mask=attention_mask,
1113
+ token_type_ids=token_type_ids,
1114
+ position_ids=position_ids,
1115
+ output_attentions=output_attentions,
1116
+ output_hidden_states=output_hidden_states,
1117
+ return_dict=return_dict,
1118
+ )
1119
+
1120
+ sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
1121
+
1122
+ # 2. 通过 MLM Head 得到预测结果 (Logits)
1123
+ prediction_scores = self.lm_head(sequence_output)
1124
+
1125
+ masked_lm_loss = None
1126
+ if labels is not None:
1127
+ # 3. 计算 MLM Loss
1128
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # 默认 ignore_index=-100
1129
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1130
+
1131
+ if not return_dict:
1132
+ output = (prediction_scores,) + outputs[2:]
1133
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1134
+
1135
+ return MaskedLMOutput(
1136
+ loss=masked_lm_loss,
1137
+ logits=prediction_scores,
1138
+ hidden_states=outputs.hidden_states,
1139
+ attentions=outputs.attentions,
1140
+ )
1141
+
1142
+ class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
1143
+ def __init__(self, config):
1144
+ super().__init__(config)
1145
+ self.num_labels = config.classifier_num_labels
1146
+ self.task_level = config.task_level
1147
+ self.task_type = config.task_type
1148
+ assert self.task_level == "seq_level"
1149
+ self.classifier_pooling_type = config.classifier_pooling_type
1150
+ self.classifier_loss_type = config.classifier_loss_type
1151
+ self.classifier_loss_reduction = config.classifier_loss_reduction
1152
+ self.classifier_pos_weight = config.classifier_pos_weight
1153
+ self.classifier_weight = config.classifier_weight
1154
+ self.lucaone = LucaGPLMModel(config) # 基础模型
1155
+ if self.classifier_pooling_type == "value_attention":
1156
+ self.pooler = LucaGPLMGlobalMaskValueAttentionPooling1D(config.hidden_size)
1157
+ elif self.classifier_pooling_type == "context_attention":
1158
+ self.pooler = LucaGPLMGlobalMaskContextAttentionPooling1D(embed_size=config.hidden_size)
1159
+ elif self.classifier_pooling_type == "weighted_attention":
1160
+ self.pooler = LucaGPLMGlobalMaskWeightedAttentionPooling1D(embed_size=config.hidden_size)
1161
+ else:
1162
+ self.pooler = None
1163
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1164
+
1165
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1166
+ if self.task_type == "multi_class":
1167
+ weight = None
1168
+ if self.classifier_weight:
1169
+ if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int):
1170
+ weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32)
1171
+ elif isinstance(self.classifier_weight, float):
1172
+ weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32)
1173
+ elif isinstance(self.classifier_weight, list):
1174
+ weight = torch.tensor(self.classifier_weight, dtype=torch.float32)
1175
+ self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean")
1176
+ elif self.task_type == "binary_class":
1177
+ pos_weight = None
1178
+ if self.classifier_pos_weight:
1179
+ if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
1180
+ pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32)
1181
+ elif isinstance(self.classifier_pos_weight, float):
1182
+ pos_weight = torch.tensor([self.classifier_pos_weight], dtype=torch.float32)
1183
+ self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean")
1184
+ elif self.task_type == "regression":
1185
+ if self.classifier_loss_type == "mae":
1186
+ self.loss_fct = nn.L1Loss(reduction="mean")
1187
+ else:
1188
+ self.loss_fct = nn.MSELoss(reduction="mean")
1189
+ elif self.task_type == "multi_label":
1190
+ pos_weight = None
1191
+ if self.classifier_pos_weight:
1192
+ if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
1193
+ pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32)
1194
+ elif isinstance(self.classifier_pos_weight, float):
1195
+ pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32)
1196
+ self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction)
1197
+ else:
1198
+ raise ValueError("Invalid task type: %s" % self.task_type)
1199
+ self.post_init()
1200
+
1201
+ def forward(
1202
+ self,
1203
+ input_ids=None,
1204
+ token_type_ids=None,
1205
+ attention_mask=None,
1206
+ labels=None,
1207
+ return_dict=None
1208
+ ):
1209
+ return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
1210
+ outputs = self.lucaone(
1211
+ input_ids,
1212
+ token_type_ids=token_type_ids,
1213
+ attention_mask=attention_mask,
1214
+ return_dict=return_dict
1215
+ )
1216
+ if self.pooler is not None:
1217
+ pooled_output = self.pooler(outputs[0])
1218
+ elif self.classifier_pooling_type == "cls":
1219
+ # 取 CLS token
1220
+ pooled_output = outputs[0][:, 0, :]
1221
+ elif self.classifier_pooling_type == "mean":
1222
+ pooled_output = outputs[0].mean(dim=1)
1223
+ else:
1224
+ raise ValueError("Invalid classifier pooling type: %s" % self.classifier_pooling_type)
1225
+
1226
+ pooled_output = self.dropout(pooled_output)
1227
+ logits = self.classifier(pooled_output)
1228
+
1229
+ loss = None
1230
+ if labels is not None:
1231
+ if self.task_type == "multi_class":
1232
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1233
+ elif self.task_type == "binary_class":
1234
+ loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
1235
+ elif self.task_type == "regression":
1236
+ loss = self.loss_fct(logits.view(-1), labels.view(-1))
1237
+ elif self.task_type == "multi_label":
1238
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
1239
+ else:
1240
+ raise ValueError("Invalid task type: %s" % self.task_type)
1241
+
1242
+ if not return_dict:
1243
+ output = (logits,) + outputs[1:]
1244
+ return ((loss,) + output) if loss is not None else output
1245
+
1246
+ return SequenceClassifierOutput(loss=loss, logits=logits)
1247
+
1248
+ class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel):
1249
+ def __init__(self, config):
1250
+ super().__init__(config)
1251
+ self.num_labels = config.classifier_num_labels
1252
+ self.task_level = config.task_level
1253
+ self.task_type = config.task_type
1254
+ assert self.task_level == "token_level"
1255
+ self.classifier_pooling_type = config.classifier_pooling_type
1256
+ self.classifier_loss_type = config.classifier_loss_type
1257
+ self.classifier_loss_reduction = config.classifier_loss_reduction
1258
+ self.classifier_pos_weight = config.classifier_pos_weight
1259
+ self.classifier_weight = config.classifier_weight
1260
+ self.lucaone = LucaGPLMModel(config) # 基础模型
1261
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1262
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1263
+ if self.task_type == "multi_class":
1264
+ weight = None
1265
+ if self.classifier_weight:
1266
+ # [1, 1, 1, ,1, 1...] length: num_labels
1267
+ if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int):
1268
+ weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32)
1269
+ elif isinstance(self.classifier_weight, float):
1270
+ weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32)
1271
+ elif isinstance(self.classifier_weight, list):
1272
+ weight = torch.tensor(self.classifier_weight, dtype=torch.float32)
1273
+ self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean")
1274
+ elif self.task_type == "binary_class":
1275
+ pos_weight = None
1276
+ if self.classifier_pos_weight:
1277
+ if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
1278
+ pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32)
1279
+ elif isinstance(self.classifier_pos_weight, float):
1280
+ pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32)
1281
+ self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean")
1282
+ elif self.task_type == "regression":
1283
+ if self.classifier_loss_type == "mae":
1284
+ self.loss_fct = nn.L1Loss(reduction="mean")
1285
+ else:
1286
+ self.loss_fct = nn.MSELoss(reduction="mean")
1287
+ elif self.task_type == "multi_label":
1288
+ pos_weight = None
1289
+ if self.classifier_pos_weight:
1290
+ if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
1291
+ pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32)
1292
+ elif isinstance(self.classifier_pos_weight, float):
1293
+ pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32)
1294
+ self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction)
1295
+ else:
1296
+ raise ValueError("Invalid task type: %s" % self.task_type)
1297
+ self.post_init()
1298
+
1299
+ def forward(
1300
+ self,
1301
+ input_ids=None,
1302
+ token_type_ids=None,
1303
+ attention_mask=None,
1304
+ labels=None,
1305
+ return_dict=None
1306
+ ):
1307
+ return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
1308
+ outputs = self.lucaone(
1309
+ input_ids,
1310
+ token_type_ids=token_type_ids,
1311
+ attention_mask=attention_mask,
1312
+ return_dict=return_dict
1313
+ )
1314
+ sequence_output = outputs[0][:, 1:-1, :] # (B, L, H)
1315
+
1316
+ sequence_output = self.dropout(sequence_output)
1317
+ logits = self.classifier(sequence_output)
1318
+
1319
+ loss = None
1320
+ if labels is not None:
1321
+ if self.task_type == "multi_class":
1322
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1323
+ elif self.task_type == "binary_class":
1324
+ loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
1325
+ elif self.task_type == "regression":
1326
+ loss = self.loss_fct(logits.view(-1), labels.view(-1))
1327
+ elif self.task_type == "multi_label":
1328
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
1329
+ else:
1330
+ raise ValueError("Invalid task type: %s" % self.task_type)
1331
+
1332
+
1333
+ if not return_dict:
1334
+ output = (logits,) + outputs[1:]
1335
+ return ((loss,) + output) if loss is not None else output
1336
+ return TokenClassifierOutput(loss=loss, logits=logits)
1337
+
1338
+ __all__ = [
1339
+ "LucaGPLMModel",
1340
+ "LucaGPLMPreTrainedModel",
1341
+ "LucaGPLMForMaskedLM",
1342
+ "LucaGPLMForSequenceClassification",
1343
+ "LucaGPLMForTokenClassification"
1344
+ ]
tokenization_lucaone.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2025, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2025/12/30 11:33
9
+ @project: lucaone
10
+ @file: tokenization_lucaone
11
+ @desc: tokenization_lucaone
12
+ '''
13
+
14
+ import os
15
+ import json
16
+ import itertools
17
+ from typing import List, Optional, Dict, Any, Tuple, Union
18
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
19
+
20
+ def gene_seq_replace(seq):
21
+ """
22
+ Gene sequence preprocessing: A->1, U/T->2, C->3, G->4, N->5
23
+ Optimized for performance.
24
+ """
25
+ # 使用字典映射比 if-else 判断快
26
+ mapping = {
27
+ 'A': '1', 'a': '1',
28
+ 'T': '2', 't': '2', 'U': '2', 'u': '2',
29
+ 'C': '3', 'c': '3',
30
+ 'G': '4', 'g': '4'
31
+ }
32
+ # 对于不在字典中的字符(如 N),默认返回 '5'
33
+ return "".join([mapping.get(ch, '5') for ch in seq])
34
+
35
+ class LucaGPLMTokenizer(PreTrainedTokenizer):
36
+ """
37
+ HuggingFace-compatible tokenizer that performs identical tokenization
38
+ to the old model's Alphabet class.
39
+ """
40
+
41
+ # Vocabulary definitions matching the old model
42
+ gene_prepend_toks = ['[PAD]', '[UNK]']
43
+ gene_append_toks = ['[CLS]', '[SEP]', '[MASK]']
44
+ gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
45
+
46
+ prot_prepend_toks = ['[PAD]', '[UNK]']
47
+ prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
48
+ prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
49
+
50
+ gene_prot_prepend_toks = ['[PAD]', '[UNK]']
51
+ gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
52
+ # EXACT VOCABULARY ORDER FROM ORIGINAL ALPHABET CLASS
53
+
54
+ gene_prot_standard_toks = [
55
+ '1', # 5 - gene A (after gene_seq_replace)
56
+ '2', # 6 - gene T/U (after gene_seq_replace)
57
+ '3', # 7 - gene C (after gene_seq_replace)
58
+ '4', # 8 - gene G (after gene_seq_replace)
59
+ '5', # 9 - gene N/unknown
60
+ 'L', # 10 - protein
61
+ 'A', # 11 - protein
62
+ 'G', # 12 - protein
63
+ 'V', # 13 - protein
64
+ 'S', # 14 - protein
65
+ 'E', # 15 - protein
66
+ 'R', # 16 - protein
67
+ 'T', # 17 - protein
68
+ 'I', # 18 - protein
69
+ 'D', # 19 - protein
70
+ 'P', # 20 - protein
71
+ 'K', # 21 - protein
72
+ 'Q', # 22 - protein
73
+ 'N', # 23 - protein
74
+ 'F', # 24 - protein
75
+ 'Y', # 25 - protein
76
+ 'M', # 26 - protein
77
+ 'H', # 27 - protein
78
+ 'W', # 28 - protein
79
+ 'C', # 29 - protein
80
+ 'X', # 30 - protein unknown
81
+ 'B', # 31 - protein
82
+ 'U', # 32 - protein
83
+ 'Z', # 33 - protein
84
+ 'O', # 34 - protein
85
+ 'J', # 35 - protein
86
+ '.', # 36 - special
87
+ '-', # 37 - special
88
+ '*' # 38 - special
89
+ ]
90
+
91
+ def __init__(
92
+ self,
93
+ vocab_type: str = "gene_prot",
94
+ prepend_bos: bool = True,
95
+ append_eos: bool = True,
96
+ unk_token="[UNK]",
97
+ pad_token="[PAD]",
98
+ cls_token="[CLS]",
99
+ sep_token="[SEP]",
100
+ mask_token="[MASK]",
101
+ **kwargs
102
+ ):
103
+ # Set vocabulary based on type
104
+ if vocab_type.lower() == "prot":
105
+ prepend_toks = self.prot_prepend_toks
106
+ append_toks = self.prot_append_toks
107
+ standard_toks = self.prot_standard_toks
108
+ elif vocab_type.lower() == "gene":
109
+ prepend_toks = self.gene_prepend_toks
110
+ append_toks = self.gene_append_toks
111
+ standard_toks = self.gene_standard_toks
112
+ elif vocab_type.lower() in ["gene_prot", "prot_gene"]:
113
+ prepend_toks = self.gene_prot_prepend_toks
114
+ append_toks = self.gene_prot_append_toks
115
+ standard_toks = self.gene_prot_standard_toks
116
+ else:
117
+ raise ValueError(f"Not support tokenizer vocab_type: {vocab_type}")
118
+
119
+ # Build vocabulary
120
+ self.all_toks = list(prepend_toks) + list(append_toks) + list(standard_toks)
121
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
122
+ self.idx_to_tok = {i: tok for i, tok in enumerate(self.all_toks)}
123
+
124
+ # Store configuration
125
+ self.vocab_type = vocab_type
126
+ self.prepend_bos = prepend_bos
127
+ self.append_eos = append_eos
128
+ self.unique_no_split_tokens = self.all_toks.copy()
129
+
130
+ # Special token indices
131
+ self.unk_idx = self.tok_to_idx.get("[UNK]", 1)
132
+ self.padding_idx = self.tok_to_idx.get("[PAD]", 0)
133
+ self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
134
+ self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
135
+ self.eos_idx = self.tok_to_idx.get("[SEP]", 3)
136
+
137
+ super().__init__(
138
+ unk_token=unk_token,
139
+ pad_token=pad_token,
140
+ cls_token=cls_token,
141
+ sep_token=sep_token,
142
+ mask_token=mask_token,
143
+ **kwargs
144
+ )
145
+
146
+ def get_vocab(self) -> Dict[str, int]:
147
+ return self.tok_to_idx.copy()
148
+
149
+ @property
150
+ def vocab_size(self) -> int:
151
+ return len(self.all_toks)
152
+
153
+ def get_idx(self, tok):
154
+ return self.tok_to_idx.get(tok, self.unk_idx)
155
+
156
+ def get_tok(self, idx):
157
+ return self.idx_to_tok.get(idx, "[UNK]")
158
+
159
+ def _tokenize_char_level(self, text: str) -> List[str]:
160
+ """Simple character-level tokenization (fallback)"""
161
+ return list(text)
162
+
163
+ def _tokenize(self, text: str) -> List[str]:
164
+ """
165
+ Tokenize text using the same logic as the old Alphabet.tokenize() method
166
+ """
167
+ text = text.strip()
168
+ if not text:
169
+ return []
170
+
171
+ return list(text)
172
+
173
+ def _convert_token_to_id(self, token: str) -> int:
174
+ return self.get_idx(token)
175
+
176
+ def _convert_id_to_token(self, index: int) -> str:
177
+ return self.get_tok(index)
178
+
179
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
180
+ return "".join(tokens)
181
+
182
+ def _convert_text_to_ids(self, text: str, seq_type: str) -> List[int]:
183
+ """Internal helper to convert text to IDs without special tokens."""
184
+ if seq_type == "gene":
185
+ text = gene_seq_replace(text)
186
+ tokens = self._tokenize(text)
187
+ return [self._convert_token_to_id(token) for token in tokens]
188
+
189
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
190
+ """
191
+ Build model inputs from a sequence by adding special tokens.
192
+ This mimics the old model's prepend_bos and append_eos behavior.
193
+ """
194
+ result = token_ids_0.copy()
195
+
196
+ if self.prepend_bos:
197
+ result = [self.cls_idx] + result
198
+ if self.append_eos:
199
+ result = result + [self.eos_idx]
200
+
201
+ return result
202
+
203
+ def get_special_tokens_mask(
204
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
205
+ ) -> List[int]:
206
+ """
207
+ Retrieve sequence ids from a token list.
208
+ """
209
+ if already_has_special_tokens:
210
+ return super().get_special_tokens_mask(
211
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
212
+ )
213
+
214
+ result = [0] * len(token_ids_0)
215
+ if self.prepend_bos:
216
+ result = [1] + result
217
+ if self.append_eos:
218
+ result = result + [1]
219
+ return result
220
+
221
+ def encode(
222
+ self,
223
+ text: str,
224
+ seq_type: str = "gene",
225
+ add_special_tokens: bool = True,
226
+ padding: Union[bool, str] = False, # 虽然 encode 通常不处理 padding,但保持 API 兼容性
227
+ truncation: bool = False, # <--- 关键参数
228
+ max_length: Optional[int] = None, # <--- 关键参数
229
+ **kwargs
230
+ ) -> List[int]:
231
+
232
+ # 1. 基础转换
233
+ token_ids = self._convert_text_to_ids(text, seq_type)
234
+
235
+ # 2. 添加特殊 token
236
+ if add_special_tokens:
237
+ token_ids = self.build_inputs_with_special_tokens(token_ids)
238
+
239
+ # 3. 执行截断 (修复点:之前这里缺失逻辑)
240
+ if truncation and max_length is not None and len(token_ids) > max_length:
241
+ token_ids = token_ids[:max_length]
242
+ # 如果启用了 append_eos,强行把截断后的最后一位改回 SEP
243
+ if add_special_tokens and self.append_eos:
244
+ token_ids[-1] = self.eos_idx
245
+
246
+ return token_ids
247
+
248
+ def __call__(
249
+ self,
250
+ text: Union[str, List[str]],
251
+ text_pair: Optional[Union[str, List[str]]] = None,
252
+ seq_type: str = "gene",
253
+ add_special_tokens: bool = True,
254
+ padding: Union[bool, str] = False,
255
+ max_length: Optional[int] = None,
256
+ return_attention_mask: bool = True,
257
+ return_token_type_ids: bool = True,
258
+ return_tensors: Optional[str] = None,
259
+ truncation: bool = False,
260
+ **kwargs
261
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
262
+ """
263
+ Main callable method for tokenization - HuggingFace standard interface
264
+ """
265
+ if isinstance(text, list):
266
+ # Handle batch processing
267
+ return self.batch_encode_plus(
268
+ text,
269
+ text_pair=text_pair,
270
+ seq_type=seq_type,
271
+ add_special_tokens=add_special_tokens,
272
+ padding=padding,
273
+ max_length=max_length,
274
+ return_attention_mask=return_attention_mask,
275
+ return_token_type_ids=return_token_type_ids,
276
+ return_tensors=return_tensors,
277
+ truncation=truncation,
278
+ **kwargs
279
+ )
280
+ else:
281
+ # Handle single text
282
+ return self.encode_plus(
283
+ text,
284
+ text_pair=text_pair,
285
+ seq_type=seq_type,
286
+ add_special_tokens=add_special_tokens,
287
+ padding=padding,
288
+ max_length=max_length,
289
+ return_attention_mask=return_attention_mask,
290
+ return_token_type_ids=return_token_type_ids,
291
+ return_tensors=return_tensors,
292
+ truncation=truncation,
293
+ **kwargs
294
+ )
295
+
296
+ def batch_encode_plus(self, *args, **kwargs):
297
+ # 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
298
+ return super().batch_encode_plus(*args, **kwargs)
299
+
300
+ def encode_plus(
301
+ self,
302
+ text: str,
303
+ text_pair: Optional[str] = None,
304
+ seq_type: str = "gene",
305
+ add_special_tokens: bool = True,
306
+ padding: Union[bool, str] = False,
307
+ max_length: Optional[int] = None,
308
+ return_attention_mask: bool = True,
309
+ return_token_type_ids: bool = True,
310
+ return_tensors: Optional[str] = None,
311
+ truncation: bool = False,
312
+ **kwargs
313
+ ) -> Dict[str, Any]:
314
+
315
+ # 调用修复后的 encode,它现在会正确处理截断
316
+ token_ids = self.encode(
317
+ text,
318
+ seq_type=seq_type,
319
+ add_special_tokens=add_special_tokens,
320
+ truncation=truncation,
321
+ max_length=max_length
322
+ )
323
+
324
+ # 处理 Padding
325
+ attention_mask = [1] * len(token_ids)
326
+ if padding == "max_length" and max_length is not None:
327
+ if len(token_ids) < max_length:
328
+ pad_length = max_length - len(token_ids)
329
+ token_ids.extend([self.padding_idx] * pad_length)
330
+ attention_mask.extend([0] * pad_length)
331
+ # 注意:padding=True (dynamic padding) 通常由 batch_encode_plus 处理,这里单条通常不处理
332
+
333
+ result = {"input_ids": token_ids}
334
+
335
+ if return_attention_mask:
336
+ result["attention_mask"] = attention_mask
337
+
338
+ if return_token_type_ids:
339
+ # 0 for gene, 1 for protein
340
+ type_value = 0 if seq_type == "gene" else 1
341
+ result["token_type_ids"] = [type_value] * len(token_ids)
342
+
343
+ if return_tensors == "pt":
344
+ import torch
345
+ for key, value in result.items():
346
+ result[key] = torch.tensor(value, dtype=torch.long).unsqueeze(0)
347
+
348
+ return result
349
+
350
+ def encode_old_model_style(
351
+ self,
352
+ text: str,
353
+ seq_type: str = "gene",
354
+ max_length: int = None
355
+ ) -> List[int]:
356
+ """
357
+ Encode using the EXACT same process as the old model's encoder function.
358
+ This replicates the logic from src/llm/lucaone_virus/get_embedding.py:encoder()
359
+ """
360
+ # Preprocess gene sequences (done in get_embedding function BEFORE calling encoder)
361
+ if seq_type == "gene":
362
+ text = gene_seq_replace(text)
363
+
364
+ # Call tokenizer.encode (which does NOT include BOS/EOS in old model)
365
+ seq_encoded = self.encode(text, seq_type=seq_type, add_special_tokens=False)
366
+
367
+ # Apply max_length truncation if specified
368
+ if max_length and len(seq_encoded) > max_length:
369
+ seq_encoded = seq_encoded[:max_length]
370
+
371
+ # Calculate processed_seq_len (as done in old model)
372
+ processed_seq_len = len(seq_encoded) + int(self.prepend_bos) + int(self.append_eos)
373
+
374
+ # Create input_ids tensor (as done in old model encoder function)
375
+ input_ids = [self.padding_idx] * processed_seq_len
376
+
377
+ # Add BOS token if enabled
378
+ if self.prepend_bos:
379
+ input_ids[0] = self.cls_idx
380
+
381
+ # Place the encoded sequence
382
+ start_idx = int(self.prepend_bos)
383
+ for i, token_id in enumerate(seq_encoded):
384
+ input_ids[start_idx + i] = token_id
385
+
386
+ # Add EOS token if enabled
387
+ if self.append_eos:
388
+ input_ids[len(seq_encoded) + int(self.prepend_bos)] = self.eos_idx
389
+
390
+ return input_ids
391
+
392
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
393
+ """
394
+ Save the tokenizer vocabulary to a JSON file.
395
+ Required by HuggingFace tokenizer interface.
396
+ """
397
+ if filename_prefix is None:
398
+ filename_prefix = ""
399
+ else:
400
+ filename_prefix = filename_prefix + "-"
401
+
402
+ vocab_file = os.path.join(save_directory, f"{filename_prefix}vocab.json")
403
+ vocab_dict = self.get_vocab()
404
+ with open(vocab_file, "w", encoding="utf-8") as f:
405
+ json.dump(vocab_dict, f, ensure_ascii=False, indent=2)
406
+
407
+ return (vocab_file,)
408
+
409
+ @classmethod
410
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
411
+ """
412
+ Load tokenizer from pretrained model path (standard HuggingFace interface)
413
+ """
414
+ vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
415
+ if os.path.exists(vocab_file):
416
+ print("Load from saved vocabulary (not implemented yet, use default)")
417
+ return cls(vocab_type="gene_prot", **kwargs)
418
+ else:
419
+ return cls(vocab_type="gene_prot", **kwargs)
420
+
421
+ class LucaGPLMTokenizerFast(PreTrainedTokenizerFast):
422
+ """
423
+ Fast tokenizer version - currently just delegates to slow tokenizer
424
+ """
425
+ slow_tokenizer_class = LucaGPLMTokenizer
426
+
427
+ def __init__(self, **kwargs):
428
+ # For now, this is just a placeholder
429
+ # In a full implementation, you would use the tokenizers library
430
+ super().__init__(**kwargs)
431
+
432
+ __all__ = ["LucaGPLMTokenizer", "LucaGPLMTokenizerFast", "gene_seq_replace"]
tokenizer_config.json CHANGED
@@ -41,12 +41,18 @@
41
  "special": true
42
  }
43
  },
44
- "clean_up_tokenization_spaces": false,
45
  "cls_token": "[CLS]",
46
  "mask_token": "[MASK]",
47
  "model_max_length": 1000000000000000019884624838656,
48
  "pad_token": "[PAD]",
49
  "sep_token": "[SEP]",
50
  "tokenizer_class": "LucaGPLMTokenizer",
51
- "unk_token": "[UNK]"
52
- }
 
 
 
 
 
 
 
41
  "special": true
42
  }
43
  },
44
+ "clean_up_tokenization_spaces": true,
45
  "cls_token": "[CLS]",
46
  "mask_token": "[MASK]",
47
  "model_max_length": 1000000000000000019884624838656,
48
  "pad_token": "[PAD]",
49
  "sep_token": "[SEP]",
50
  "tokenizer_class": "LucaGPLMTokenizer",
51
+ "unk_token": "[UNK]",
52
+ "auto_map": {
53
+ "AutoTokenizer": [
54
+ "tokenization_lucaone.LucaGPLMTokenizer",
55
+ null
56
+ ]
57
+ }
58
+ }