Skip to content

Commit

Permalink
Merge branch 'main' into mor--inference
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Apr 26, 2024
2 parents f8c89e9 + 18ba1a7 commit 0a1a790
Show file tree
Hide file tree
Showing 34 changed files with 151 additions and 90 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ on:
- cron: '0 */2 * * *'

jobs:
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'gpu' job
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO 'gpu' job
tpu:
strategy:
fail-fast: false
Expand Down Expand Up @@ -99,7 +99,7 @@ jobs:
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
'python3 pedagogical_examples/shmap_collective_matmul.py'
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO 'tpu' job
gpu:
strategy:
fail-fast: false
Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
- id: codespell
name: Running codespell for typos
entry: codespell -w --skip="*.txt,pylintrc,.*" .
2 changes: 1 addition & 1 deletion MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
""" Static map of TPU names such as v4-8 to properties such as chip layout."""

""" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO
UserFacingNameToSystemCharacteristics in xpk/xpk.py !!!!! """

from dataclasses import dataclass
Expand Down
4 changes: 2 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommeneded
dcn_tensor_parallelism: 1 # never recommended
dcn_autoregressive_parallelism: 1 # never recommended
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
Expand Down Expand Up @@ -197,7 +197,7 @@ prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from d
autoregressive_decode_assert: ""

enable_profiler: False
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane reuslt from the first host.
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane result from the first host.
upload_all_profiler_results: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion MaxText/configs/v5e/128b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xl

python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\
steps=30 per_device_batch_size=1 enable_checkpointing=false\
enable_profiler=false remat_policy=minimal_offloaded global_parameter_scale=128\
enable_profiler=false remat_policy=qkv_proj_offloaded global_parameter_scale=128\
ici_fsdp_parallelism=16 ici_tensor_parallelism=16\
max_target_length=2048 base_output_directory=gs://runner-maxtext-logs\
use_iota_embed=true reuse_example_batch=1\
Expand Down
4 changes: 2 additions & 2 deletions MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def main(raw_args=None) -> None:

layer_weight["self_attention"] = copy.deepcopy(self_attention)
jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight)
jax_weights = jax.tree_map(jnp.array, jax_weights)
jax_weights = jax.tree_util.tree_map(jnp.array, jax_weights)

def astype_fn(x):
if isinstance(x, jnp.ndarray):
return x.astype(jnp.bfloat16)
else:
return x

jax_weights = jax.tree_map(astype_fn, jax_weights)
jax_weights = jax.tree_util.tree_map(astype_fn, jax_weights)

enable_checkpointing = True
async_checkpointing = False
Expand Down
12 changes: 6 additions & 6 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports
"""Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state.
"""Transforms a "full state" including optimizer state to a bfloat16 "parameter state" without optimizer state.
This typically used for turning a state output by training.py into a state than can be consumed by decode.py.
The input "fullstate" is passed in via:
Expand Down Expand Up @@ -54,13 +54,13 @@ def _possibly_unroll_params(config, training_state, training_state_annotations,
def new_pspec(x):
return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :])

new_per_layer_state_annotation = jax.tree_map(new_pspec, training_state_annotations_layers)
new_per_layer_state_sharding = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)
new_per_layer_state_annotation = jax.tree_util.tree_map(new_pspec, training_state_annotations_layers)
new_per_layer_state_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)

for i in range(config.num_decoder_layers):

def slice_ith(input_layers):
return jax.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers)
return jax.tree_util.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers)

new_layer = jax.jit(slice_ith, out_shardings=new_per_layer_state_sharding)(training_state_layers)

Expand All @@ -70,7 +70,7 @@ def slice_ith(input_layers):
del training_state.params["params"]["decoder"]["layers"]
del training_state_annotations.params["params"]["decoder"]["layers"]

jax.tree_map(lambda x: x.delete(), training_state_layers)
jax.tree_util.tree_map(lambda x: x.delete(), training_state_layers)


def _read_train_checkpoint(config, checkpoint_manager, mesh):
Expand All @@ -90,7 +90,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
def _save_decode_checkpoint(config, state, checkpoint_manager):
"""Generate checkpoint for decode from the training_state."""
with jax.spmd_mode("allow_all"):
decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
decode_state = max_utils.init_decode_state(None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, 0, decode_state):
max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}")
Expand Down
6 changes: 3 additions & 3 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def reduce_concat_tokens(
):
"""Token-preprocessor to concatenate multiple unrelated documents.
If we want to generate examples of exactly the right length,
(to avoid wasting space on padding), then we use this function, folowed by
(to avoid wasting space on padding), then we use this function, followed by
split_tokens.
Args:
dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
Expand Down Expand Up @@ -219,7 +219,7 @@ def get_datasets(
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})

eval_ds = eval_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"})

Expand All @@ -243,7 +243,7 @@ def preprocess_dataset(
train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length)
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)

# note eval_ds is pre tokenized, reduce_concated and splitted to target_length
# note eval_ds is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
train_ds = sequence_packing.pack_dataset(train_ds, config.max_target_length)
eval_ds = sequence_packing.pack_dataset(eval_ds, config.max_target_length)
Expand Down
4 changes: 2 additions & 2 deletions MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(
SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
)
Expand All @@ -112,7 +112,7 @@ def __next__(self):

@staticmethod
def raw_generate_synthetic_data(config):
"""Generates a single batch of syntehtic data"""
"""Generates a single batch of synthetic data"""
output = {}
output["inputs"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32)
output["inputs_position"] = jax.numpy.zeros(
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
raise ValueError("num_kv_heads is not defined.")

if self.num_query_heads % self.num_kv_heads != 0:
raise ValueError("Invaid num_kv_heads for GQA.")
raise ValueError("Invalid num_kv_heads for GQA.")

kv_proj = DenseGeneral(
features=(self.num_kv_heads, self.head_dim),
Expand Down Expand Up @@ -918,7 +918,7 @@ def __call__(
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
There are three modes: training, prefill and autoregression. During training, the KV cahce
There are three modes: training, prefill and autoregression. During training, the KV cache
is ignored. During prefill, the cache is filled. During autoregression the cache is used.
In the cache initialization call, `inputs_q` has a shape [batch, length,
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind):
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save memory.
# Instead they are retreived from the tensors stored in the 'aqt' collection.
# Instead they are retrieved from the tensors stored in the 'aqt' collection.
kernel = jnp.zeros(kernel_shape)
else:
kernel = self.param(
Expand Down
19 changes: 16 additions & 3 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,25 @@ class AqtQuantization:

def dot_general_cls(self):
"""Returns dot_general configured with aqt params."""
aqt_dg_cls = functools.partial(aqt_flax.AqtDotGeneral, self.quant_dg, rhs_quant_mode=self.quant_mode)
aqt_dg_cls = functools.partial(
aqt_flax.AqtDotGeneral,
self.quant_dg,
rhs_quant_mode=self.quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
)
return aqt_dg_cls

def einsum(self):
"""Returns einsum configured with aqt params"""
aqt_einsum = functools.partial(aqt_flax.AqtEinsum(cfg=self.quant_dg, lhs_quant_mode=self.quant_mode))
"""Returns einsum configured with aqt params."""
aqt_einsum = functools.partial(
aqt_flax.AqtEinsum(
cfg=self.quant_dg,
lhs_quant_mode=self.quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
)
)
return aqt_einsum


Expand Down
2 changes: 1 addition & 1 deletion MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def checkpoint_device_put(arr):
return jax.device_put(arr, device=s3)

# convert all weights to jax.numpy with sharding if applicable
jax_weights = jax.tree_map(checkpoint_device_put, jax_weights)
jax_weights = jax.tree_util.tree_map(checkpoint_device_put, jax_weights)

# dummy configs for the checkpoint_manager
step_number_to_save_new_ckpt = 0
Expand Down
10 changes: 5 additions & 5 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def find_nans_and_infs(pytree):
def finder(x):
return jnp.any(jnp.isinf(x) | jnp.isnan(x))

bad_pytree = jax.tree_map(finder, pytree)
bad_pytree = jax.tree_util.tree_map(finder, pytree)
return jax.tree_util.tree_flatten(bad_pytree)


Expand Down Expand Up @@ -546,15 +546,15 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss:
logits: [batch, length, num_classes] float array.
targets: categorical one-hot targets [batch, length, num_classes] float
array.
z_loss: coefficient for auxilliary z-loss loss term.
z_loss: coefficient for auxiliary z-loss loss term.
Returns:
tuple with the total loss and the z_loss, both
float arrays with shape [batch, length].
"""
logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
log_softmax = logits - logits_sum
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
# Add auxiliary z-loss term.
log_z = jnp.squeeze(logits_sum, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
Expand All @@ -574,7 +574,7 @@ def _cross_entropy_with_logits_fwd(
sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
log_softmax = shifted - jnp.log(sum_exp)
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
# Add auxiliary z-loss term.
log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
Expand Down Expand Up @@ -680,7 +680,7 @@ def delete_leaf(leaf):
leaf.delete()
del leaf

jax.tree_map(delete_leaf, p)
jax.tree_util.tree_map(delete_leaf, p)


def summarize_pytree_data(params, name="Params", raw=False):
Expand Down
12 changes: 6 additions & 6 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def load_params(self, *args, **kwargs) -> Params:
"""Load Parameters, typically from GCS"""
# pylint: disable=unused-argument
state, self.state_mesh_annotations = max_utils.setup_decode_state(self.model, self.config, self.rng, self._mesh, None)
self.abstract_params = jax.tree_map(
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params
)
self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh)
self.kv_cache_shardings = jax.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)
self.kv_cache_shardings = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)

if not self.model.quant:
self.abstract_params = jax.tree_map(
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params
)
return state.params
Expand All @@ -113,7 +113,7 @@ def model_apply(_p, _rng):
# Remove param values which have corresponding qtensors in aqt to save memory.
params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"])

self.abstract_params = jax.tree_map(
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params
)

Expand Down Expand Up @@ -342,13 +342,13 @@ def init(abstract_params):
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
mesh_annotations = nn.logical_to_mesh(logical_annotations)

shardings = jax.tree_map(
shardings = jax.tree_util.tree_map(
lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations
)

@functools.partial(jax.jit, out_shardings=shardings)
def initialize():
return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs)
return jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs)

cache = initialize()["cache"]

Expand Down
10 changes: 5 additions & 5 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations
functional_train = get_functional_train_step(train_step, model, config)
functional_train.__name__ = "train_step"
data_pspec = P(*config.data_sharding)
state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = (state_mesh_shardings, None) # State, metrics
static_argnums = () # We partial out the static argnums of model and config
Expand All @@ -51,8 +51,8 @@ def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations,
functional_eval = get_functional_eval_step(eval_step, model, config)
functional_eval.__name__ = "eval_step"
data_pspec = P(*config.data_sharding)
state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = None # metrics
static_argnums = () # We partial out the static argnums of model, config
Expand Down Expand Up @@ -169,7 +169,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01):
perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding
assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, (
"Number of parameters per chip must not be less than in the ideal sharded "
"scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes."
"scenario across `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes."
)
assert total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 < tolerance, (
f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters."
Expand Down
14 changes: 7 additions & 7 deletions MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def adam_pax(
) -> optax.GradientTransformation:
"""Standard Adam optimizer that supports weight decay.
Follows the implemenation in pax/praxis sharded_adam
Follows the implementation in pax/praxis sharded_adam
https://github.com/google/praxis/blob/545e00ab126b823265d70c715950d39333484f38/praxis/optimizers.py#L621
Args:
Expand Down Expand Up @@ -129,19 +129,19 @@ def _update_momentum(update, mu, nu):
nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu
return _slot_opt_state(mu=mu, nu=nu)

updated_moments = jax.tree_map(_update_momentum, updates, state.mu, state.nu)
updated_moments = jax.tree_util.tree_map(_update_momentum, updates, state.mu, state.nu)

mu = jax.tree_map(lambda x: x.mu, updated_moments)
nu = jax.tree_map(lambda x: x.nu, updated_moments)
mu = jax.tree_util.tree_map(lambda x: x.mu, updated_moments)
nu = jax.tree_util.tree_map(lambda x: x.nu, updated_moments)

updates = jax.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu)
updates = jax.tree_util.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu)

if weight_decay > 0:
updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params)
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)

step_size = -1.0 * learning_rate_fn(count)
# Finally, fold in step size.
updates = jax.tree_map(lambda x: step_size * x, updates)
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)

updated_states = optax.ScaleByAdamState(count=count + 1, mu=mu, nu=nu)
return updates, updated_states
Expand Down

0 comments on commit 0a1a790

Please sign in to comment.