Skip to content

Commit

Permalink
fix: model filter issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tjbck committed May 18, 2024
1 parent 3890ea1 commit 3aa6b0f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
8 changes: 4 additions & 4 deletions backend/apps/litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ async def lifespan(app: FastAPI):
litellm_config = yaml.safe_load(file)


app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config

Expand Down Expand Up @@ -151,10 +155,6 @@ async def shutdown_litellm_background():
background_process = None


app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST


@app.get("/")
async def get_status():
return {"status": True}
Expand Down
14 changes: 8 additions & 6 deletions backend/apps/ollama/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@

app.state.config = AppConfig()

app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST

app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
Expand Down Expand Up @@ -178,11 +178,12 @@ async def get_ollama_tags(
if url_idx == None:
models = await get_all_models()

if app.state.ENABLE_MODEL_FILTER:
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
Expand Down Expand Up @@ -1046,11 +1047,12 @@ async def get_openai_models(
if url_idx == None:
models = await get_all_models()

if app.state.ENABLE_MODEL_FILTER:
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
Expand Down
9 changes: 5 additions & 4 deletions backend/apps/openai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@
allow_headers=["*"],
)


app.state.config = AppConfig()

app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST


app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
Expand Down Expand Up @@ -259,11 +260,11 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"],
)
)
Expand Down
8 changes: 4 additions & 4 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ async def update_model_filter_config(
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models

ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST

openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST

litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
Expand Down

0 comments on commit 3aa6b0f

Please sign in to comment.