Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(scores): remove fk constraint from traceId and observationId on scores, change to projectId #2012

Merged
merged 33 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
90d605d
add project id column
marcklingen May 8, 2024
f7ec422
add projectid to scores created via api
marcklingen May 8, 2024
21748a1
add projectid to scores created by eval service
marcklingen May 8, 2024
282f3d1
add projectid to review scores via trpc router
marcklingen May 8, 2024
fce3f62
chore: backfill projectid on score
marcklingen May 8, 2024
7517a74
add index on project id
marcklingen May 8, 2024
8d363a4
chore: add fk to projectid on scores and set non null
marcklingen May 8, 2024
1dd66eb
Merge branch 'main' into marc/lfe-1116-fk-projectid-scores
marliessophie May 10, 2024
bfd5798
revert changes on this file, eslint uninted auto changes
marliessophie May 10, 2024
b08d1de
Merge branch 'main' into marc/lfe-1116-fk-projectid-scores
marcklingen May 13, 2024
d97cb2a
move migration
marcklingen May 13, 2024
a6c033d
fix prisma fk queries
marcklingen May 13, 2024
900a309
chore(scores): drop fk constraint on traces/observations (#2014)
marcklingen May 13, 2024
85ae235
fix misc
marcklingen May 13, 2024
60557e5
ingestion api fix
marcklingen May 13, 2024
8153ead
fixes
marcklingen May 13, 2024
c810805
public api
marcklingen May 13, 2024
69215df
trpc update
marcklingen May 13, 2024
10ca67f
fix post to use score access level
marcklingen May 13, 2024
7982da3
misc
marcklingen May 13, 2024
c109357
Merge branch 'main' into marc/lfe-1116-fk-projectid-scores
marcklingen May 13, 2024
6a8d58f
move migrations
marcklingen May 13, 2024
4e784ec
update worker types while this is necessary (hotfix for ci)
marcklingen May 13, 2024
30da76d
fix total items in get API
marcklingen May 13, 2024
ef761f9
delete score when deleting traces
marcklingen May 13, 2024
759f9f5
cleanup
marcklingen May 13, 2024
d6d95d4
fix api access check
marcklingen May 13, 2024
dbf8b33
concurrent unique index creation
marcklingen May 15, 2024
7576298
changes based on review
marcklingen May 15, 2024
76eba39
Merge branch 'main' into marc/lfe-1116-fk-projectid-scores
marcklingen May 15, 2024
372a2b5
fix test
marcklingen May 15, 2024
e212bad
remove tests which do not appply anymore
marcklingen May 15, 2024
135310e
Merge branch 'main' into marc/lfe-1116-fk-projectid-scores
marcklingen May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/shared/prisma/generated/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ export type Prompt = {
export type Score = {
id: string;
timestamp: Generated<Timestamp>;
project_id: string | null;
project_id: string;
name: string;
value: number;
source: ScoreSource;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- DropForeignKey
ALTER TABLE "scores" DROP CONSTRAINT "scores_observation_id_fkey";

-- DropForeignKey
ALTER TABLE "scores" DROP CONSTRAINT "scores_trace_id_fkey";
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
Warnings:

- Made the column `project_id` on table `scores` required. This step will fail if there are existing NULL values in that column.

*/
-- AlterTable
ALTER TABLE "scores" ALTER COLUMN "project_id" SET NOT NULL;

-- AddForeignKey
ALTER TABLE "scores" ADD CONSTRAINT "scores_project_id_fkey" FOREIGN KEY ("project_id") REFERENCES "projects"("id") ON DELETE CASCADE ON UPDATE CASCADE;
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
Warnings:

- A unique constraint covering the columns `[id,project_id]` on the table `scores` will be added. If there are existing duplicate values, this will fail.

*/
-- DropIndex
DROP INDEX "scores_id_trace_id_key";

-- CreateIndex
CREATE UNIQUE INDEX "scores_id_project_id_key" ON "scores"("id", "project_id");
marcklingen marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 4 additions & 6 deletions packages/shared/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ model Project {
JobExecution JobExecution[]
LlmApiKeys LlmApiKeys[]
PosthogIntegration PosthogIntegration[]
Score Score[]

@@map("projects")
}
Expand Down Expand Up @@ -230,7 +231,6 @@ model Trace {
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @default(now()) @updatedAt @map("updated_at")

scores Score[]
DatasetRunItems DatasetRunItems[]
DatasetItem DatasetItem[]
JobExecution JobExecution[]
Expand Down Expand Up @@ -305,7 +305,6 @@ model Observation {
outputCost Decimal? @map("output_cost")
totalCost Decimal? @map("total_cost")
completionStartTime DateTime? @map("completion_start_time")
scores Score[]
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
derivedDatasetItems DatasetItem[]
datasetRunItems DatasetRunItems[]
Expand Down Expand Up @@ -392,18 +391,17 @@ enum ObservationLevel {
model Score {
id String @id @default(cuid())
timestamp DateTime @default(now())
projectId String? @map("project_id")
projectId String @map("project_id")
project Project @relation(fields: [projectId], references: [id], onDelete: Cascade)
name String
value Float
source ScoreSource
comment String?
traceId String @map("trace_id")
trace Trace @relation(fields: [traceId], references: [id], onDelete: Cascade)
observationId String? @map("observation_id")
observation Observation? @relation(fields: [observationId], references: [id], onDelete: SetNull)
JobExecution JobExecution[]

@@unique([id, traceId]) // used for upsert
@@unique([id, projectId]) // used for upserts via prisma
@@index(timestamp)
@@index([value])
@@index([projectId])
Expand Down
1 change: 1 addition & 0 deletions web/src/__tests__/ingestion.servertest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ describe("/api/public/ingestion API Endpoint", () => {
expect(dbScore?.name).toBe("score-name");
expect(dbScore?.value).toBe(100.5);
expect(dbScore?.observationId).toBeNull();
expect(dbScore?.projectId).toBe("7a88fb47-b4e2-43b8-a06c-a5ce950dc53a");
});
});

Expand Down
1 change: 1 addition & 0 deletions web/src/__tests__/scores.servertest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ describe("/api/public/scores API Endpoint", () => {
expect(dbScore?.observationId).toBeNull();
expect(dbScore?.comment).toBe("comment");
expect(dbScore?.source).toBe("API");
expect(dbScore?.projectId).toBe("7a88fb47-b4e2-43b8-a06c-a5ce950dc53a");
});

it("should create score for a trace with int", async () => {
Expand Down
48 changes: 34 additions & 14 deletions web/src/features/datasets/server/dataset-router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ export const datasetRouter = createTRPCRouter({
JOIN scores s
ON s.trace_id = ri.trace_id
AND (ri.observation_id IS NULL OR s.observation_id = ri.observation_id) -- only include scores that are linked to the observation if observation is linked
AND s.project_id = ${input.projectId}
JOIN traces t ON t.id = s.trace_id
WHERE t.project_id = ${input.projectId}
GROUP BY
Expand Down Expand Up @@ -553,13 +554,11 @@ export const datasetRouter = createTRPCRouter({
observation: {
select: {
id: true,
scores: true,
},
},
trace: {
select: {
id: true,
scores: true,
},
},
},
Expand All @@ -570,6 +569,27 @@ export const datasetRouter = createTRPCRouter({
skip: input.page * input.limit,
});

const traceScores = await ctx.prisma.score.findMany({
where: {
projectId: ctx.session.projectId,
traceId: {
in: runItems
.filter((ri) => !!!ri.observationId) // only include trace scores if run is not linked to an observation
marcklingen marked this conversation as resolved.
Show resolved Hide resolved
.map((ri) => ri.traceId),
},
},
});
const observationScores = await ctx.prisma.score.findMany({
where: {
projectId: ctx.session.projectId,
observationId: {
in: runItems
.filter((ri) => !!ri.observationId)
marcklingen marked this conversation as resolved.
Show resolved Hide resolved
.map((ri) => ri.observationId) as string[],
},
},
});

const totalRunItems = await ctx.prisma.datasetRunItems.count({
where: {
datasetRunId: input.datasetRunId,
Expand All @@ -582,13 +602,12 @@ export const datasetRouter = createTRPCRouter({
},
});

const observationIds = runItems
.map((ri) => ri.observation?.id)
.filter(Boolean) as string[];
const observations = await ctx.prisma.observationView.findMany({
where: {
id: {
in: observationIds,
in: runItems
.map((ri) => ri.observationId)
.filter(Boolean) as string[],
marcklingen marked this conversation as resolved.
Show resolved Hide resolved
},
},
select: {
Expand All @@ -598,13 +617,10 @@ export const datasetRouter = createTRPCRouter({
},
});

const traceIds = runItems
.map((ri) => ri.trace?.id)
.filter(Boolean) as string[];
const traces = await ctx.prisma.traceView.findMany({
where: {
id: {
in: traceIds,
in: runItems.map((ri) => ri.traceId).filter(Boolean) as string[],
},
projectId: ctx.session.projectId,
},
Expand All @@ -621,10 +637,14 @@ export const datasetRouter = createTRPCRouter({
datasetItemId: ri.datasetItemId,
observation: observations.find((o) => o.id === ri.observationId),
trace: traces.find((t) => t.id === ri.traceId),
scores:
// use observation scores if run is linked to an observation, otherwise use all trace scores
(!!ri.observationId ? ri.observation?.scores : ri.trace?.scores) ??
[],
scores: [
...traceScores.filter((s) => s.traceId === ri.traceId),
...observationScores.filter(
(s) =>
s.observationId === ri.observationId &&
s.traceId === ri.traceId,
),
],
};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export function ManualScoreButton({

const handleDelete = async () => {
if (score) {
await mutDeleteScore.mutateAsync(score.id);
await mutDeleteScore.mutateAsync({ id: score.id, projectId });
capture("score:delete", {
type: type,
source: source,
Expand Down Expand Up @@ -130,6 +130,7 @@ export function ManualScoreButton({
id: score.id,
value: values.score,
comment: values.comment,
projectId,
});
capture("score:update_form_submit", {
type: type,
Expand All @@ -142,6 +143,7 @@ export function ManualScoreButton({
comment: values.comment,
traceId,
observationId,
projectId,
});
capture("score:create_form_submit", {
type: type,
Expand Down
2 changes: 1 addition & 1 deletion web/src/features/public-api/server/apiScope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async function isResourceInProject(resource: Resource, projectId: string) {
(await prisma.score.count({
where: {
id: resource.id,
trace: { projectId },
projectId,
marcklingen marked this conversation as resolved.
Show resolved Hide resolved
},
})) === 1;
if (!scoreCheck)
Expand Down
17 changes: 8 additions & 9 deletions web/src/pages/api/cron/ingestion-metrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,17 @@ export default async function handler(
}>
>`
SELECT
t.project_id,
s.project_id,
count(s.*)::integer count_scores
FROM
scores s
JOIN traces t ON t.id = s.trace_id
WHERE
s.timestamp < ${endTimeframe}
${
startTimeframe
? Prisma.sql`AND s.timestamp >= ${startTimeframe}`
: Prisma.empty
}
WHERE
s.timestamp < ${endTimeframe}
${
startTimeframe
? Prisma.sql`AND s.timestamp >= ${startTimeframe}`
: Prisma.empty
}
GROUP BY
1
`;
Expand Down
49 changes: 19 additions & 30 deletions web/src/pages/api/public/scores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ import { isPrismaException } from "@/src/utils/exceptions";

const operators = ["<", ">", "<=", ">=", "!=", "="] as const;

const prismaOperators: Record<(typeof operators)[number], string> = {
"<": "lt",
">": "gt",
"<=": "lte",
">=": "gte",
"!=": "not",
"=": "equals",
};

const ScoresGetSchema = z.object({
...paginationZod,
userId: z.string().nullish(),
Expand Down Expand Up @@ -141,8 +132,8 @@ export default async function handler(
s.observation_id as "observationId",
json_build_object('userId', t.user_id) as "trace"
FROM "scores" AS s
JOIN "traces" AS t ON t.id = s.trace_id
WHERE t.project_id = ${authCheck.scope.projectId}
JOIN "traces" AS t ON t.id = s.trace_id AND t.project_id = ${authCheck.scope.projectId}
WHERE s.project_id = ${authCheck.scope.projectId}
${userCondition}
${nameCondition}
${sourceCondition}
Expand All @@ -151,25 +142,23 @@ export default async function handler(
ORDER BY t."timestamp" DESC
LIMIT ${obj.limit} OFFSET ${skipValue}
`);
const totalItems = await prisma.score.count({
where: {
name: obj.name ? obj.name : undefined,
source: obj.source ? obj.source : undefined,
timestamp: obj.fromTimestamp
? { gte: new Date(obj.fromTimestamp) }
: undefined,
value:
obj.operator && obj.value
? {
[prismaOperators[obj.operator]]: obj.value,
}
: undefined,
trace: {
projectId: authCheck.scope.projectId,
userId: obj.userId ? obj.userId : undefined,
},
},
});

const totalItemsRes = await prisma.$queryRaw<{ count: bigint }[]>(
Prisma.sql`
SELECT COUNT(*) as count
FROM "scores" AS s
JOIN "traces" AS t ON t.id = s.trace_id AND t.project_id = ${authCheck.scope.projectId}
WHERE s.project_id = ${authCheck.scope.projectId}
${userCondition}
${nameCondition}
${sourceCondition}
${fromTimestampCondition}
${valueCondition}
`,
);

const totalItems =
totalItemsRes[0] !== undefined ? Number(totalItemsRes[0].count) : 0;

return res.status(200).json({
data: scores,
Expand Down
12 changes: 3 additions & 9 deletions web/src/pages/api/public/scores/[scoreId].ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ export default async function handler(
},
where: {
id: scoreId,
trace: {
projectId: authCheck.scope.projectId,
},
projectId: authCheck.scope.projectId,
},
});

Expand All @@ -55,9 +53,7 @@ export default async function handler(
await prisma.score.delete({
where: {
id: scoreId,
trace: {
projectId: authCheck.scope.projectId,
},
projectId: authCheck.scope.projectId,
},
});

Expand Down Expand Up @@ -107,9 +103,7 @@ export default async function handler(
const score = await prisma.score.findUnique({
where: {
id: scoreId,
trace: {
projectId: authCheck.scope.projectId,
},
projectId: authCheck.scope.projectId,
},
});

Expand Down
2 changes: 1 addition & 1 deletion web/src/pages/api/public/traces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export default async function handler(
COALESCE(ARRAY_AGG(DISTINCT s.id) FILTER (WHERE s.id IS NOT NULL), ARRAY[]::text[]) AS "scores"
FROM "traces" AS t
LEFT JOIN "observations_view" AS o ON t.id = o.trace_id AND o.project_id = ${authCheck.scope.projectId}
LEFT JOIN "scores" AS s ON t.id = s.trace_id
LEFT JOIN "scores" AS s ON t.id = s.trace_id AND s.project_id = ${authCheck.scope.projectId}
WHERE t.project_id = ${authCheck.scope.projectId}
${userCondition}
${nameCondition}
Expand Down