Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
tianjing-li committed May 7, 2024
2 parents 8d60343 + 57691ad commit 53003e9
Show file tree
Hide file tree
Showing 81 changed files with 878 additions and 157 deletions.
5 changes: 4 additions & 1 deletion .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ AZURE_CHAT_ENDPOINT_URL=<ENDPOINT URL>
USE_EXPERIMENTAL_LANGCHAIN=False

# Community features
USE_COMMUNITY_FEATURES='True'
USE_COMMUNITY_FEATURES='True'

# Auth session
SESSION_SECRET_KEY=<GENERATE_A_SECRET_KEY>
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ reset-db:
docker volume rm cohere_toolkit_db
setup:
poetry install --only setup --verbose
poetry run python3 src/backend/cli/main.py
poetry run python3 cli/main.py
lint:
poetry run black .
poetry run isort .
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Toolkit is a collection of prebuilt components enabling users to quickly build a
- [How to guides](/docs/how_to_guides.md)
- [How to set up command model providers](/docs/command_model_providers.md)
- [How to add tools](/docs/custom_tool_guides/tool_guide.md)
- [How to deploy toolkit services](/docs/deployments.md)
- [How to deploy toolkit services](/docs/service_deployments.md)
- [How to contribute](#contributing)
- [Try Cohere's Command Showcase](https://coral.cohere.com/)

Expand Down
170 changes: 157 additions & 13 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ py-expression-eval = "^0.3.14"
tavily-python = "^0.3.3"
arxiv = "^2.1.0"
xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
Expand Down
4 changes: 2 additions & 2 deletions src/backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from sqlalchemy import engine_from_config, pool

# Need to import Models - note they will be unused but are required for Alembic to detect
from backend.models import *
from backend.models.base import Base
from backend.database_models import *
from backend.database_models.base import Base

load_dotenv()

Expand Down
34 changes: 34 additions & 0 deletions src/backend/alembic/versions/b88f00283a27_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""empty message
Revision ID: b88f00283a27
Revises: 2853273872ca
Create Date: 2024-05-02 19:19:52.608062
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b88f00283a27"
down_revision: Union[str, None] = "2853273872ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users", sa.Column("hashed_password", sa.LargeBinary(), nullable=True)
)
op.create_unique_constraint("unique_user_email", "users", ["email"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_user_email", "users", type_="unique")
op.drop_column("users", "hashed_password")
# ### end Alembic commands ###
26 changes: 26 additions & 0 deletions src/backend/alembic/versions/c15b848babe3_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""empty message
Revision ID: c15b848babe3
Revises: 6553b76de6ca, b88f00283a27
Create Date: 2024-05-07 15:59:05.436751
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c15b848babe3"
down_revision: Union[str, None] = ("6553b76de6ca", "b88f00283a27")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
4 changes: 2 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from fastapi import HTTPException

from backend.chat.base import BaseChat
from backend.chat.custom.model_deployments.base import BaseDeployment
from backend.chat.custom.model_deployments.deployment import get_deployment
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_deployment
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.tool import Category, Tool
from backend.services.logger import get_logger
Expand Down
9 changes: 0 additions & 9 deletions src/backend/chat/custom/model_deployments/__init__.py

This file was deleted.

9 changes: 9 additions & 0 deletions src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backend.services.auth import BasicAuthentication

# Modify this to enable auth strategies.
ENABLED_AUTH_STRATEGIES = []

# Define the mapping from Auth strategy name to class obj.
# Does not need to be manually modified.
# Ex: {"Basic": BasicAuthentication}
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls for cls in ENABLED_AUTH_STRATEGIES}
8 changes: 5 additions & 3 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from distutils.util import strtobool
from enum import StrEnum

from backend.chat.custom.model_deployments.azure import AzureDeployment
from backend.chat.custom.model_deployments.cohere_platform import CohereDeployment
from backend.chat.custom.model_deployments.sagemaker import SageMakerDeployment
from backend.model_deployments import (
AzureDeployment,
CohereDeployment,
SageMakerDeployment,
)
from backend.schemas.deployment import Deployment


Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/citation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.models.citation import Citation
from backend.database_models.citation import Citation


def create_citation(db: Session, citation: Citation) -> Citation:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/conversation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.models.conversation import Conversation
from backend.database_models.conversation import Conversation
from backend.schemas.conversation import UpdateConversation


Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.models.document import Document
from backend.database_models.document import Document


def create_document(db: Session, document: Document) -> Document:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.models.file import File
from backend.database_models.file import File
from backend.schemas.file import UpdateFile


Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.models.message import Message
from backend.database_models.message import Message
from backend.schemas.message import UpdateMessage


Expand Down
4 changes: 2 additions & 2 deletions src/backend/crud/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.models.user import User
from backend.database_models.user import User
from backend.schemas.user import UpdateUser


Expand Down Expand Up @@ -62,7 +62,7 @@ def update_user(db: Session, user: User, new_user: UpdateUser) -> User:
Returns:
User: Updated user.
"""
for attr, value in new_user.model_dump().items():
for attr, value in new_user.model_dump(exclude_none=True).items():
setattr(user, attr, value)
db.commit()
db.refresh(user)
Expand Down
8 changes: 8 additions & 0 deletions src/backend/database_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from backend.database_models.base import *
from backend.database_models.citation import *
from backend.database_models.conversation import *
from backend.database_models.database import *
from backend.database_models.document import *
from backend.database_models.file import *
from backend.database_models.message import *
from backend.database_models.user import *
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column, relationship

from backend.models.base import Base
from backend.models.document import Document
from backend.database_models.base import Base
from backend.database_models.document import Document

citation_documents = Table(
"citation_documents",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from sqlalchemy import Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from backend.models.base import Base
from backend.models.file import File
from backend.models.message import Message
from backend.database_models.base import Base
from backend.database_models.file import File
from backend.database_models.message import Message


class Conversation(Base):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy import JSON, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column

from backend.models.base import Base
from backend.database_models.base import Base


class Document(Base):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy import ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column

from backend.models.base import Base
from backend.database_models.base import Base


class File(Base):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from sqlalchemy import Boolean, Enum, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from backend.models.base import Base
from backend.models.citation import Citation
from backend.models.document import Document
from backend.models.file import File
from backend.database_models.base import Base
from backend.database_models.citation import Citation
from backend.database_models.document import Document
from backend.database_models.file import File


class MessageAgent(StrEnum):
Expand Down
16 changes: 16 additions & 0 deletions src/backend/database_models/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional

from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base


class User(Base):
__tablename__ = "users"

fullname: Mapped[str] = mapped_column()
email: Mapped[Optional[str]] = mapped_column()
hashed_password: Mapped[Optional[bytes]] = mapped_column()

__table_args__ = (UniqueConstraint("email", name="unique_user_email"),)
36 changes: 32 additions & 4 deletions src/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from contextlib import asynccontextmanager

from alembic.command import upgrade
from alembic.config import Config
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
from backend.routers.conversation import router as conversation_router
from backend.routers.deployment import router as deployment_router
Expand All @@ -15,32 +19,56 @@

load_dotenv()

ORIGINS = ["*"]


@asynccontextmanager
async def lifespan(app: FastAPI):
yield


origins = ["*"]


def create_app():
app = FastAPI(lifespan=lifespan)

# Add routers
app.include_router(auth_router)
app.include_router(chat_router)
app.include_router(user_router)
app.include_router(conversation_router)
app.include_router(tool_router)
app.include_router(deployment_router)
app.include_router(experimental_feature_router)

# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_origins=ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

if ENABLED_AUTH_STRATEGY_MAPPING:
secret_key = os.environ.get("SESSION_SECRET_KEY", None)

if not secret_key:
raise ValueError(
"Missing SESSION_SECRET_KEY environment variable to enable Authentication."
)

# Handle User sessions and Auth
app.add_middleware(
SessionMiddleware,
secret_key=secret_key,
)

# Add auth
for auth in ENABLED_AUTH_STRATEGY_MAPPING.values():
if auth.SHOULD_ATTACH_TO_APP:
# TODO: Add app attachment logic for eg OAuth:
# https://docs.authlib.org/en/latest/client/fastapi.html
pass

return app


Expand Down
9 changes: 9 additions & 0 deletions src/backend/model_deployments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backend.model_deployments.azure import AzureDeployment
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.model_deployments.sagemaker import SageMakerDeployment

__all__ = [
"AzureDeployment",
"CohereDeployment",
"SageMakerDeployment",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cohere
from cohere.types import StreamedChatResponse

from backend.chat.custom.model_deployments.base import BaseDeployment
from backend.model_deployments.base import BaseDeployment
from backend.schemas.cohere_chat import CohereChatRequest


Expand Down

0 comments on commit 53003e9

Please sign in to comment.