File size: 3,888 Bytes
6510698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env bash
set -euo pipefail

echo "πŸš€ Starting Phi-3.5-MoE prestart setup..."

# Function to log with timestamp
log() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}

# Function to check if command exists
command_exists() {
    command -v "$1" >/dev/null 2>&1
}

# Ensure Python is available
if ! command_exists python; then
    log "❌ Python not found"
    exit 1
fi

log "βœ… Python found: $(python --version)"

# Load environment variables if .env exists
if [ -f .env ]; then
    log "πŸ“„ Loading environment variables from .env"
    export $(cat .env | grep -v '^#' | xargs)
fi

# Run dependency installation and environment setup
python - <<'PY'
import os
import sys
import subprocess
import logging
import torch

# Setup logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

def run_pip_install(packages, description="packages"):
    """Run pip install with error handling."""
    try:
        cmd = [sys.executable, "-m", "pip", "install"] + packages
        logger.info(f"Installing {description}: {' '.join(packages)}")
        subprocess.check_call(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        logger.info(f"βœ… Successfully installed {description}")
        return True
    except subprocess.CalledProcessError as e:
        logger.error(f"❌ Failed to install {description}: {e}")
        return False

def main():
    logger.info("πŸ” Checking environment and installing dependencies...")
    
    # Always ensure core dependencies are present
    core_deps = ["einops>=0.7.0", "transformers==4.46.0", "accelerate>=0.31.0"]
    if not run_pip_install(core_deps, "core dependencies"):
        logger.error("❌ Failed to install core dependencies")
        sys.exit(1)
    
    # Check CUDA availability
    cuda_available = torch.cuda.is_available()
    logger.info(f"πŸ–₯️  CUDA available: {cuda_available}")
    
    if cuda_available:
        logger.info("πŸš€ GPU runtime detected - installing flash-attn for optimal performance")
        
        # Install flash-attn for GPU
        flash_attn_packages = ["flash-attn>=2.6.0", "--no-build-isolation"]
        if run_pip_install(flash_attn_packages, "flash-attn (GPU optimization)"):
            logger.info("βœ… GPU environment fully configured")
        else:
            logger.warning("⚠️ Flash-attn installation failed, continuing without it")
    else:
        logger.info("πŸ’» CPU runtime detected - configuring for CPU-only operation")
        logger.info("ℹ️  Skipping flash-attn installation (not needed for CPU)")
        
        # For CPU, we need to select a safe model revision
        logger.info("πŸ” Checking for CPU-safe model revision...")
        
        try:
            # Run the revision selector
            result = subprocess.run([
                sys.executable, "scripts/select_revision.py"
            ], capture_output=True, text=True, timeout=300)
            
            if result.returncode == 0:
                logger.info("βœ… CPU-safe revision configured")
            else:
                logger.warning(f"⚠️ Revision selector returned {result.returncode}")
                logger.warning(f"stdout: {result.stdout}")
                logger.warning(f"stderr: {result.stderr}")
                
        except subprocess.TimeoutExpired:
            logger.warning("⚠️ Revision selection timed out, continuing with default")
        except Exception as e:
            logger.warning(f"⚠️ Error running revision selector: {e}")
    
    logger.info("πŸŽ‰ Prestart setup completed successfully!")

if __name__ == "__main__":
    main()
PY

# Check exit code from Python script
if [ $? -ne 0 ]; then
    log "❌ Prestart setup failed"
    exit 1
fi

log "βœ… Prestart setup completed successfully!"
log "πŸš€ Ready to start the application!"