Skip to content

Commit

Permalink
Merge pull request #262 from EricLBuehler/lora_swapping
Browse files Browse the repository at this point in the history
Implement dynamic LoRA swapping
  • Loading branch information
EricLBuehler committed May 12, 2024
2 parents 62e4402 + f5400d3 commit 96f25d5
Show file tree
Hide file tree
Showing 51 changed files with 1,820 additions and 353 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Mistral.rs is a fast LLM inference platform supporting inference on a variety of
- Fast LoRA support with weight merging.
- First X-LoRA inference platform with first class support.
- Speculative Decoding: Mix supported models as the draft model or the target model
- Dynamic LoRA adapter swapping at runtime with adapter preloading: [examples and docs](docs/ADAPTER_MODELS.md#adapter-model-dynamic-adapter-activation)


This is a demo of interactive mode with streaming running Mistral GGUF:
Expand Down
32 changes: 27 additions & 5 deletions docs/ADAPTER_MODELS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Adapter model support
An adapter model is a model with X-LoRA or LoRA. X-LoRA support is provided by selecting the `x-lora-*` architecture, and LoRA support by selecting the `lora-*` architecture. For both X-LoRA and LoRA, an ordering file (see [this section](#adapter-ordering-file) for preparing the ordering file) must be provided. The ordering file describes the ordering of layers and which adapters to use (and what order to use them in for X-LoRA).
An adapter model is a model with X-LoRA or LoRA. X-LoRA support is provided by selecting an `XLora*` architecture, and LoRA support by selecting the `Lora*` architecture. For both X-LoRA and LoRA, an ordering file (see [this section](#adapter-ordering-file) for preparing the ordering file) must be provided. The ordering file describes the ordering of layers and which adapters to use (and what order to use them in for X-LoRA).

When using an adapter model with a quantized base model, if the ordering file specifies unsupported layers you will receive an error.

Expand All @@ -23,7 +23,7 @@ When using an adapter model with a quantized base model, if the ordering file sp

## Adapter ordering file
**Preparing the X-LoRA/LoRA Ordering File**
The X-LoRA/LoRA ordering file is necessary to prepare before inference with an X-LoRA model. However, it is easy with a provided [`script`](scripts/create_ordering.py)!
The X-LoRA/LoRA ordering file is necessary to prepare before inference with an X-LoRA model. However, it is easy with a provided [`script`](../scripts/create_ordering.py)!

The X-LoRA/LoRA ordering JSON file contains 2 parts. The first is the order of the adapters and the second, the layer ordering. The layer ordering has been automatically generated and should not be manipulated as it controls the application of scalings. However the order of adapter should be an array of strings which are the adapter names corresponding to the order the adapters were specified during training. For example, if the adapters were specified as a dictionary:

Expand All @@ -43,13 +43,13 @@ There are 2 scripts to prepare the ordering file. The ordering file is specific

1) From scratch: No ordering file for the architecture and target modules

A script [`create_ordering.py`](scripts/create_ordering.py) is provided which prompts the user for the model ID, target modules, and adapter names. The user is prompted for an output file location, relative to the working directory.
A script [`create_ordering.py`](../scripts/create_ordering.py) is provided which prompts the user for the model ID, target modules, and adapter names. The user is prompted for an output file location, relative to the working directory.

2) Create a new ordering file from an existing ordering file for an architecture and target modules

A script [`modify_names.py`](scripts/modify_names.py) is provided which prompts the user for the adapter names and the old ordering file. The user is prompted for an output file location, relative to the working directory.
A script [`set_names.py`](../scripts/set_names.py) is provided which prompts the user for the adapter names and the old ordering file. The user is prompted for an output file location, relative to the working directory.

We provide an [ordering file](scripts/xlora-paper-ordering.json) which contains the ordering for the X-LoRA model associated with [the paper](https://arxiv.org/abs/2402.07148) and the Huggingface repository: https://huggingface.co/lamm-mit/x-lora.
We provide an [ordering file](../orderings/xlora-paper-ordering.json) which contains the ordering for the X-LoRA model associated with [the paper](https://arxiv.org/abs/2402.07148) and the Huggingface repository: https://huggingface.co/lamm-mit/x-lora.

**Quantized X-LoRA or LoRA models**

Expand All @@ -61,3 +61,25 @@ Mistral.rs supports running quantized models with X-LoRA or LoRA. The X-LoRA or
The X-LoRA implementation supports non-granular scalings. This caches the scalings after `k` completion tokens are generated and they will be used for the remaining passes avoiding the scaling pass. The number of tokens to generate before caching is defined by setting `tgt_non_granular_index`. Setting `tgt_non_granular_index` will restrict the maximum running sequences to 1.

Please see [this page](NON_GRANULAR.md) for more details and examples.

## Adapter model dynamic adapter activation

We support dynamic adapter activation for LoRA models, allowing you to activate a set of adapters at runtime. There is a Python, Rust and HTTP API:

- Rust: [example](../mistralrs/examples/lora/main.rs)
- Python: [example](../examples/python/lora_zephyr.py)
- HTTP: [example](../examples/http.md)

To use this feature, you should add a `preload_adapters` key to your ordering file:
```diff
{
"order": ["..."],
"layers": {"...": "123"},
"base_model_id": "...",
+ "preload_adapters": [{"name": "...", "adapter_model_id": "..."}] # New field here
}
```

This allows mistral.rs to preload the adapter and enable runtime activation.

We also provide a script to add this key to your existing order file: [`load_add_preload_adapters.py`](../scripts/lora_add_preload_adapters.py).
2 changes: 1 addition & 1 deletion docs/ADDING_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ fn new(
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: VarBuilder,
+ lora_config: &[(String, LoraConfig)],
+ lora_config: &[((String, String), LoraConfig)],
+ count: &mut usize,
+ ord: &Ordering,
) -> Result<Self> {
Expand Down
45 changes: 45 additions & 0 deletions examples/http.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ curl http://localhost:8080/v1/completions \
}'
```

## `POST`: `/activate_adapters`
Make the specified adapters the active adapters. Pass the names as a JSON object with the key `adapter_names` to an array of strings (the adapter names).

Example with `curl`:
```bash
curl http://localhost:<port>/activate_adapters -H "Content-Type: application/json" -H "Authorization: Bearer EMPTY" -d '{"adapter_names":["adapter_2"]}'
```

## `POST`: `/re_isq`
Reapply ISQ to the model if possible. Pass the names as a JSON object with the key `ggml_type` to a string (the quantization level).

Example with `curl`:
```bash
curl http://localhost:<port>/re_isq -H "Content-Type: application/json" -H "Authorization: Bearer EMPTY" -d '{"ggml_type":"Q4K"}'
```


Streaming requests are not supported.

## Request
Expand All @@ -136,6 +153,34 @@ pub struct ChatCompletionRequest {
// Default -1 to consider all
pub top_k: Option<i64>,
pub stream: bool,
pub adapters: Option<Vec<String>>,
}
```

### `CompletionRequest`
```rust
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub best_of: usize,
pub echo_prompt: bool,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub logit_bias: Option<HashMap<u32, f32>>,
// Default false
pub logprobs: Option<usize>,
pub max_tokens: Option<usize>,
// Default 1
pub n: usize,
pub stop_seqs: Option<StopTokens>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub suffix: Option<String>,

// mistral.rs additional
pub top_k: Option<usize>,
pub grammar: Option<Grammar>,
pub adapters: Option<Vec<String>>,
}
```

Expand Down
29 changes: 29 additions & 0 deletions examples/python/lora_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from mistralrs import Runner, Which, ChatCompletionRequest

runner = Runner(
which=Which.LoraGGUF(
tok_model_id=None, # Automatically determine from ordering file
quantized_model_id="TheBloke/zephyr-7B-beta-GGUF",
quantized_filename="zephyr-7b-beta.Q4_0.gguf",
tokenizer_json=None,
repeat_last_n=64,
xlora_model_id="lamm-mit/x-lora",
order="orderings/xlora-paper-ordering.json",
)
)

res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
temperature=0.5,
adapters=["adapter_4"],
)
)
print(res.choices[0].message.content)
print(res.usage)
31 changes: 31 additions & 0 deletions examples/python/lora_zephyr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from mistralrs import Runner, Which, ChatCompletionRequest

runner = Runner(
which=Which.LoraGGUF(
tok_model_id=None, # Automatically determine from ordering file
quantized_model_id="TheBloke/zephyr-7B-beta-GGUF",
quantized_filename="zephyr-7b-beta.Q4_0.gguf",
tokenizer_json=None,
repeat_last_n=64,
xlora_model_id="lamm-mit/x-lora",
order="orderings/xlora-paper-ordering.json",
)
)

# Example: Make adapter_3 the active adapter
runner.activate_adapters(["adapter_3"])

res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
temperature=0.5,
)
)
print(res.choices[0].message.content)
print(res.usage)
59 changes: 59 additions & 0 deletions examples/server/adapter_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import openai
import httpx
import textwrap, json


def log_response(response: httpx.Response):
request = response.request
print(f"Request: {request.method} {request.url}")
print(" Headers:")
for key, value in request.headers.items():
if key.lower() == "authorization":
value = "[...]"
if key.lower() == "cookie":
value = value.split("=")[0] + "=..."
print(f" {key}: {value}")
print(" Body:")
try:
request_body = json.loads(request.content)
print(textwrap.indent(json.dumps(request_body, indent=2), " "))
except json.JSONDecodeError:
print(textwrap.indent(request.content.decode(), " "))
print(f"Response: status_code={response.status_code}")
print(" Headers:")
for key, value in response.headers.items():
if key.lower() == "set-cookie":
value = value.split("=")[0] + "=..."
print(f" {key}: {value}")


openai.api_key = "EMPTY"
openai.base_url = "http://localhost:1234/v1/"

# Enable this to log requests and responses
# openai.http_client = httpx.Client(
# event_hooks={"request": [print], "response": [log_response]}
# )

messages = []
prompt = input("Enter system prompt >>> ")
if len(prompt) > 0:
messages.append({"role": "system", "content": prompt})


while True:
prompt = input(">>> ")
adapter = input("Active adapter >>> ")
messages.append({"role": "user", "content": prompt})
completion = openai.chat.completions.create(
model="mistral",
messages=messages,
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
extra_body={"adapters": [adapter]},
)
resp = completion.choices[0].message.content
print(resp)
messages.append({"role": "assistant", "content": resp})
14 changes: 8 additions & 6 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use mistralrs_core::{
Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, ModelKind,
ModelSelected, Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource,
Usage,
ModelSelected, NormalRequest, Request, RequestMessage, Response, SamplingParams,
SchedulerMethod, TokenSource, Usage,
};
use std::fmt::Display;
use std::sync::Arc;
Expand Down Expand Up @@ -68,7 +68,7 @@ fn run_bench(
let sender = mistralrs.get_sender();
let (tx, mut rx) = channel(10_000);

let req = Request {
let req = Request::Normal(NormalRequest {
id: mistralrs.next_request_id(),
messages: prompt,
sampling_params: sampling_params.clone(),
Expand All @@ -77,7 +77,8 @@ fn run_bench(
is_streaming: false,
constraint: Constraint::None,
suffix: None,
};
adapters: None,
});

let mut usages = Vec::new();

Expand Down Expand Up @@ -225,7 +226,7 @@ fn warmup_run(mistralrs: Arc<MistralRs>) {
let sender = mistralrs.get_sender();
let (tx, mut rx) = channel(10_000);

let req = Request {
let req = Request::Normal(NormalRequest {
id: mistralrs.next_request_id(),
messages: RequestMessage::Completion {
text: "Hello!".to_string(),
Expand All @@ -238,7 +239,8 @@ fn warmup_run(mistralrs: Arc<MistralRs>) {
is_streaming: false,
constraint: Constraint::None,
suffix: None,
};
adapters: None,
});

sender
.blocking_send(req.clone())
Expand Down

0 comments on commit 96f25d5

Please sign in to comment.