|
import solara |
|
import anywidget |
|
import traitlets |
|
from vega_datasets import data |
|
import altair as alt |
|
import matplotlib.pyplot as plt |
|
from matplotlib.figure import Figure |
|
import pandas as pd |
|
from typing import List |
|
from typing_extensions import TypedDict |
|
|
|
class ChartWidget(anywidget.AnyWidget): |
|
spec = traitlets.Dict().tag(sync=True) |
|
selection = traitlets.Dict().tag(sync=True) |
|
_esm = """ |
|
import embed from "https://cdn.jsdelivr.net/npm/vega-embed@6/+esm"; |
|
async function render({ model, el }) { |
|
let spec = model.get("spec"); |
|
let api = await embed(el, spec); |
|
api.view.addSignalListener(spec.params[0].name, (_, update) => { |
|
model.set("selection", update); |
|
model.save_changes(); |
|
}) |
|
} |
|
export default { render }; |
|
""" |
|
|
|
selected: solara.Reactive[TypedDict] = solara.reactive({'Horsepower': [45, 231], 'Miles_per_Gallon': [8, 47]}) |
|
|
|
@solara.component_vue("viewlistener.vue") |
|
def ViewListener(view_data=None, on_view_data=None, children=[], style={}): |
|
... |
|
|
|
@solara.component |
|
def Plot(sub): |
|
dpi = 100 |
|
view_data = solara.use_reactive({"width": 800, "height": 600}) |
|
width, height = view_data.value["width"], view_data.value["height"] |
|
fig = Figure(figsize=(width / dpi, height / dpi), dpi=dpi) |
|
ax = fig.subplots() |
|
ax.hist(sub["Weight_in_lbs"], edgecolor="#26a269", facecolor ="#57e389") |
|
ax.set_xlabel("Weight_in_lbs") |
|
ax.set_ylabel("Count of Records") |
|
with ViewListener(view_data=view_data.value, on_view_data=view_data.set, style={"width": "100%", "height": "55vh"}): |
|
solara.FigureMatplotlib(fig) |
|
|
|
counter = solara.reactive(0) |
|
@solara.component |
|
def Page(): |
|
title = "AnyWidget+Solara: Cars dataset" |
|
with solara.Head(): |
|
solara.Title(f"{title}") |
|
with solara.AppBar(): |
|
solara.lab.ThemeToggle(enable_auto=False) |
|
with solara.Column(style={"padding":"30px"}): |
|
|
|
df = data.cars() |
|
brush = alt.selection_interval() |
|
|
|
points = alt.Chart(df).mark_point().encode( |
|
x = "Horsepower", |
|
y = "Miles_per_Gallon", |
|
color = alt.condition(brush, "Origin", alt.value("lightgray")), |
|
tooltip = ["Horsepower", "Miles_per_Gallon"], |
|
).add_params( |
|
brush |
|
) |
|
bars = alt.Chart(df).mark_bar().encode( |
|
y = "Origin", |
|
color = "Origin", |
|
x = "count(Origin)" |
|
).transform_filter( |
|
brush |
|
) |
|
chart = (points & bars) |
|
sub = df |
|
with solara.Row(): |
|
widget = ChartWidget.element(spec=chart.to_dict(), on_selection=selected.set) |
|
with solara.Column(style={"margin": "0"}): |
|
for field, (lower, upper) in (selected.value).items(): |
|
sub = sub[(sub[field]>lower) & (sub[field]<upper)] |
|
if not sub.empty: |
|
Plot(sub) |
|
solara.DataFrame(sub, items_per_page=10) |
|
file_object = sub.to_csv(index=False) |
|
with solara.FileDownload(file_object, "cars_subset.csv", mime_type="application/vnd.ms-excel"): |
|
solara.Button("Download selection", icon_name="mdi-cloud-download-outline", color="primary") |
|
Page() |
|
|