Removed stray cuda call
#2
by
justbruno
- opened
- modeling_aria.py +1 -1
modeling_aria.py
CHANGED
|
@@ -620,7 +620,7 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
| 620 |
def precompute_causal_mask(max_seq_len: int):
|
| 621 |
return torch.tril(
|
| 622 |
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
| 623 |
-
)
|
| 624 |
|
| 625 |
|
| 626 |
def precompute_freqs_cis(
|
|
|
|
| 620 |
def precompute_causal_mask(max_seq_len: int):
|
| 621 |
return torch.tril(
|
| 622 |
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
| 623 |
+
)
|
| 624 |
|
| 625 |
|
| 626 |
def precompute_freqs_cis(
|