Removed stray cuda call

#2
by justbruno - opened
Files changed (1) hide show
  1. modeling_aria.py +1 -1
modeling_aria.py CHANGED
@@ -620,7 +620,7 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
620
  def precompute_causal_mask(max_seq_len: int):
621
  return torch.tril(
622
  torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
- ).cuda()
624
 
625
 
626
  def precompute_freqs_cis(
 
620
  def precompute_causal_mask(max_seq_len: int):
621
  return torch.tril(
622
  torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
+ )
624
 
625
 
626
  def precompute_freqs_cis(