| | import { z } from 'zod'; |
| | import { ViolationTypes, ErrorTypes } from 'librechat-data-provider'; |
| | import type { Agent, TModelsConfig } from 'librechat-data-provider'; |
| | import type { Request, Response } from 'express'; |
| |
|
| | |
| | export const agentAvatarSchema = z.object({ |
| | filepath: z.string(), |
| | source: z.string(), |
| | }); |
| |
|
| | |
| | export const agentBaseResourceSchema = z.object({ |
| | file_ids: z.array(z.string()).optional(), |
| | files: z.array(z.any()).optional(), |
| | }); |
| |
|
| | |
| | export const agentFileResourceSchema = agentBaseResourceSchema.extend({ |
| | vector_store_ids: z.array(z.string()).optional(), |
| | }); |
| |
|
| | |
| | export const agentToolResourcesSchema = z |
| | .object({ |
| | image_edit: agentBaseResourceSchema.optional(), |
| | execute_code: agentBaseResourceSchema.optional(), |
| | file_search: agentFileResourceSchema.optional(), |
| | context: agentBaseResourceSchema.optional(), |
| | |
| | ocr: agentBaseResourceSchema.optional(), |
| | }) |
| | .optional(); |
| |
|
| | |
| | export const agentSupportContactSchema = z |
| | .object({ |
| | name: z.string().optional(), |
| | email: z.union([z.literal(''), z.string().email()]).optional(), |
| | }) |
| | .optional(); |
| |
|
| | |
| | export const graphEdgeSchema = z.object({ |
| | from: z.union([z.string(), z.array(z.string())]), |
| | to: z.union([z.string(), z.array(z.string())]), |
| | description: z.string().optional(), |
| | edgeType: z.enum(['handoff', 'direct']).optional(), |
| | prompt: z.union([z.string(), z.function()]).optional(), |
| | excludeResults: z.boolean().optional(), |
| | promptKey: z.string().optional(), |
| | }); |
| |
|
| | |
| | export const agentBaseSchema = z.object({ |
| | name: z.string().nullable().optional(), |
| | description: z.string().nullable().optional(), |
| | instructions: z.string().nullable().optional(), |
| | avatar: agentAvatarSchema.nullable().optional(), |
| | model_parameters: z.record(z.unknown()).optional(), |
| | tools: z.array(z.string()).optional(), |
| | |
| | agent_ids: z.array(z.string()).optional(), |
| | edges: z.array(graphEdgeSchema).optional(), |
| | end_after_tools: z.boolean().optional(), |
| | hide_sequential_outputs: z.boolean().optional(), |
| | artifacts: z.string().optional(), |
| | recursion_limit: z.number().optional(), |
| | conversation_starters: z.array(z.string()).optional(), |
| | tool_resources: agentToolResourcesSchema, |
| | support_contact: agentSupportContactSchema, |
| | category: z.string().optional(), |
| | }); |
| |
|
| | |
| | export const agentCreateSchema = agentBaseSchema.extend({ |
| | provider: z.string(), |
| | model: z.string().nullable(), |
| | tools: z.array(z.string()).optional().default([]), |
| | }); |
| |
|
| | |
| | export const agentUpdateSchema = agentBaseSchema.extend({ |
| | avatar: z.union([agentAvatarSchema, z.null()]).optional(), |
| | provider: z.string().optional(), |
| | model: z.string().nullable().optional(), |
| | projectIds: z.array(z.string()).optional(), |
| | removeProjectIds: z.array(z.string()).optional(), |
| | isCollaborative: z.boolean().optional(), |
| | }); |
| |
|
| | interface ValidateAgentModelParams { |
| | req: Request; |
| | res: Response; |
| | agent: Agent; |
| | modelsConfig: TModelsConfig; |
| | logViolation: ( |
| | req: Request, |
| | res: Response, |
| | type: string, |
| | errorMessage: Record<string, unknown>, |
| | score?: number | string, |
| | ) => Promise<void>; |
| | } |
| |
|
| | interface ValidateAgentModelResult { |
| | isValid: boolean; |
| | error?: { |
| | message: string; |
| | }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | export async function validateAgentModel( |
| | params: ValidateAgentModelParams, |
| | ): Promise<ValidateAgentModelResult> { |
| | const { req, res, agent, modelsConfig, logViolation } = params; |
| | const { model, provider: endpoint } = agent; |
| |
|
| | if (!model) { |
| | return { |
| | isValid: false, |
| | error: { |
| | message: `{ "type": "${ErrorTypes.MISSING_MODEL}", "info": "${endpoint}" }`, |
| | }, |
| | }; |
| | } |
| |
|
| | if (!modelsConfig) { |
| | return { |
| | isValid: false, |
| | error: { |
| | message: `{ "type": "${ErrorTypes.MODELS_NOT_LOADED}" }`, |
| | }, |
| | }; |
| | } |
| |
|
| | const availableModels = modelsConfig[endpoint]; |
| | if (!availableModels) { |
| | return { |
| | isValid: false, |
| | error: { |
| | message: `{ "type": "${ErrorTypes.ENDPOINT_MODELS_NOT_LOADED}", "info": "${endpoint}" }`, |
| | }, |
| | }; |
| | } |
| |
|
| | const validModel = !!availableModels.find((availableModel) => availableModel === model); |
| |
|
| | if (validModel) { |
| | return { isValid: true }; |
| | } |
| |
|
| | const { ILLEGAL_MODEL_REQ_SCORE: score = 1 } = process.env ?? {}; |
| | const type = ViolationTypes.ILLEGAL_MODEL_REQUEST; |
| | const errorMessage = { |
| | type, |
| | model, |
| | endpoint, |
| | }; |
| |
|
| | await logViolation(req, res, type, errorMessage, score); |
| |
|
| | return { |
| | isValid: false, |
| | error: { |
| | message: `{ "type": "${ViolationTypes.ILLEGAL_MODEL_REQUEST}", "info": "${endpoint}|${model}" }`, |
| | }, |
| | }; |
| | } |
| |
|