howard-hou commited on
Commit
586d6a1
1 Parent(s): 1d2fc64

Update modeling_rwkv.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv.py +1 -0
modeling_rwkv.py CHANGED
@@ -1043,6 +1043,7 @@ class RWKV(MyModule):
1043
  elif embs is not None and tokens is not None:
1044
  seq_mode = len(tokens) > 1
1045
  x = w['emb.weight'][tokens if seq_mode else tokens[0]]
 
1046
  x = torch.cat([x, embs], dim=0)
1047
  else:
1048
  raise ValueError('Either tokens or embs must be provided')
 
1043
  elif embs is not None and tokens is not None:
1044
  seq_mode = len(tokens) > 1
1045
  x = w['emb.weight'][tokens if seq_mode else tokens[0]]
1046
+ x = x.to(device=embs.device, dtype=embs.dtype, non_blocking=True)
1047
  x = torch.cat([x, embs], dim=0)
1048
  else:
1049
  raise ValueError('Either tokens or embs must be provided')