from typing import Dict, List, Optional
import plotly.io as pio
import pandas as pd
import plotly.graph_objects as go
import vizro.models as vm
from vizro import Vizro
from vizro.models.types import capture
import vizro.plotly.express as px
pastries = pd.DataFrame(
{
"pastry": [
"Scones",
"Bagels",
"Muffins",
"Cakes",
"Donuts",
"Cookies",
"Croissants",
"Eclairs",
"Brownies",
"Tarts",
"Macarons",
"Pies",
],
# "Profit Ratio": [-0.10, -0.15, -0.05, 0.10, 0.05, 0.20, 0.15, -0.08, 0.08, -0.12, 0.02, -0.07],
"Strongly Disagree": [20, 30, 10, 5, 15, 5, 10, 25, 8, 20, 5, 10],
"Disagree": [30, 25, 20, 10, 20, 10, 15, 30, 12, 30, 10, 15],
"Agree": [30, 25, 40, 40, 45, 40, 40, 25, 40, 30, 45, 35],
"Strongly Agree": [20, 20, 30, 45, 20, 45, 35, 20, 40, 20, 40, 40],
}
)
pastries["Disagree"] = -pastries["Disagree"]
pastries["Strongly Disagree"] = -pastries["Strongly Disagree"]
pastries_long = pd.melt(pastries, id_vars=["pastry"], value_vars=["Strongly Disagree", "Disagree", "Agree", "Strongly Agree"])
pastries_2 = pastries.copy()
pastries_2 = pastries_2.set_index("pastry")
pastries_2.columns.name = "Opinion"
# print(pastries)
# print(pastries_long)
@capture("graph")
def diverging_stacked_bar2(
data_frame,
**kwargs
) -> go.Figure:
"""Creates a horizontal diverging stacked bar chart (with positive and negative values only) using Plotly's go.Bar.
This type of chart is a variant of the standard stacked bar chart, with bars aligned on a central baseline to
show both positive and negative values. Each bar is segmented to represent different categories.
This function is not suitable for diverging stacked bar charts that include a neutral category.
Inspired by: https://community.plotly.com/t/need-help-in-making-diverging-stacked-bar-charts/34023
Args:
data_frame (pd.DataFrame): The data frame for the chart.
y (str): The name of the categorical column in the data frame to be used for the y-axis (categories)
category_pos (List[str]): List of column names in the data frame representing positive values. Columns should be
ordered from least to most positive.
category_neg (List[str]): List of column names in the DataFrame representing negative values. Columns should be
ordered from least to most negative.
color_discrete_map: Optional[Dict[str, str]]: A dictionary mapping category names to color strings.
Returns:
go.Figure: A Plotly Figure object representing the horizontal diverging stacked bar chart.
"""
fig = px.bar(data_frame, **kwargs)
# apply color_discrete_sequence. What if user has specified it?
xx = px.colors.sample_colorscale([list(x) for x in fig.layout.template.layout.colorscale.diverging], len(fig.data), 0.2, 0.8)
for i, (trace, color) in enumerate(zip(fig.data, xx)):
trace.update(legendrank=i, marker_color=color)
orientation = fig.data[0].orientation
# Does the swap
negative_traces = {trace_idx: trace for trace_idx, trace in enumerate(fig.data) if all(value <= 0 for value in getattr(trace, "x" if orientation == "h" else "y"))}
mutable_traces = list(fig.data)
for trace_idx, trace in zip(reversed(negative_traces.keys()), negative_traces.values()):
mutable_traces[trace_idx] = trace
fig.data = mutable_traces
if orientation == "h":
fig.add_vline(x=0, line_width=2, line_color="grey")
else:
fig.add_hline(y=0, line_width=2, line_color="grey")
return fig
page = vm.Page(
title="Diverging stacked bar",
layout=vm.Layout(grid=[[0, 1, 2]]),
components=[
vm.Graph(
title="Would you recommend the pastry to your friends?",
figure=diverging_stacked_bar2(
data_frame=pastries,
x=["Strongly Disagree", "Disagree", "Agree", "Strongly Agree"],
y="pastry",
labels={"value": "Response count", "variable": "Opinion"},
)
),
vm.Graph(
title="Would you recommend the pastry to your friends?",
figure=diverging_stacked_bar2(
data_frame=pastries_long,
color="variable",
x="value",
y="pastry",
labels={"value": "Response count", "variable": "Opinion"},
),
),
vm.Graph(
title="Would you recommend the pastry to your friends?",
figure=diverging_stacked_bar2(
data_frame=pastries_2,
orientation="v",
),
),
],
)
dashboard = vm.Dashboard(pages=[page])
app = Vizro().build(dashboard)
#Vizro().build(dashboard).run()