Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arch/mixtral #169

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file added tests/__init__.py
Empty file.
74 changes: 74 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import gc
import random
from contextlib import contextmanager

import os
import pytest
import numpy as np
import torch
import torch._dynamo as dynamo


@contextmanager
def set_seed(seed: int = 0):
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
yield

@pytest.fixture(autouse=True)
def reset_dyno_state():
cache_limit = dynamo.config.cache_size_limit
try:
dynamo.config.cache_size_limit = 512
dynamo.reset()
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
yield {}
except RuntimeError as err:
raise err
finally:
dynamo.config.cache_size_limit = cache_limit
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()


def assert_all_close(a: torch.Tensor, b: torch.Tensor, rtol=0, atol=1e-1) -> None:
"""
Check that all elements of tensors a and b are within provided thresholds.
"""
assert a.shape == b.shape, f"Shapes don't match: {a.shape} != {b.shape}"
assert a.dtype == b.dtype, f"Dtypes don't match: {a.dtype} != {b.dtype}"
assert a.device == b.device, f"Devices don't match: {a.device} != {b.device}"
max_abs_diff = torch.max(torch.abs(a - b))
rel_diff = torch.abs(a / b)
max_rel_diff = torch.max(rel_diff)
mismatch_elements = torch.sum(torch.abs(a - b) > atol + rtol * torch.abs(b))
nb_elements = torch.numel(a)
msg = (
f"Differences: "
f"{max_abs_diff:.3f} (max abs), "
f"{max_rel_diff:.3f} (max rel), "
f"{mismatch_elements}/{nb_elements} (mismatch elements)"
)
assert torch.allclose(a, b, rtol=rtol, atol=atol), msg
Empty file added tests/kernels/__init__.py
Empty file.
41 changes: 41 additions & 0 deletions tests/kernels/test_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
import triton

from unsloth.kernels.gelu import gelu_forward_kenel, gelu_backward_kernel
from tests.conftest import set_seed, assert_all_close

@set_seed
@pytest.fixture(params=[(100, 100), (1024, 1024), (5000, 1024), (12345, 5678)])
def test_matrix(request):
shape = request.param
x = torch.randn(shape, device='cuda')
return x

# Test function
def test_relu_kernel_fwd(test_matrix):
# Apply your Triton-based ReLU kernel
triton_output = gelu_forward_kenel(test_matrix)

# Apply PyTorch's ReLU for comparison
torch_gelu = torch.nn.GELU()
torch_output = torch_gelu(test_matrix)

# Check if the outputs are close enough using assert_all_close
assert_all_close(triton_output, torch_output, rtol=1e-05, atol=1e-08)


# Test function for GeLU backward kernel
def test_gelu_backward_kernel(test_matrix):
# Create a tensor representing gradients (e.g., random gradients)
grad_input = torch.randn_like(test_matrix)

# Apply your Triton-based GeLU backward kernel
triton_output = gelu_backward_kernel(test_matrix, grad_input)

# Compute PyTorch's GeLU gradient for comparison
torch_gelu = torch.nn.GELU()
torch_output = torch.autograd.grad(torch_output.sum(), test_matrix, grad_outputs=grad_input)[0]

# Check if the outputs are close enough using assert_all_close
assert_all_close(triton_output, torch_output, rtol=1e-05, atol=1e-08)
54 changes: 54 additions & 0 deletions tests/kernels/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch
import triton

from unsloth.kernels.layernorm import LayerNorm
from tests.conftest import set_seed, assert_all_close

# Fixture for test matrices and associated parameters
@set_seed()
@pytest.fixture(params=[(64, 64), (1024, 512), (2048, 1024)])
def test_data(request):
torch.manual_seed(0) # For reproducibility
batch_size, num_features = request.param
x = torch.randn(batch_size, num_features, device='cuda')
weight = torch.randn(num_features, device='cuda')
bias = torch.randn(num_features, device='cuda')
eps = 1e-5
return x, weight, bias, eps

# Test forward pass
def test_layer_norm_forward(test_data):
x, weight, bias, eps = test_data

# Triton layer norm forward
triton_output = LayerNorm.apply(x, x.size(), weight, bias, eps)

# PyTorch layer norm forward
pytorch_layer_norm = torch.nn.LayerNorm(x.size()[1:], eps=eps, elementwise_affine=True)
pytorch_layer_norm.weight = torch.nn.Parameter(weight)
pytorch_layer_norm.bias = torch.nn.Parameter(bias)
pytorch_output = pytorch_layer_norm(x)

# Check if outputs are close using assert_all_close
assert_all_close(triton_output, pytorch_output, rtol=1e-05, atol=1e-08)


def test_layer_norm_backward(test_data):
x, weight, bias, eps = test_data
x.requires_grad = True

# Triton layer norm backward
triton_output = LayerNorm.apply(x, x.size(), weight, bias, eps)
triton_grad = torch.autograd.grad(triton_output.sum(), x)[0]

# PyTorch layer norm backward
pytorch_layer_norm = torch.nn.LayerNorm(x.size()[1:], eps=eps, elementwise_affine=True)
pytorch_layer_norm.weight = torch.nn.Parameter(weight)
pytorch_layer_norm.bias = torch.nn.Parameter(bias)
pytorch_output = pytorch_layer_norm(x)
pytorch_output.sum().backward()
pytorch_grad = x.grad

# Check if gradients are close using assert_all_close
assert_all_close(triton_grad, pytorch_grad, rtol=1e-05, atol=1e-08)
25 changes: 25 additions & 0 deletions tests/kernels/test_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import torch
import triton

from unsloth.kernels.relu import relu_kernel
from tests.conftest import set_seed, assert_all_close

@set_seed
@pytest.fixture(params=[(100, 100), (1024, 1024), (5000, 1024), (12345, 5678)])
def test_matrix(request):
shape = request.param
x = torch.randn(shape, device='cuda')
return x

# Test function
def test_relu_kernel(test_matrix):
# Apply your Triton-based ReLU kernel
triton_output = relu_kernel(test_matrix)

# Apply PyTorch's ReLU for comparison
torch_relu = torch.nn.ReLU()
torch_output = torch_relu(test_matrix)

# Check if the outputs are close enough using assert_all_close
assert_all_close(triton_output, torch_output, rtol=1e-05, atol=1e-08)
27 changes: 27 additions & 0 deletions tests/profiles/profile_phi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from unsloth.kernels.utils import profile_generate_method

torch.set_default_device("cuda")

model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

inputs = tokenizer('''def print_prime(n):
"""
Print all primes between 1 and n
"""''', return_tensors="pt", return_attention_mask=False)


generate_args = {
**inputs, # Assuming model_inputs is a dictionary with appropriate keys
"max_new_tokens": 100,
"do_sample": True
}

# Ensure your model and tokenizer are properly loaded and set up as before.

# Now, call the profile_generate_method function
prof = profile_generate_method(model, generate_args)

81 changes: 81 additions & 0 deletions unsloth/kernels/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import triton
import triton.language as tl
import torch

import triton
import triton.language as tl
import torch


BLOCK_SIZE = 1024

@triton.jit
def _seeded_dropout(
x_ptr: tl.intptr, # Pointer to the input tensor
output_ptr: tl.intptr, # Pointer to the output tensor
n_elements: int, # Number of elements in the input tensor
p: float, # Dropout probability
seed: int, # Seed for random number generation
BLOCK_SIZE: tl.constexpr, # Block size, a compile-time constant
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE * 4

off0 = block_start + BLOCK_SIZE * 0 + tl.arange(0, BLOCK_SIZE)
off1 = block_start + BLOCK_SIZE * 1 + tl.arange(0, BLOCK_SIZE)
off2 = block_start + BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE)
off3 = block_start + BLOCK_SIZE * 3 + tl.arange(0, BLOCK_SIZE)

mask0 = off0 < n_elements
mask1 = off1 < n_elements
mask2 = off2 < n_elements
mask3 = off3 < n_elements

x0 = tl.load(x_ptr + off0, mask = mask0)
x1 = tl.load(x_ptr + off1, mask = mask1)
x2 = tl.load(x_ptr + off2, mask = mask2)
x3 = tl.load(x_ptr + off3, mask = mask3)

r0, r1, r2, r3 = tl.random.rand4x(seed, off0)
keep0, keep1, keep2, keep3 = r0 > p, r1 > p, r2 > p, r3 > p

o0 = tl.where(keep0, x0 / (1 - p), 0.0)
o1 = tl.where(keep1, x1 / (1 - p), 0.0)
o2 = tl.where(keep2, x2 / (1 - p), 0.0)
o3 = tl.where(keep3, x3 / (1 - p), 0.0)

tl.store(output_ptr + off0, o0, mask = mask0)
tl.store(output_ptr + off1, o1, mask = mask1)
tl.store(output_ptr + off2, o2, mask = mask2)
tl.store(output_ptr + off3, o3, mask = mask3)

pass

def seeded_dropout(
x: torch.Tensor,
p: float,
seed: int,
BLOCK_SIZE: int = 1024
) -> torch.Tensor:
assert x.is_cuda and x.is_contiguous(), "Tensor must be on GPU and stored in contiguous block!"
output = torch.empty_like(x)
n_elements = x.numel()
# Define the grid size based on the number of elements and BLOCK_SIZE
grid = lambda meta: (triton.cdiv(n_elements, BLOCK_SIZE * 4),)
# Launch the kernel with the grid configuration
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=BLOCK_SIZE) # Pass BLOCK_SIZE as a named argument
return output