Skip to content

πŸ‘ΈΒ Β Altex

Submitted by Arnaud Miribel

Summary

A simple wrapper on top of Altair to make Streamlit charts in an express API. If you're lazy and/or familiar with Altair, this is probably a good fit! Inspired by plost and plotly-express.

Functions

_chart

Create an Altair chart with a simple API. Supported charts include line, bar, point, area, histogram, sparkline, sparkbar, sparkarea.

Parameters:

Name Type Description Default
mark_function str

Altair mark function, example line/bar/point

required
data DataFrame

Dataframe to use for the chart

required
x Union[X, str]

Column for the x axis

required
y Union[Y, str]

Column for the y axis

required
color Optional[Union[Color, str]]

Color a specific group of your data. Defaults to None.

None
opacity Optional[Union[value, float]]

Change opacity of marks. Defaults to None.

None
column Optional[Union[Column, str]]

Groupby a specific column. Defaults to None.

None
rolling Optional[int]

Rolling average window size. Defaults to None.

None
title Optional[str]

Title of the chart. Defaults to None.

None
width Optional[int]

Width of the chart. Defaults to None.

None
height Optional[int]

Height of the chart. Defaults to None.

None
spark bool

Whether or not to make spark chart, i.e. a chart without axes nor ticks nor legend. Defaults to False.

False
autoscale_y bool

Whether or not to autoscale the y axis. Defaults to False.

False

Returns:

Type Description
Chart

alt.Chart: Altair chart

Source code in src/streamlit_extras/altex/__init__.py
@extra
def _chart(
    mark_function: str,
    data: pd.DataFrame,
    x: Union[alt.X, str],
    y: Union[alt.Y, str],
    color: Optional[Union[alt.Color, str]] = None,
    opacity: Optional[Union[alt.value, float]] = None,
    column: Optional[Union[alt.Column, str]] = None,
    rolling: Optional[int] = None,
    title: Optional[str] = None,
    width: Optional[int] = None,
    height: Optional[int] = None,
    spark: bool = False,
    autoscale_y: bool = False,
) -> alt.Chart:
    """Create an Altair chart with a simple API.
    Supported charts include line, bar, point, area, histogram, sparkline, sparkbar, sparkarea.

    Args:
        mark_function (str): Altair mark function, example line/bar/point
        data (pd.DataFrame): Dataframe to use for the chart
        x (Union[alt.X, str]): Column for the x axis
        y (Union[alt.Y, str]): Column for the y axis
        color (Optional[Union[alt.Color, str]], optional): Color a specific group of your data. Defaults to None.
        opacity (Optional[Union[alt.value, float]], optional): Change opacity of marks. Defaults to None.
        column (Optional[Union[alt.Column, str]], optional): Groupby a specific column. Defaults to None.
        rolling (Optional[int], optional): Rolling average window size. Defaults to None.
        title (Optional[str], optional): Title of the chart. Defaults to None.
        width (Optional[int], optional): Width of the chart. Defaults to None.
        height (Optional[int], optional): Height of the chart. Defaults to None.
        spark (bool, optional): Whether or not to make spark chart, i.e. a chart without axes nor ticks nor legend. Defaults to False.
        autoscale_y (bool, optional): Whether or not to autoscale the y axis. Defaults to False.

    Returns:
        alt.Chart: Altair chart
    """

    x_ = _get_shorthand(x)
    y_ = _get_shorthand(y)
    color_ = _get_shorthand(color)

    tooltip_config = _drop_nones([x_, y_, color_])

    chart_config = _drop_nones(
        {
            "data": data,
            "title": title,
            "mark": mark_function,
            "width": width,
            "height": height,
        }
    )

    chart = alt.Chart(**chart_config)

    if rolling is not None:
        rolling_column = f"{y_} ({rolling}-average)"
        y = f"{rolling_column}:Q"
        transform_config = {
            rolling_column: f"mean({y_})",
            "frame": [-rolling, 0],
            "groupby": [str(color)],
        }
        chart = chart.transform_window(**transform_config)

    if spark:
        chart = chart.configure_view(strokeWidth=0).configure_axis(
            grid=False, domain=False
        )
        x_axis = _update_axis_config(x, alt.X, {"axis": None})
        y_axis = _update_axis_config(y, alt.Y, {"axis": None})
    else:
        x_axis = x
        y_axis = y

    if autoscale_y:
        y_axis = _update_axis_config(y_axis, alt.Y, {"scale": alt.Scale(zero=False)})

    encode_config = _drop_nones(
        {
            "x": x_axis,
            "y": y_axis,
            "color": color,
            "tooltip": tooltip_config,
            "opacity": alt.value(opacity) if isinstance(opacity, float) else opacity,
            "column": column,
        }
    )

    chart = chart.encode(**encode_config)

    return chart

Import:

from streamlit_extras.altex import _chart # (1)!
  1. You should add this to the top of your .py file πŸ› 

scatter_chart

Source code in src/streamlit_extras/altex/__init__.py
@extra
def scatter_chart(**kwargs):
    return chart(mark_function="point", __name__="scatter_chart", **kwargs)

Import:

from streamlit_extras.altex import scatter_chart # (1)!
  1. You should add this to the top of your .py file πŸ› 

Examples

example_line

@cache_data
def example_line():
    stocks = get_stocks_data()

    line_chart(
        data=stocks.query("symbol == 'GOOG'"),
        x="date",
        y="price",
        title="A beautiful simple line chart",
    )

example_multi_line

@cache_data
def example_multi_line():
    stocks = get_stocks_data()
    line_chart(
        data=stocks,
        x="date",
        y="price",
        color="symbol",
        title="A beautiful multi line chart",
    )

example_bar

@cache_data
def example_bar():
    stocks = get_stocks_data()
    bar_chart(
        data=stocks.query("symbol == 'GOOG'"),
        x="date",
        y="price",
        title="A beautiful bar chart",
    )

example_hist

@cache_data
def example_hist():
    stocks = get_stocks_data()
    hist_chart(
        data=stocks.assign(price=stocks.price.round(0)),
        x="price",
        title="A beautiful histogram",
    )

example_scatter

@cache_data
def example_scatter():
    weather = get_weather_data()
    scatter_chart(
        data=weather,
        x=alt.X("wind:Q", title="Custom X title"),
        y=alt.Y("temp_min:Q", title="Custom Y title"),
        title="A beautiful scatter chart",
    )

example_sparkline

@cache_data
def example_sparkline():
    stocks = get_stocks_data()
    sparkline_chart(
        data=stocks.query("symbol == 'GOOG'"),
        x="date",
        y="price",
        title="A beautiful sparkline chart",
        rolling=7,
        height=150,
    )

example_minisparklines

@cache_data
def example_minisparklines():
    stocks = get_stocks_data()

    left, middle, right = st.columns(3)
    with left:
        data = stocks.query("symbol == 'GOOG'")
        st.metric("GOOG", int(data["price"].mean()))
        sparkline_chart(
            data=data,
            x="date",
            y="price:Q",
            height=80,
            autoscale_y=True,
        )
    with middle:
        data = stocks.query("symbol == 'MSFT'")
        st.metric("MSFT", int(data["price"].mean()))
        sparkline_chart(
            data=data,
            x="date",
            y="price:Q",
            height=80,
            autoscale_y=True,
        )
    with right:
        data = stocks.query("symbol == 'AAPL'")
        st.metric("AAPL", int(data["price"].mean()))
        sparkline_chart(
            data=data,
            x="date",
            y="price:Q",
            height=80,
            autoscale_y=True,
        )

example_sparkbar

@cache_data
def example_sparkbar():
    stocks = get_stocks_data()
    sparkbar_chart(
        data=stocks.query("symbol == 'GOOG'"),
        x="date",
        y="price",
        title="A beautiful sparkbar chart",
        height=150,
    )

example_sparkarea

@cache_data
def example_sparkarea():
    random_data = get_random_data()
    df = pd.melt(
        random_data,
        id_vars="index",
        value_vars=list("abcdefg"),
    )

    sparkarea_chart(
        data=df,
        x="index",
        y="value",
        color=alt.Color("variable", legend=None),
        title="A beautiful (also probably useless) sparkarea chart",
        opacity=alt.value(0.6),
        height=200,
    )

example_hist_time

@cache_data
def example_hist_time():
    weather = get_weather_data()
    hist_chart(
        data=weather,
        x="week(date):T",
        y="day(date):T",
        color=alt.Color(
            "median(temp_max):Q",
            legend=None,
        ),
        title="A beautiful time hist chart",
    )

example_bar_sorted

@cache_data
def example_bar_sorted():
    weather = get_weather_data()
    bar_chart(
        data=weather.sort_values(by="temp_max", ascending=False).head(25),
        x=alt.X("date", sort="-y"),
        y=alt.Y("temp_max:Q"),
        title="A beautiful sorted-by-value bar chart",
    )

example_bar_normalized

@cache_data
def example_bar_normalized():
    barley = get_barley_data()
    bar_chart(
        data=barley,
        x=alt.X("variety:N", title="Variety"),
        y=alt.Y("sum(yield):Q", stack="normalize"),
        color="site:N",
        title="A beautiful normalized stacked bar chart",
    )

example_bar_grouped

@cache_data
def example_bar_grouped():
    barley = get_barley_data()
    bar_chart(
        data=barley,
        x="year:O",
        y="sum(yield):Q",
        color="year:N",
        column="site:N",
        title="A beautiful grouped bar charts",
        width=90,
        use_container_width=False,
    )

example_bar_horizontal

@cache_data
def example_bar_horizontal():
    weather = get_weather_data()
    bar_chart(
        data=weather.head(15),
        x="temp_max:Q",
        y=alt.Y("date:O", title="Temperature"),
        title="A beautiful horizontal bar chart",
    )

example_bar_log

@cache_data
def example_bar_log():
    weather = get_weather_data()
    bar_chart(
        data=weather,
        x=alt.X("temp_max:Q", title="Temperature"),
        y=alt.Y(
            "count()",
            title="Count of records",
            scale=alt.Scale(type="symlog"),
        ),
        title="A beautiful histogram... with log scale",
    )

example_scatter_opacity

@cache_data
def example_scatter_opacity():
    weather = get_weather_data()
    scatter_chart(
        data=weather,
        x=alt.X("wind:Q", title="Custom X title"),
        y=alt.Y("temp_min:Q", title="Custom Y title"),
        title="A beautiful scatter chart with custom opacity",
        opacity=0.2,
    )

example_bar_normalized_custom

@cache_data
def example_bar_normalized_custom():
    barley = get_barley_data()
    bar_chart(
        data=barley,
        x=alt.X("variety", title="Variety"),
        y="sum(yield)",
        color=alt.Color("site", scale=alt.Scale(scheme="lighttealblue"), legend=None),
        title="A beautiful stacked bar chart (without legend, custom colors)",
    )