File size: 3,139 Bytes
6d55fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
import json
from typing import Dict, Any

from .prompt_templates import PROMPT_MAP, EVALUATION_SYSTEM_MESSAGE
from config import settings

class EvaluationChains:
    def __init__(self, model_name: str = None, api_provider: str = "groq"):
        self.api_provider = api_provider
        self.model_name = model_name or (
            settings.DEFAULT_GROQ_MODEL if api_provider == "groq" 
            else settings.DEFAULT_OPENAI_MODEL
        )
        
        if api_provider == "groq":
            self.llm = ChatGroq(
                model_name=self.model_name,
                temperature=0.1,
                max_tokens=500
            )
        elif api_provider == "openai":
            self.llm = ChatOpenAI(
                model_name=self.model_name,
                temperature=0.1,
                max_tokens=500
            )
        else:
            raise ValueError(f"Unsupported API provider: {api_provider}")
    
    def create_evaluation_chain(self, metric: str):
        """Create a LangChain chain for a specific evaluation metric"""
        prompt = PROMPT_MAP.get(metric)
        if not prompt:
            raise ValueError(f"Unknown metric: {metric}")
        
        # Handle context-based metrics differently
        if metric in ["context_precision", "context_recall"]:
            chain = (
                RunnablePassthrough()
                | self._prepare_context_input
                | prompt
                | self.llm
                | StrOutputParser()
                | self._parse_json_response
            )
        else:
            chain = (
                RunnablePassthrough()
                | prompt
                | self.llm
                | StrOutputParser()
                | self._parse_json_response
            )
        
        return chain
    
    def _prepare_context_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        """Prepare input for context-based metrics"""
        # For context-based metrics, we need to ensure context is provided
        # If context is not provided, use a default empty context
        if "context" not in input_data or not input_data["context"]:
            input_data["context"] = "No context provided for evaluation."
        return input_data
    
    def _parse_json_response(self, response: str) -> Dict[str, Any]:
        """Parse JSON response from LLM"""
        try:
            # Extract JSON from response
            json_start = response.find('{')
            json_end = response.rfind('}') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = response[json_start:json_end]
                return json.loads(json_str)
            return {"score": 0, "explanation": "Invalid response format"}
        except json.JSONDecodeError:
            return {"score": 0, "explanation": "Failed to parse JSON response"}