Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models): add sorting and filtering #1451

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 33 additions & 3 deletions web/src/components/table/use-cases/models.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { DataTable } from "@/src/components/table/data-table";
import { DataTableToolbar } from "@/src/components/table/data-table-toolbar";
import { type LangfuseColumnDef } from "@/src/components/table/types";
import { Button } from "@/src/components/ui/button";
import useColumnVisibility from "@/src/features/column-visibility/hooks/useColumnVisibility";
import { useQueryFilterState } from "@/src/features/filters/hooks/useFilterState";
import { useOrderByState } from "@/src/features/orderBy/hooks/useOrderByState";
import { useHasAccess } from "@/src/features/rbac/utils/checkAccess";
import { modelsTableCols } from "@/src/server/api/definitions/modelsTable";
import { api } from "@/src/utils/api";
import { usdFormatter } from "@/src/utils/numbers";
import { type Prisma, type Model } from "@langfuse/shared/src/db";
Expand Down Expand Up @@ -47,20 +51,26 @@ export default function ModelTable({ projectId }: { projectId: string }) {
pageIndex: withDefault(NumberParam, 0),
pageSize: withDefault(NumberParam, 50),
});

const [filterState, setFilterState] = useQueryFilterState([], "models");
const [orderByState, setOrderByState] = useOrderByState({
column: "modelName",
order: "DESC",
});
const models = api.models.all.useQuery({
page: paginationState.pageIndex,
limit: paginationState.pageSize,
projectId,
filter: filterState,
orderBy: orderByState,
});
const totalCount = models.data?.totalCount ?? 0;

const columns: LangfuseColumnDef<ModelTableRow>[] = [
{
accessorKey: "maintainer",
id: "maintainer",
enableColumnFilter: true,
header: "Maintainer",
enableSorting: true,
},
{
accessorKey: "modelName",
Expand All @@ -69,6 +79,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
headerTooltip: {
description: modelConfigDescriptions.modelName,
},
enableSorting: true,
},
{
accessorKey: "startDate",
Expand All @@ -77,6 +88,8 @@ export default function ModelTable({ projectId }: { projectId: string }) {
headerTooltip: {
description: modelConfigDescriptions.startDate,
},
enableHiding: true,
enableSorting: true,
cell: ({ row }) => {
const value: Date | undefined = row.getValue("startDate");

Expand All @@ -94,6 +107,8 @@ export default function ModelTable({ projectId }: { projectId: string }) {
description: modelConfigDescriptions.matchPattern,
},
header: "Match Pattern",
enableHiding: true,
enableSorting: true,
cell: ({ row }) => {
const value: string = row.getValue("matchPattern");

Expand All @@ -118,6 +133,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
headerTooltip: {
description: modelConfigDescriptions.inputPrice,
},
enableSorting: true,
cell: ({ row }) => {
const value: Decimal | undefined = row.getValue("inputPrice");

Expand All @@ -144,6 +160,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
</>
);
},
enableSorting: true,
cell: ({ row }) => {
const value: Decimal | undefined = row.getValue("outputPrice");

Expand All @@ -170,6 +187,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
headerTooltip: {
description: modelConfigDescriptions.totalPrice,
},
enableSorting: true,
cell: ({ row }) => {
const value: Decimal | undefined = row.getValue("totalPrice");

Expand All @@ -189,6 +207,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
headerTooltip: {
description: modelConfigDescriptions.unit,
},
enableSorting: true,
enableHiding: true,
},
{
Expand All @@ -198,6 +217,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
headerTooltip: {
description: modelConfigDescriptions.tokenizerId,
},
enableSorting: true,
enableHiding: true,
},
{
Expand Down Expand Up @@ -235,7 +255,7 @@ export default function ModelTable({ projectId }: { projectId: string }) {
];

const [columnVisibility, setColumnVisibility] =
useColumnVisibility<ModelTableRow>("scoresColumnVisibility", columns);
useColumnVisibility<ModelTableRow>("modelsColumnVisibility", columns);

const convertToTableRow = (model: Model): ModelTableRow => {
return {
Expand All @@ -257,6 +277,14 @@ export default function ModelTable({ projectId }: { projectId: string }) {

return (
<div>
<DataTableToolbar
columns={columns}
filterColumnDefinition={modelsTableCols}
filterState={filterState}
setFilterState={setFilterState}
columnVisibility={columnVisibility}
setColumnVisibility={setColumnVisibility}
/>
<DataTable
columns={columns}
data={
Expand All @@ -279,6 +307,8 @@ export default function ModelTable({ projectId }: { projectId: string }) {
onChange: setPaginationState,
state: paginationState,
}}
orderBy={orderByState}
setOrderBy={setOrderByState}
columnVisibility={columnVisibility}
onColumnVisibilityChange={setColumnVisibility}
/>
Expand Down
19 changes: 14 additions & 5 deletions web/src/features/filters/components/filter-builder.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,20 @@ function FilterBuilderForm({
<SelectValue placeholder="" />
</SelectTrigger>
<SelectContent>
{["true", "false"].map((option) => (
<SelectItem key={option} value={option}>
{option}
</SelectItem>
))}
{filter.column === "Maintainer"
RichardKruemmel marked this conversation as resolved.
Show resolved Hide resolved
? ["User", "Langfuse"].map((option) => (
<SelectItem
key={option}
value={option === "User" ? "true" : "false"}
>
{option}
</SelectItem>
))
: ["true", "false"].map((option) => (
<SelectItem key={option} value={option}>
{option}
</SelectItem>
))}
</SelectContent>
</Select>
) : (
Expand Down
2 changes: 2 additions & 0 deletions web/src/features/filters/hooks/useFilterState.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { type FilterState, type TableName } from "@/src/features/filters/types";
import { modelsTableCols } from "@/src/server/api/definitions/modelsTable";
import { observationsTableCols } from "@/src/server/api/definitions/observationsTable";
import { scoresTableCols } from "@/src/server/api/definitions/scoresTable";
import { sessionsViewCols } from "@/src/server/api/definitions/sessionsView";
Expand Down Expand Up @@ -102,6 +103,7 @@ const tableCols = {
traces: tracesTableCols,
sessions: sessionsViewCols,
scores: scoresTableCols,
models: modelsTableCols,
dashboard: [{ id: "traceName", name: "traceName" }],
};

Expand Down
1 change: 1 addition & 0 deletions web/src/features/filters/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ export type TableName =
| "generations"
| "sessions"
| "scores"
| "models"
| "dashboard";
53 changes: 53 additions & 0 deletions web/src/server/api/definitions/modelsTable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { type ColumnDefinition } from "@/src/server/api/interfaces/tableDefinition";

export const modelsTableCols: ColumnDefinition[] = [
{
name: "Maintainer",
id: "maintainer",
type: "boolean",
internal: '(m."project_id" IS NOT NULL)',
},
{
name: "Model Name",
id: "modelName",
type: "string",
internal: 'm."model_name"',
},
{
name: "Match Pattern",
id: "matchPattern",
type: "string",
internal: 'm."match_pattern"',
},
{
name: "Start Date",
id: "startDate",
type: "datetime",
internal: 'm."start_date"',
},
{
name: "Input Price",
id: "inputPrice",
type: "number",
internal: 'm."input_price"',
},
{
name: "Output Price",
id: "outputPrice",
type: "number",
internal: 'm."output_price"',
},
{
name: "Total Price",
id: "totalPrice",
type: "number",
internal: 'm."total_price"',
},
{ name: "Unit", id: "unit", type: "string", internal: 'm."unit"' },
{
name: "Tokenizer",
id: "tokenizerId",
type: "string",
internal: 'm."tokenizer_id"',
},
];
1 change: 1 addition & 0 deletions web/src/server/api/interfaces/tableDefinition.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export const tableNames = [
"traces_metrics",
"traces_parent_observation_scores",
"sessions",
"models",
] as const;

export type TableNames = (typeof tableNames)[number];
Expand Down
96 changes: 72 additions & 24 deletions web/src/server/api/routers/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,73 @@ import {
protectedProjectProcedure,
} from "@/src/server/api/trpc";
import { paginationZod } from "@/src/utils/zod";
import { type Model, Prisma } from "@langfuse/shared/src/db";
import { TRPCError } from "@trpc/server";
import { auditLog } from "@/src/features/audit-logs/auditLog";
import { Prisma } from "@langfuse/shared/src/db";
import { singleFilter } from "@/src/server/api/interfaces/filters";
import { tableColumnsToSqlFilterAndPrefix } from "@/src/features/filters/server/filterToPrisma";
import { modelsTableCols } from "@/src/server/api/definitions/modelsTable";
import { orderBy } from "@/src/server/api/interfaces/orderBy";
import { orderByToPrismaSql } from "@/src/features/orderBy/server/orderByToPrisma";

const ModelAllOptions = z.object({
projectId: z.string(),
filter: z.array(singleFilter),
orderBy: orderBy,
...paginationZod,
});

export const modelRouter = createTRPCRouter({
all: protectedProjectProcedure
.input(ModelAllOptions)
.query(async ({ input, ctx }) => {
const models = await ctx.prisma.model.findMany({
where: {
OR: [{ projectId: input.projectId }, { projectId: null }],
},
skip: input.page * input.limit,
orderBy: [
{ modelName: "asc" },
{ unit: "asc" },
{
startDate: {
sort: "desc",
nulls: "last",
},
},
],
take: input.limit,
});
const filterCondition = tableColumnsToSqlFilterAndPrefix(
input.filter,
modelsTableCols,
"models",
);

const totalAmount = await ctx.prisma.model.count({
where: {
OR: [{ projectId: input.projectId }, { projectId: null }],
},
});
const orderByCondition = orderByToPrismaSql(
input.orderBy,
modelsTableCols,
);

const models = await ctx.prisma.$queryRaw<Array<Model>>(
generateModelsQuery(
Prisma.sql`
m.id,
m.project_id as "projectId",
m.model_name as "modelName",
m.match_pattern as "matchPattern",
m.start_date as "startDate",
m.input_price as "inputPrice",
m.output_price as "outputPrice",
m.total_price as "totalPrice",
m.unit,
m.tokenizer_id as "tokenizerId"`,
input.projectId,
filterCondition,
orderByCondition,
input.limit,
input.page,
),
);
const totalAmount = await ctx.prisma.$queryRaw<
Array<{ totalCount: bigint }>
>(
generateModelsQuery(
Prisma.sql` count(*) AS "totalCount"`,
input.projectId,
filterCondition,
Prisma.empty,
1, // limit
0, // page
),
);
return {
models,
totalCount: totalAmount,
totalCount:
totalAmount.length > 0 ? Number(totalAmount[0]?.totalCount) : 0,
};
}),
modelNames: protectedProjectProcedure
Expand Down Expand Up @@ -158,3 +186,23 @@ export const modelRouter = createTRPCRouter({
return createdModel;
}),
});

const generateModelsQuery = (
select: Prisma.Sql,
projectId: string,
filterCondition: Prisma.Sql,
orderCondition: Prisma.Sql,
limit: number,
page: number,
) => {
return Prisma.sql`
SELECT
${select}
FROM models m
WHERE (m.project_id = ${projectId} OR m.project_id IS NULL)
${filterCondition}
${orderCondition}
LIMIT ${limit}
OFFSET ${page * limit}
`;
};