Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative attentions: ReBased linear flashattn and LWM's RingAttention #114

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def forward(self, x):
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x

try:
# needs to have https://github.com/corl-team/rebased/ installed
from fla.ops.triton.rebased_fast import parallel_rebased
except:
REBASED_IS_AVAILABLE = False

try:
# needs to have https://github.com/lucidrains/ring-attention-pytorch installed
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda
except:
RING_ATTENTION_IS_AVAILABLE = False

class Attention(nn.Module):
def __init__(
Expand All @@ -121,6 +132,9 @@ def __init__(
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
enable_flashattn: bool = False,
enable_flashlinearattn: bool = False,
enable_ringattn: bool = False,
eps=1e-12, causal=True, ring_bucket_size=1024
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
Expand All @@ -129,6 +143,8 @@ def __init__(
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn = enable_flashattn
self.enable_flashlinearattn = enable_flashlinearattn
self.enable_ringattn = enable_ringattn

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
Expand All @@ -137,6 +153,10 @@ def __init__(
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.eps = eps
self.causal = causal
self.ring_bucket_size = ring_bucket_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
Expand All @@ -158,6 +178,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
elif self.enable_flashlinearattn:
if not REBASED_IS_AVAILABLE:
raise Exception("Flash Linear attention is selected, but ReBased is not available!")
x = parallel_rebased(q, k, v, self.eps, True, True)

elif self.enable_ringattn:
if not RING_ATTENTION_IS_AVAILABLE:
raise Exception("Ring attention is selected, but it is not installed!")
x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size)
else:
dtype = q.dtype
q = q * self.scale
Expand Down Expand Up @@ -188,6 +217,9 @@ def __init__(
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
enable_flashattn: bool = False,
enable_flashlinearattn: bool = False,
enable_ringattn: bool = False,
eps=1e-12, causal=True, ring_bucket_size=1024
) -> None:
super().__init__(
dim=dim,
Expand All @@ -198,6 +230,9 @@ def __init__(
proj_drop=proj_drop,
norm_layer=norm_layer,
enable_flashattn=enable_flashattn,
enable_flashlinearattn=enable_flashlinearattn,
enable_ringattn=enable_ringattn,
eps=eps, causal=True, ring_bucket_size=1024
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -231,6 +266,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
elif self.enable_flashlinearattn:
if not REBASED_IS_AVAILABLE:
raise Exception("Flash Linear attention is selected, but ReBased is not available!")
x = parallel_rebased(q, k, v, self.eps, True, True)

elif self.enable_ringattn:
if not RING_ATTENTION_IS_AVAILABLE:
raise Exception("Ring attention is selected, but it is not installed!")
x = ring_flash_attn_cuda(q, k, v, causal=self.causal, bucket_size=self.ring_bucket_size)
else:
dtype = q.dtype
q = q * self.scale
Expand Down