Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prompts): add table to compare metrics across prompt versions, [promptName]/metrics #1991

Merged
merged 29 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
cf122ab
feat: add tabs on prompt page to switch editor vs table view
marliessophie May 6, 2024
ea8bf87
feat: add `metrics` prompt procedure
marliessophie May 7, 2024
8b49963
feat: add `verison-table` page incl. all but score data
marliessophie May 7, 2024
fd1025c
feat: add linked cell for prompt version
marliessophie May 7, 2024
8be092a
refactor: adjust label cell styling
marliessophie May 7, 2024
b220b43
feat: add average observation score to table
marliessophie May 8, 2024
a0c585d
feat: add average trace score to table
marliessophie May 8, 2024
bc7ab4f
docs: add comment to observation schema
marliessophie May 8, 2024
c6ad847
Merge branch 'main' into marlies/lfe-975
marliessophie May 8, 2024
5efa0d9
fix: metrics query
marliessophie May 8, 2024
5cc3afe
Merge branch 'main' into marlies/lfe-975
maxdeichmann May 9, 2024
0968ee4
fix: observation scores query and display on UI
marliessophie May 10, 2024
13cf680
fix: only check `generations` type
marliessophie May 10, 2024
7a86e73
feat: add traces query
marliessophie May 10, 2024
e98a0ae
fix: using wrong variable bc of typo
marliessophie May 10, 2024
a704079
fix: format latency with 3 decimals
marliessophie May 10, 2024
b5b9b7b
query update
marcklingen May 10, 2024
6783f46
feat: add toggle columns to show/hide cols
marliessophie May 10, 2024
4677394
fix: add gap between labels
marliessophie May 10, 2024
e8a3531
feat: add pagination for `prompts.metrics`
marliessophie May 10, 2024
7fbc25d
feat: add optional pagination for `prompts.allVersions`
marliessophie May 10, 2024
b6801c3
feat: set default version table pagination to 50
marliessophie May 10, 2024
991b9f6
fix: pagination
marliessophie May 10, 2024
3ddadb0
skip batch
marliessophie May 10, 2024
762197a
feat: add cell loading skeleton
marliessophie May 10, 2024
821ef3d
Merge branch 'main' into marlies/lfe-975
marliessophie May 10, 2024
a299611
rename table
marcklingen May 14, 2024
7918d20
Merge branch 'main' into marlies/lfe-975
marcklingen May 14, 2024
773a59d
rename to metrics
marcklingen May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/shared/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ view TraceView {
}

// Update ObservationView below when making changes to this model!
// traceId is optional only due to timing during data injestion
// (traceId is not necessarily known at the time of observation creation)
model Observation {
id String @id @default(cuid())
traceId String? @map("trace_id")
Expand Down
21 changes: 17 additions & 4 deletions web/src/features/prompts/components/prompt-detail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
ChatMlArraySchema,
OpenAiMessageView,
} from "@/src/components/trace/IOPreview";
import { Tabs, TabsList, TabsTrigger } from "@/src/components/ui/tabs";
import { Badge } from "@/src/components/ui/badge";
import { Button } from "@/src/components/ui/button";
import { CodeView, JSONView } from "@/src/components/ui/CodeJsonViewer";
Expand Down Expand Up @@ -49,10 +50,10 @@ export const PromptDetail = () => {
{ enabled: Boolean(projectId) },
);
const prompt = currentPromptVersion
? promptHistory.data?.find(
? promptHistory.data?.promptVersions.find(
(prompt) => prompt.version === currentPromptVersion,
)
: promptHistory.data?.[0];
: promptHistory.data?.promptVersions[0];

const extractedVariables = prompt
? extractVariables(JSON.stringify(prompt.prompt))
Expand Down Expand Up @@ -143,14 +144,26 @@ export const PromptDetail = () => {
<DeletePromptVersion
promptVersionId={prompt.id}
version={prompt.version}
countVersions={promptHistory.data.length}
countVersions={promptHistory.data.totalCount}
/>
<DetailPageNav
key="nav"
currentId={promptName}
path={(name) => `/project/${projectId}/prompts/${name}`}
listKey="prompts"
/>
<Tabs value="editor">
<TabsList>
<TabsTrigger value="editor">Editor</TabsTrigger>
<TabsTrigger value="metrics" asChild>
<Link
href={`/project/${projectId}/prompts/${encodeURIComponent(promptName)}/metrics`}
>
Metrics
</Link>
</TabsTrigger>
</TabsList>
</Tabs>
</>
}
/>
Expand Down Expand Up @@ -229,7 +242,7 @@ export const PromptDetail = () => {
<div className="text-m px-3 font-medium">
<ScrollArea className="flex border-l pl-2">
<PromptHistoryNode
prompts={promptHistory.data}
prompts={promptHistory.data.promptVersions}
currentPromptVersion={prompt.version}
setCurrentPromptVersion={setCurrentPromptVersion}
/>
Expand Down
4 changes: 2 additions & 2 deletions web/src/features/prompts/components/prompt-history.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { type NextRouter, useRouter } from "next/router";

const PromptHistoryTraceNode = (props: {
index: number;
prompt: RouterOutputs["prompts"]["allVersions"][number];
prompt: RouterOutputs["prompts"]["allVersions"]["promptVersions"][number];
currentPromptVersion: number | undefined;
setCurrentPromptVersion: (version: number | undefined) => void;
router: NextRouter;
Expand Down Expand Up @@ -57,7 +57,7 @@ const PromptHistoryTraceNode = (props: {
};

export const PromptHistoryNode = (props: {
prompts: RouterOutputs["prompts"]["allVersions"];
prompts: RouterOutputs["prompts"]["allVersions"]["promptVersions"];
currentPromptVersion: number | undefined;
setCurrentPromptVersion: (id: number | undefined) => void;
}) => {
Expand Down
183 changes: 180 additions & 3 deletions web/src/features/prompts/server/routers/promptRouter.ts
marliessophie marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { type Prompt, Prisma } from "@langfuse/shared/src/db";
import { createPrompt } from "../actions/createPrompt";
import { orderByToPrismaSql } from "@/src/features/orderBy/server/orderByToPrisma";
import { promptsTableCols } from "@/src/server/api/definitions/promptsTable";
import { paginationZod } from "@/src/utils/zod";
import { optionalPaginationZod, paginationZod } from "@/src/utils/zod";
import {
orderBy,
singleFilter,
Expand Down Expand Up @@ -493,7 +493,13 @@ export const promptRouter = createTRPCRouter({
}
}),
allVersions: protectedProjectProcedure
.input(z.object({ projectId: z.string(), name: z.string() }))
.input(
z.object({
projectId: z.string(),
name: z.string(),
...optionalPaginationZod,
}),
)
.query(async ({ input, ctx }) => {
throwIfNoAccess({
session: ctx.session,
Expand All @@ -505,8 +511,19 @@ export const promptRouter = createTRPCRouter({
projectId: input.projectId,
name: input.name,
},
...(input.limit !== undefined && input.page !== undefined
? { take: input.limit, skip: input.page * input.limit }
: undefined),
orderBy: [{ version: "desc" }],
});

const totalCount = await ctx.prisma.prompt.count({
where: {
projectId: input.projectId,
name: input.name,
},
});

const userIds = prompts
.map((p) => p.createdBy)
.filter((id) => id !== "API");
Expand Down Expand Up @@ -538,7 +555,167 @@ export const promptRouter = createTRPCRouter({
creator: user?.name,
};
});
return joinedPromptAndUsers;
return { promptVersions: joinedPromptAndUsers, totalCount };
}),
metrics: protectedProjectProcedure
.input(
z.object({
projectId: z.string(),
promptIds: z.array(z.string()),
}),
)
.query(async ({ input, ctx }) => {
throwIfNoAccess({
session: ctx.session,
projectId: input.projectId,
scope: "prompts:read",
});

const metrics = await ctx.prisma.$queryRaw<
Array<{
id: string;
observationCount: number;
firstUsed: Date | null;
lastUsed: Date | null;
medianOutputTokens: number | null;
medianInputTokens: number | null;
medianTotalCost: number | null;
medianLatency: number | null;
}>
>(
Prisma.sql`
select p.id, p.version, observation_metrics.* from prompts p
LEFT JOIN LATERAL (
SELECT
count(*) AS "observationCount",
MIN(ov.start_time) AS "firstUsed",
MAX(ov.start_time) AS "lastUsed",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.completion_tokens) AS "medianOutputTokens",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.prompt_tokens) AS "medianInputTokens",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.calculated_total_cost) AS "medianTotalCost",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.latency) AS "medianLatency"
FROM
"observations_view" ov
WHERE
ov.prompt_id = p.id
AND "type" = 'GENERATION'
AND "project_id" = ${input.projectId}
) AS observation_metrics ON true
WHERE "project_id" = ${input.projectId}
AND p.id in (${Prisma.join(input.promptIds)})
ORDER BY version DESC
`,
);

const averageObservationScores = await ctx.prisma.$queryRaw<
Array<{
prompt_id: string;
scores: Record<string, number>;
}>
>(
Prisma.sql`
WITH avg_scores_by_prompt AS (
SELECT
o.prompt_id AS prompt_id,
s.name AS score_name,
AVG(s.value) AS average_score_value
FROM observations AS o
JOIN prompts AS p ON o.prompt_id = p.id AND p.project_id = ${input.projectId}
LEFT JOIN scores s ON o.trace_id = s.trace_id AND s.observation_id = o.id AND s.project_id = ${input.projectId}
WHERE
o.type = 'GENERATION'
AND o.prompt_id IS NOT NULL
AND o.project_id = ${input.projectId}
AND p.id IN (${Prisma.join(input.promptIds)})
GROUP BY 1,2
ORDER BY 1,2
),
json_avg_scores_by_prompt_id AS (
SELECT
prompt_id,
jsonb_object_agg(score_name,
average_score_value) AS scores
FROM
avg_scores_by_prompt AS avgs
WHERE
avgs.score_name IS NOT NULL
AND avgs.average_score_value IS NOT NULL
GROUP BY prompt_id
ORDER BY prompt_id
)
SELECT *
FROM json_avg_scores_by_prompt_id`,
);

const averageTraceScores = await ctx.prisma.$queryRaw<
Array<{
prompt_id: string;
scores: Record<string, number>;
}>
>(
Prisma.sql`
WITH traces_by_prompt_id AS (
SELECT
o.prompt_id,
o.trace_id
FROM
observations o
WHERE
o.prompt_id IS NOT NULL
AND o.type = 'GENERATION'
AND o.project_id = ${input.projectId}
AND o.prompt_id IN (${Prisma.join(input.promptIds)})
GROUP BY
o.prompt_id,
o.trace_id
), scores_by_trace AS (
SELECT
tp.prompt_id,
tp.trace_id,
p.version,
s.name AS score_name,
s.value AS score_value
FROM
traces_by_prompt_id tp
JOIN prompts AS p ON tp.prompt_id = p.id AND p.project_id = ${input.projectId}
LEFT JOIN scores s ON tp.trace_id = s.trace_id AND s.observation_id IS NULL AND s.project_id = ${input.projectId}
), average_scores_by_prompt AS (
SELECT
prompt_id,
score_name,
AVG(score_value) AS average_score_value
FROM
scores_by_trace
GROUP BY 1,2
), json_avg_scores_by_prompt_id AS (
SELECT
prompt_id,
jsonb_object_agg(score_name,
average_score_value) AS scores
FROM
average_scores_by_prompt
WHERE
score_name IS NOT NULL
AND average_score_value IS NOT NULL
GROUP BY
prompt_id
ORDER BY
prompt_id
)
SELECT *
FROM json_avg_scores_by_prompt_id
`,
);

return metrics.map((metric) => ({
...metric,
averageObservationScores: averageObservationScores.find(
(score) => score.prompt_id === metric.id,
)?.scores,
averageTraceScores: averageTraceScores.find(
(score) => score.prompt_id === metric.id,
)?.scores,
}));
}),
});

Expand Down
13 changes: 8 additions & 5 deletions web/src/features/tag/components/TagPromptDetailsPopover.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ export function TagPromptDetailsPopover({
{ projectId: projectId, name: promptName },
(oldQueryData: RouterOutput["prompts"]["allVersions"] | undefined) => {
return oldQueryData
? oldQueryData.map((prompt) => {
return prompt.name === promptName
? { ...prompt, tags }
: prompt;
})
? {
promptVersions: oldQueryData.promptVersions.map((prompt) => {
return prompt.name === promptName
? { ...prompt, tags }
: prompt;
}),
totalCount: oldQueryData.totalCount,
}
: undefined;
},
);
Expand Down