Skip to content

Commit

Permalink
feat(prompts): add table to compare metrics across prompt versions, `…
Browse files Browse the repository at this point in the history
…[promptName]/metrics` (#1991)



---------

Co-authored-by: Marc Klingen <git@marcklingen.com>
  • Loading branch information
marliessophie and marcklingen committed May 14, 2024
1 parent ad10e6f commit 2990767
Show file tree
Hide file tree
Showing 8 changed files with 646 additions and 14 deletions.
2 changes: 2 additions & 0 deletions packages/shared/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ view TraceView {
}

// Update ObservationView below when making changes to this model!
// traceId is optional only due to timing during data injestion
// (traceId is not necessarily known at the time of observation creation)
model Observation {
id String @id @default(cuid())
traceId String? @map("trace_id")
Expand Down
21 changes: 17 additions & 4 deletions web/src/features/prompts/components/prompt-detail.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
ChatMlArraySchema,
OpenAiMessageView,
} from "@/src/components/trace/IOPreview";
import { Tabs, TabsList, TabsTrigger } from "@/src/components/ui/tabs";
import { Badge } from "@/src/components/ui/badge";
import { Button } from "@/src/components/ui/button";
import { CodeView, JSONView } from "@/src/components/ui/CodeJsonViewer";
Expand Down Expand Up @@ -49,10 +50,10 @@ export const PromptDetail = () => {
{ enabled: Boolean(projectId) },
);
const prompt = currentPromptVersion
? promptHistory.data?.find(
? promptHistory.data?.promptVersions.find(
(prompt) => prompt.version === currentPromptVersion,
)
: promptHistory.data?.[0];
: promptHistory.data?.promptVersions[0];

const extractedVariables = prompt
? extractVariables(JSON.stringify(prompt.prompt))
Expand Down Expand Up @@ -143,14 +144,26 @@ export const PromptDetail = () => {
<DeletePromptVersion
promptVersionId={prompt.id}
version={prompt.version}
countVersions={promptHistory.data.length}
countVersions={promptHistory.data.totalCount}
/>
<DetailPageNav
key="nav"
currentId={promptName}
path={(name) => `/project/${projectId}/prompts/${name}`}
listKey="prompts"
/>
<Tabs value="editor">
<TabsList>
<TabsTrigger value="editor">Editor</TabsTrigger>
<TabsTrigger value="metrics" asChild>
<Link
href={`/project/${projectId}/prompts/${encodeURIComponent(promptName)}/metrics`}
>
Metrics
</Link>
</TabsTrigger>
</TabsList>
</Tabs>
</>
}
/>
Expand Down Expand Up @@ -229,7 +242,7 @@ export const PromptDetail = () => {
<div className="text-m px-3 font-medium">
<ScrollArea className="flex border-l pl-2">
<PromptHistoryNode
prompts={promptHistory.data}
prompts={promptHistory.data.promptVersions}
currentPromptVersion={prompt.version}
setCurrentPromptVersion={setCurrentPromptVersion}
/>
Expand Down
4 changes: 2 additions & 2 deletions web/src/features/prompts/components/prompt-history.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { type NextRouter, useRouter } from "next/router";

const PromptHistoryTraceNode = (props: {
index: number;
prompt: RouterOutputs["prompts"]["allVersions"][number];
prompt: RouterOutputs["prompts"]["allVersions"]["promptVersions"][number];
currentPromptVersion: number | undefined;
setCurrentPromptVersion: (version: number | undefined) => void;
router: NextRouter;
Expand Down Expand Up @@ -57,7 +57,7 @@ const PromptHistoryTraceNode = (props: {
};

export const PromptHistoryNode = (props: {
prompts: RouterOutputs["prompts"]["allVersions"];
prompts: RouterOutputs["prompts"]["allVersions"]["promptVersions"];
currentPromptVersion: number | undefined;
setCurrentPromptVersion: (id: number | undefined) => void;
}) => {
Expand Down
183 changes: 180 additions & 3 deletions web/src/features/prompts/server/routers/promptRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { type Prompt, Prisma } from "@langfuse/shared/src/db";
import { createPrompt } from "../actions/createPrompt";
import { orderByToPrismaSql } from "@/src/features/orderBy/server/orderByToPrisma";
import { promptsTableCols } from "@/src/server/api/definitions/promptsTable";
import { paginationZod } from "@/src/utils/zod";
import { optionalPaginationZod, paginationZod } from "@/src/utils/zod";
import {
orderBy,
singleFilter,
Expand Down Expand Up @@ -493,7 +493,13 @@ export const promptRouter = createTRPCRouter({
}
}),
allVersions: protectedProjectProcedure
.input(z.object({ projectId: z.string(), name: z.string() }))
.input(
z.object({
projectId: z.string(),
name: z.string(),
...optionalPaginationZod,
}),
)
.query(async ({ input, ctx }) => {
throwIfNoAccess({
session: ctx.session,
Expand All @@ -505,8 +511,19 @@ export const promptRouter = createTRPCRouter({
projectId: input.projectId,
name: input.name,
},
...(input.limit !== undefined && input.page !== undefined
? { take: input.limit, skip: input.page * input.limit }
: undefined),
orderBy: [{ version: "desc" }],
});

const totalCount = await ctx.prisma.prompt.count({
where: {
projectId: input.projectId,
name: input.name,
},
});

const userIds = prompts
.map((p) => p.createdBy)
.filter((id) => id !== "API");
Expand Down Expand Up @@ -538,7 +555,167 @@ export const promptRouter = createTRPCRouter({
creator: user?.name,
};
});
return joinedPromptAndUsers;
return { promptVersions: joinedPromptAndUsers, totalCount };
}),
metrics: protectedProjectProcedure
.input(
z.object({
projectId: z.string(),
promptIds: z.array(z.string()),
}),
)
.query(async ({ input, ctx }) => {
throwIfNoAccess({
session: ctx.session,
projectId: input.projectId,
scope: "prompts:read",
});

const metrics = await ctx.prisma.$queryRaw<
Array<{
id: string;
observationCount: number;
firstUsed: Date | null;
lastUsed: Date | null;
medianOutputTokens: number | null;
medianInputTokens: number | null;
medianTotalCost: number | null;
medianLatency: number | null;
}>
>(
Prisma.sql`
select p.id, p.version, observation_metrics.* from prompts p
LEFT JOIN LATERAL (
SELECT
count(*) AS "observationCount",
MIN(ov.start_time) AS "firstUsed",
MAX(ov.start_time) AS "lastUsed",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.completion_tokens) AS "medianOutputTokens",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.prompt_tokens) AS "medianInputTokens",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.calculated_total_cost) AS "medianTotalCost",
PERCENTILE_CONT(0.5) WITHIN GROUP(ORDER BY ov.latency) AS "medianLatency"
FROM
"observations_view" ov
WHERE
ov.prompt_id = p.id
AND "type" = 'GENERATION'
AND "project_id" = ${input.projectId}
) AS observation_metrics ON true
WHERE "project_id" = ${input.projectId}
AND p.id in (${Prisma.join(input.promptIds)})
ORDER BY version DESC
`,
);

const averageObservationScores = await ctx.prisma.$queryRaw<
Array<{
prompt_id: string;
scores: Record<string, number>;
}>
>(
Prisma.sql`
WITH avg_scores_by_prompt AS (
SELECT
o.prompt_id AS prompt_id,
s.name AS score_name,
AVG(s.value) AS average_score_value
FROM observations AS o
JOIN prompts AS p ON o.prompt_id = p.id AND p.project_id = ${input.projectId}
LEFT JOIN scores s ON o.trace_id = s.trace_id AND s.observation_id = o.id AND s.project_id = ${input.projectId}
WHERE
o.type = 'GENERATION'
AND o.prompt_id IS NOT NULL
AND o.project_id = ${input.projectId}
AND p.id IN (${Prisma.join(input.promptIds)})
GROUP BY 1,2
ORDER BY 1,2
),
json_avg_scores_by_prompt_id AS (
SELECT
prompt_id,
jsonb_object_agg(score_name,
average_score_value) AS scores
FROM
avg_scores_by_prompt AS avgs
WHERE
avgs.score_name IS NOT NULL
AND avgs.average_score_value IS NOT NULL
GROUP BY prompt_id
ORDER BY prompt_id
)
SELECT *
FROM json_avg_scores_by_prompt_id`,
);

const averageTraceScores = await ctx.prisma.$queryRaw<
Array<{
prompt_id: string;
scores: Record<string, number>;
}>
>(
Prisma.sql`
WITH traces_by_prompt_id AS (
SELECT
o.prompt_id,
o.trace_id
FROM
observations o
WHERE
o.prompt_id IS NOT NULL
AND o.type = 'GENERATION'
AND o.project_id = ${input.projectId}
AND o.prompt_id IN (${Prisma.join(input.promptIds)})
GROUP BY
o.prompt_id,
o.trace_id
), scores_by_trace AS (
SELECT
tp.prompt_id,
tp.trace_id,
p.version,
s.name AS score_name,
s.value AS score_value
FROM
traces_by_prompt_id tp
JOIN prompts AS p ON tp.prompt_id = p.id AND p.project_id = ${input.projectId}
LEFT JOIN scores s ON tp.trace_id = s.trace_id AND s.observation_id IS NULL AND s.project_id = ${input.projectId}
), average_scores_by_prompt AS (
SELECT
prompt_id,
score_name,
AVG(score_value) AS average_score_value
FROM
scores_by_trace
GROUP BY 1,2
), json_avg_scores_by_prompt_id AS (
SELECT
prompt_id,
jsonb_object_agg(score_name,
average_score_value) AS scores
FROM
average_scores_by_prompt
WHERE
score_name IS NOT NULL
AND average_score_value IS NOT NULL
GROUP BY
prompt_id
ORDER BY
prompt_id
)
SELECT *
FROM json_avg_scores_by_prompt_id
`,
);

return metrics.map((metric) => ({
...metric,
averageObservationScores: averageObservationScores.find(
(score) => score.prompt_id === metric.id,
)?.scores,
averageTraceScores: averageTraceScores.find(
(score) => score.prompt_id === metric.id,
)?.scores,
}));
}),
});

Expand Down
13 changes: 8 additions & 5 deletions web/src/features/tag/components/TagPromptDetailsPopover.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ export function TagPromptDetailsPopover({
{ projectId: projectId, name: promptName },
(oldQueryData: RouterOutput["prompts"]["allVersions"] | undefined) => {
return oldQueryData
? oldQueryData.map((prompt) => {
return prompt.name === promptName
? { ...prompt, tags }
: prompt;
})
? {
promptVersions: oldQueryData.promptVersions.map((prompt) => {
return prompt.name === promptName
? { ...prompt, tags }
: prompt;
}),
totalCount: oldQueryData.totalCount,
}
: undefined;
},
);
Expand Down

0 comments on commit 2990767

Please sign in to comment.