Spaces:
Sleeping
Sleeping
| import { Chess, Move } from 'chess.js' | |
| import { pipeline } from '@xenova/transformers' | |
| export class ChessAI { | |
| private model: any = null | |
| private isLoading: boolean = false | |
| private modelId: string | |
| constructor(modelId: string = 'mlabonne/chesspythia-70m') { | |
| this.modelId = modelId | |
| } | |
| async initialize(): Promise<void> { | |
| if (this.model || this.isLoading) return | |
| this.isLoading = true | |
| try { | |
| this.model = await pipeline('text-generation', this.modelId) | |
| } catch (error) { | |
| console.error('Failed to load chess model:', error) | |
| this.model = null | |
| } finally { | |
| this.isLoading = false | |
| } | |
| } | |
| async getMove(chess: Chess, timeLimit: number = 10000): Promise<Move | null> { | |
| if (!this.model) { | |
| console.log('Using fallback AI (model not loaded)') | |
| return this.getFallbackMove(chess) | |
| } | |
| try { | |
| const legalMoves = chess.moves({ verbose: true }) | |
| if (legalMoves.length === 0) return null | |
| const prompt = this.createChessPrompt(chess) | |
| const startTime = Date.now() | |
| const result = await Promise.race([ | |
| this.generateMove(prompt, legalMoves), | |
| new Promise<null>((_, reject) => | |
| setTimeout(() => reject(new Error('Timeout')), timeLimit) | |
| ) | |
| ]) | |
| const elapsedTime = Date.now() - startTime | |
| console.log(`AI move generated in ${elapsedTime}ms`) | |
| return result || this.getFallbackMove(chess) | |
| } catch (error) { | |
| console.error('Error generating AI move:', error) | |
| return this.getFallbackMove(chess) | |
| } | |
| } | |
| private async generateMove(prompt: string, legalMoves: Move[]): Promise<Move | null> { | |
| if (!this.model) return null | |
| try { | |
| const output = await this.model(prompt, { | |
| max_new_tokens: 10, | |
| temperature: 0.7 | |
| }) | |
| const generatedText = Array.isArray(output) ? output[0]?.generated_text : output.generated_text | |
| console.log('Model output:', generatedText) | |
| const move = this.parseMove(generatedText, legalMoves) | |
| return move | |
| } catch (error) { | |
| console.error('Error in model generation:', error) | |
| return null | |
| } | |
| } | |
| private createChessPrompt(chess: Chess): string { | |
| const turn = chess.turn() | |
| const moveNumber = chess.moveNumber() | |
| if (turn === 'w') { | |
| return `${moveNumber}.` | |
| } else { | |
| return `${moveNumber}...` | |
| } | |
| } | |
| private parseMove(generatedText: string, legalMoves: Move[]): Move | null { | |
| if (!generatedText) return null | |
| const cleanText = generatedText.trim().replace(/[+#]$/, '') | |
| for (const move of legalMoves) { | |
| if (move.san === cleanText || move.lan === cleanText) { | |
| return move | |
| } | |
| } | |
| for (const move of legalMoves) { | |
| if (move.san.startsWith(cleanText) || cleanText.includes(move.san)) { | |
| return move | |
| } | |
| } | |
| console.log(`Could not parse move "${cleanText}" from legal moves:`, legalMoves.map(m => m.san)) | |
| return null | |
| } | |
| private getFallbackMove(chess: Chess): Move | null { | |
| const legalMoves = chess.moves({ verbose: true }) | |
| if (legalMoves.length === 0) return null | |
| let candidateMoves = legalMoves.filter(move => move.captured) | |
| if (candidateMoves.length === 0) { | |
| candidateMoves = legalMoves.filter(move => { | |
| chess.move(move) | |
| const isCheck = chess.inCheck() | |
| chess.undo() | |
| return isCheck | |
| }) | |
| } | |
| if (candidateMoves.length === 0) { | |
| candidateMoves = legalMoves | |
| } | |
| const randomIndex = Math.floor(Math.random() * candidateMoves.length) | |
| return candidateMoves[randomIndex] | |
| } | |
| isModelLoaded(): boolean { | |
| return this.model !== null | |
| } | |
| isModelLoading(): boolean { | |
| return this.isLoading | |
| } | |
| getModelInfo(): string { | |
| if (this.isLoading) return 'Loading...' | |
| if (this.model) return `${this.modelId} (Loaded)` | |
| return `${this.modelId} (Not loaded)` | |
| } | |
| } |