Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| from prophet import Prophet | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| ######################################## | |
| # OKX endpoints & utility | |
| ######################################## | |
| OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT" | |
| OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles" | |
| TIMEFRAME_MAPPING = { | |
| "1m": "1m", | |
| "5m": "5m", | |
| "15m": "15m", | |
| "30m": "30m", | |
| "1h": "1H", # OKX expects '1H' | |
| "2h": "2H", | |
| "4h": "4H", | |
| "6h": "6H", | |
| "12h": "12H", | |
| "1d": "1D", | |
| "1w": "1W", | |
| } | |
| def fetch_okx_symbols(): | |
| """ | |
| Fetch the list of symbols (instId) from OKX Spot tickers. | |
| """ | |
| logging.info("Fetching symbols from OKX Spot tickers...") | |
| try: | |
| resp = requests.get(OKX_TICKERS_ENDPOINT, timeout=30) | |
| resp.raise_for_status() | |
| json_data = resp.json() | |
| if json_data.get("code") != "0": | |
| logging.error(f"Non-zero code returned: {json_data}") | |
| return ["Error: Could not fetch OKX symbols"] | |
| data = json_data.get("data", []) | |
| symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"] | |
| if not symbols: | |
| logging.warning("No spot symbols found.") | |
| return ["Error: No spot symbols found."] | |
| logging.info(f"Fetched {len(symbols)} OKX spot symbols.") | |
| return sorted(symbols) | |
| except Exception as e: | |
| logging.error(f"Error fetching OKX symbols: {e}") | |
| return [f"Error: {str(e)}"] | |
| def fetch_okx_candles(symbol, timeframe="1H", limit=500): | |
| """ | |
| Fetch historical candle data for a symbol from OKX. | |
| OKX data columns: | |
| [ts, o, h, l, c, vol, volCcy, volCcyQuote, confirm] | |
| """ | |
| logging.info(f"Fetching {limit} candles for {symbol} @ {timeframe} from OKX...") | |
| params = { | |
| "instId": symbol, | |
| "bar": timeframe, | |
| "limit": limit | |
| } | |
| try: | |
| resp = requests.get(OKX_CANDLE_ENDPOINT, params=params, timeout=30) | |
| resp.raise_for_status() | |
| json_data = resp.json() | |
| if json_data.get("code") != "0": | |
| msg = f"OKX returned code={json_data.get('code')}, msg={json_data.get('msg')}" | |
| logging.error(msg) | |
| return pd.DataFrame(), msg | |
| items = json_data.get("data", []) | |
| if not items: | |
| warning_msg = f"No candle data returned for {symbol}." | |
| logging.warning(warning_msg) | |
| return pd.DataFrame(), warning_msg | |
| # Reverse to chronological (OKX returns newest first) | |
| items.reverse() | |
| columns = [ | |
| "ts", "o", "h", "l", "c", "vol", | |
| "volCcy", "volCcyQuote", "confirm" | |
| ] | |
| df = pd.DataFrame(items, columns=columns) | |
| df.rename(columns={ | |
| "ts": "timestamp", | |
| "o": "open", | |
| "h": "high", | |
| "l": "low", | |
| "c": "close" | |
| }, inplace=True) | |
| df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") | |
| numeric_cols = ["open", "high", "low", "close", "vol", "volCcy", "volCcyQuote", "confirm"] | |
| df[numeric_cols] = df[numeric_cols].astype(float) | |
| logging.info(f"Fetched {len(df)} rows for {symbol}.") | |
| return df, "" | |
| except Exception as e: | |
| err_msg = f"Error fetching candles for {symbol}: {e}" | |
| logging.error(err_msg) | |
| return pd.DataFrame(), err_msg | |
| ######################################## | |
| # Prophet pipeline | |
| ######################################## | |
| def prepare_data_for_prophet(df): | |
| """ | |
| Convert the DataFrame to a Prophet-compatible format. | |
| """ | |
| if df.empty: | |
| logging.warning("Empty DataFrame, cannot prepare data for Prophet.") | |
| return pd.DataFrame(columns=["ds", "y"]) | |
| df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"}) | |
| return df_prophet[["ds", "y"]] | |
| def prophet_forecast(df_prophet, periods=10, freq="h"): | |
| """ | |
| Train a Prophet model and forecast. | |
| Using 'h' or 'd' to avoid the future deprecation warning in pandas. | |
| """ | |
| if df_prophet.empty: | |
| logging.warning("Prophet input is empty, no forecast can be generated.") | |
| return pd.DataFrame(), "No data to forecast." | |
| try: | |
| model = Prophet() | |
| model.fit(df_prophet) | |
| future = model.make_future_dataframe(periods=periods, freq=freq) | |
| forecast = model.predict(future) | |
| return forecast, "" | |
| except Exception as e: | |
| logging.error(f"Forecast error: {e}") | |
| return pd.DataFrame(), f"Forecast error: {e}" | |
| def prophet_wrapper(df_prophet, forecast_steps, freq): | |
| """ | |
| Forecast, then slice out only the new/future rows using .loc. | |
| """ | |
| if len(df_prophet) < 10: | |
| return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)." | |
| full_forecast, err = prophet_forecast(df_prophet, forecast_steps, freq) | |
| if err: | |
| return pd.DataFrame(), err | |
| # Slice from len(df_prophet) onward, for columns ds, yhat, yhat_lower, yhat_upper | |
| future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]] | |
| return future_only, "" | |
| ######################################## | |
| # Main Gradio logic | |
| ######################################## | |
| def predict(symbol, timeframe, forecast_steps): | |
| """ | |
| Orchestrate candle fetch + prophet forecast. | |
| """ | |
| # Convert user timeframe to OKX bar param | |
| okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H") | |
| df_raw, err = fetch_okx_candles(symbol, timeframe=okx_bar, limit=500) | |
| if err: | |
| return pd.DataFrame(), err | |
| df_prophet = prepare_data_for_prophet(df_raw) | |
| # If timeframe string has 'h', use 'h' for freq. Else use 'd' | |
| freq = "h" if "h" in timeframe.lower() else "d" | |
| future_df, err2 = prophet_wrapper(df_prophet, forecast_steps, freq) | |
| if err2: | |
| return pd.DataFrame(), err2 | |
| return future_df, "" | |
| def display_forecast(symbol, timeframe, forecast_steps): | |
| """ | |
| For the Gradio UI, returns forecast or error message. | |
| """ | |
| logging.info(f"User requested: symbol={symbol}, timeframe={timeframe}, steps={forecast_steps}") | |
| forecast_df, error = predict(symbol, timeframe, forecast_steps) | |
| if error: | |
| return f"Error: {error}" | |
| return forecast_df | |
| def main(): | |
| # Fetch OKX symbols | |
| symbols = fetch_okx_symbols() | |
| if not symbols or "Error" in symbols[0]: | |
| symbols = ["No symbols available"] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# OKX Price Forecasting with Prophet") | |
| gr.Markdown( | |
| "This app pulls spot-market candles from OKX, trains a simple Prophet model, " | |
| "and displays only future predictions. If you see errors or no data, try another symbol/timeframe." | |
| ) | |
| symbol_dd = gr.Dropdown( | |
| label="Symbol", | |
| choices=symbols, | |
| value=symbols[0] if symbols else None | |
| ) | |
| timeframe_dd = gr.Dropdown( | |
| label="Timeframe", | |
| choices=["1m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "12h", "1d", "1w"], | |
| value="1h" | |
| ) | |
| steps_slider = gr.Slider( | |
| label="Forecast Steps", | |
| minimum=1, | |
| maximum=100, | |
| value=10 | |
| ) | |
| forecast_btn = gr.Button("Generate Forecast") | |
| output_df = gr.Dataframe( | |
| label="Future Forecast Only", | |
| headers=["ds", "yhat", "yhat_lower", "yhat_upper"] | |
| ) | |
| forecast_btn.click( | |
| fn=display_forecast, | |
| inputs=[symbol_dd, timeframe_dd, steps_slider], | |
| outputs=output_df | |
| ) | |
| gr.Markdown( | |
| "Need more tools? Check out this " | |
| "[crypto trading bot](https://www.gunbot.com)." | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = main() | |
| app.launch() | |