import gradio as gr from gradio.themes import Base from gradio.themes.utils import colors, sizes, fonts from typing import Iterable import yfinance as yf import numpy as np import pandas as pd import plotly.graph_objects as go from datetime import datetime # Custom Seafoam theme definition class Seafoam(Base): def __init__( self, *, primary_hue: colors.Color | str = colors.emerald, secondary_hue: colors.Color | str = colors.blue, neutral_hue: colors.Color | str = colors.blue, spacing_size: sizes.Size | str = sizes.spacing_md, radius_size: sizes.Size | str = sizes.radius_md, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Quicksand"), "ui-sans-serif", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, spacing_size=spacing_size, radius_size=radius_size, text_size=text_size, font=font, font_mono=font_mono, ) super().set( body_background_fill="linear-gradient(to bottom right, *primary_50, *primary_100, *primary_200)", body_background_fill_dark="linear-gradient(to bottom right, *primary_900, *primary_800, *primary_700)", button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", button_primary_text_color="white", button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", slider_color="*secondary_300", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_shadow="*shadow_drop_lg", button_large_padding="32px", ) seafoam = Seafoam() def download_stock_data(ticker, start_date, end_date): stock = yf.Ticker(ticker) df = stock.history(start=start_date, end=end_date) return df def plot_chart(ticker, start_date, end_date, chart_type): try: stock = yf.Ticker(ticker) data = stock.history(start=start_date, end=end_date) if data.empty: return "No data available for the specified date range.", None, None, None if chart_type == "Log": return plot_logarithmic_chart(data, ticker) else: return plot_candlestick_chart(data, ticker) except Exception as e: return f"An error occurred: {str(e)}", None, None, None def plot_logarithmic_chart(data, ticker): x = (data.index - data.index[0]).days y = np.log(data['Close']) slope, intercept = np.polyfit(x, y, 1) future_days = 365 * 10 all_days = np.arange(len(x) + future_days) log_trend = np.exp(intercept + slope * all_days) inner_upper_band = log_trend * 2 inner_lower_band = log_trend / 2 outer_upper_band = log_trend * 4 outer_lower_band = log_trend / 4 extended_dates = pd.date_range(start=data.index[0], periods=len(all_days), freq='D') fig = go.Figure() fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close Price', line=dict(color='blue'))) fig.add_trace(go.Scatter(x=extended_dates, y=log_trend, mode='lines', name='Log Trend', line=dict(color='red'))) fig.add_trace(go.Scatter(x=extended_dates, y=inner_upper_band, mode='lines', name='Inner Upper Band', line=dict(color='#6FB1A7'))) fig.add_trace(go.Scatter(x=extended_dates, y=inner_lower_band, mode='lines', name='Inner Lower Band', line=dict(color='#6FB1A7'))) fig.add_trace(go.Scatter(x=extended_dates, y=outer_upper_band, mode='lines', name='Outer Upper Band', line=dict(color='#FFC2A5'))) fig.add_trace(go.Scatter(x=extended_dates, y=outer_lower_band, mode='lines', name='Outer Lower Band', line=dict(color='#FFC2A5'))) fig.update_layout( title={ 'text': f'Stock Log Chart: {ticker}', 'y': 0.95, 'x': 0.5, 'xanchor': 'center', 'yanchor': 'top' }, xaxis_title='Date', yaxis_title='Price (Log Scale)', yaxis_type="log", height=800, legend=dict(x=0.01, y=0.99, bgcolor='rgba(255, 255, 255, 0.8)'), hovermode='x unified', plot_bgcolor='#F5F9F8', paper_bgcolor='#F5F9F8', font=dict(family="Quicksand, sans-serif", size=12, color="#313D38") ) fig.update_xaxes( rangeslider_visible=True, rangeselector=dict( buttons=list([ dict(count=1, label="1m", step="month", stepmode="backward"), dict(count=6, label="6m", step="month", stepmode="backward"), dict(count=1, label="YTD", step="year", stepmode="todate"), dict(count=1, label="1y", step="year", stepmode="backward"), dict(step="all") ]) ) ) current_price = data['Close'].iloc[-1] log_price = log_trend[-len(data):] percent_diff = ((current_price - log_price.iloc[-1]) / log_price.iloc[-1]) * 100 return fig, current_price, log_price.iloc[-1], percent_diff def plot_candlestick_chart(data, ticker): fig = go.Figure(data=[go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'])]) fig.update_layout( title={ 'text': f'Stock Candlestick Chart: {ticker}', 'y': 0.95, 'x': 0.5, 'xanchor': 'center', 'yanchor': 'top' }, xaxis_title='Date', yaxis_title='Price', height=800, legend=dict(x=0.01, y=0.99, bgcolor='rgba(255, 255, 255, 0.8)'), hovermode='x unified', plot_bgcolor='#F5F9F8', paper_bgcolor='#F5F9F8', font=dict(family="Quicksand, sans-serif", size=12, color="#313D38") ) fig.update_xaxes( rangeslider_visible=True, rangeselector=dict( buttons=list([ dict(count=1, label="1m", step="month", stepmode="backward"), dict(count=6, label="6m", step="month", stepmode="backward"), dict(count=1, label="YTD", step="year", stepmode="todate"), dict(count=1, label="1y", step="year", stepmode="backward"), dict(step="all") ]) ) ) current_price = data['Close'].iloc[-1] return fig, current_price, None, None def format_price_info(current_price, log_price, percent_diff): if current_price is None: return "Unable to retrieve price information." if log_price is None or percent_diff is None: return f"Current Price: ${current_price:.2f}" color = "green" if percent_diff > 0 else "red" intensity = min(abs(percent_diff) / 100, 1) # Normalize to 0-1 bg_color = f"rgba({255 if color == 'red' else 0}, {255 if color == 'green' else 0}, 0, {intensity})" return f"""

Current Price: ${current_price:.2f}

Log Price: ${log_price:.2f}

{percent_diff:.2f}% {'above' if percent_diff > 0 else 'below'} log price

""" # Get the current date current_date = datetime.now().strftime("%Y-%m-%d") # Custom CSS for button and input hover effects custom_css = """ #generate-button:hover { background-color: #FFB3BA !important; /* Pastel red */ } #ticker-input input:hover, #start-date-input input:hover, #end-date-input input:hover { background-color: #FFB3BA !important; /* Pastel red */ } """ # Update the Gradio interface section with gr.Blocks(theme=seafoam, title="Stock Charts", css=custom_css) as iface: gr.Markdown("# Stock Charts") gr.Markdown("Enter a stock ticker and date range to generate a chart.") with gr.Row(): ticker = gr.Textbox(label="Stock Ticker", value="MSFT", elem_id="ticker-input") start_date = gr.Textbox(label="Start Date", value="2015-01-01", elem_id="start-date-input") end_date = gr.Textbox(label="End Date", value=current_date, elem_id="end-date-input") with gr.Accordion("Chart Options", open=False): chart_type = gr.Radio(["Log", "Candlestick"], label="Chart Type", value="Log") submit_button = gr.Button("Generate Chart", elem_id="generate-button") with gr.Row(): chart = gr.Plot(label="Stock Chart") price_info = gr.HTML(label="Price Information") def update_chart_and_info(ticker, start_date, end_date, chart_type): chart_data, current_price, log_price, percent_diff = plot_chart(ticker, start_date, end_date, chart_type) price_info_html = format_price_info(current_price, log_price, percent_diff) return chart_data, price_info_html submit_button.click( update_chart_and_info, inputs=[ticker, start_date, end_date, chart_type], outputs=[chart, price_info] ) # Launch the app iface.launch()