fix: return the out tensor rather then the functions return value (#2361)
This commit is contained in:
parent
dd47a3dac4
commit
29b8d19cdf
|
@ -292,8 +292,7 @@ else:
|
|||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
return flash_attn_cuda.fwd(
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
@ -309,4 +308,5 @@ else:
|
|||
False,
|
||||
0,
|
||||
None,
|
||||
)[0]
|
||||
)
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue