Source code for pandas_survey_toolkit.vis

import textwrap
from typing import List, Tuple

import altair as alt
import pandas as pd


[docs] def cluster_heatmap_plot(df: pd.DataFrame, x: str, y: List[str], max_width: int = 75): """ Create a heatmap visualization of Likert scale responses grouped by clusters. This function generates an interactive Altair visualization showing the distribution of positive and negative responses across different clusters for each question. The visualization consists of two parts: 1. A bar chart showing the number of respondents in each cluster 2. A heatmap showing the sentiment distribution for each question by cluster Parameters ---------- df : pd.DataFrame The DataFrame containing the clustered data and encoded Likert responses. Should include a cluster column and encoded Likert columns. x : str The name of the column containing cluster IDs (e.g., 'question_cluster_id'). y : List[str] List of column names containing the encoded Likert responses. These should typically be columns with values -1, 0, 1 representing negative, neutral, and positive responses. max_width : int, default=75 Maximum width for wrapping question labels in the visualization. Returns ------- alt.VConcatChart An Altair chart object combining a bar chart of cluster sizes and a heatmap of sentiment distribution that can be displayed in a Jupyter notebook or exported as HTML. Notes ----- The function color-codes the heatmap cells based on the percentage of positive and negative responses, with green representing positive sentiment, red representing negative sentiment, and varying shades for mixed responses. The encoded Likert columns (y parameter) should contain values that are encoded as: * 1 for positive responses * 0 for neutral responses * -1 for negative responses Examples -------- >>> # Assuming df has been processed with cluster_questions >>> likert_columns = [f"likert_encoded_{q}" for q in questions] >>> heatmap = cluster_heatmap_plot(df, x="question_cluster_id", y=likert_columns) >>> display(heatmap) """ # Convert -1, 0, 1 to percent positive and percent negative df_positive = df[y].apply(lambda col: (col == 1).astype(int)) df_negative = df[y].apply(lambda col: (col == -1).astype(int)) # Calculate average percent positive and negative for each cluster and question heatmap_data_pos = ( df_positive.groupby(df[x]) .mean() .reset_index() .melt(id_vars=x, var_name="question", value_name="percent_positive") ) heatmap_data_neg = ( df_negative.groupby(df[x]) .mean() .reset_index() .melt(id_vars=x, var_name="question", value_name="percent_negative") ) # Merge positive and negative data heatmap_data = pd.merge(heatmap_data_pos, heatmap_data_neg, on=[x, "question"]) heatmap_data["percent_neutral"] = ( 1 - heatmap_data["percent_positive"] - heatmap_data["percent_negative"] ) # Calculate overall positivity for each cluster cluster_positivity = ( heatmap_data.groupby(x)["percent_positive"].mean().sort_values(ascending=False) ) cluster_order = cluster_positivity.index.tolist() # Replace underscores with spaces in question labels heatmap_data["question"] = ( heatmap_data["question"].str.replace("_", " ").str.replace("likert encoded", "") ) # Wrap long question labels wrapped_labels = [ textwrap.fill(label, width=max_width) for label in heatmap_data["question"].unique() ] label_to_wrapped = dict(zip(heatmap_data["question"].unique(), wrapped_labels)) heatmap_data["wrapped_question"] = heatmap_data["question"].map(label_to_wrapped) # Define color scale based on percent positive and percent negative def get_color(pos: float, neg: float) -> Tuple[str, str]: if pos > 0.6: return "#1a9641", "white" # Strong positive (green) elif pos > 0.4: return "#a6d96a", "black" # Moderate positive (light green) elif pos > neg: return "#ffffbf", "black" # Slightly positive (light yellow) elif neg > 0.6: return "#d7191c", "white" # Strong negative (red) elif neg > 0.4: return "#fdae61", "black" # Moderate negative (orange) elif neg > pos: return "#f4a582", "black" # Slightly negative (light red) else: return "#f7f7f7", "black" # Neutral (light gray) heatmap_data["background_color"], heatmap_data["text_color"] = zip( *heatmap_data.apply( lambda row: get_color(row["percent_positive"], row["percent_negative"]), axis=1, ) ) # Calculate chart dimensions chart_width = 600 row_height = 30 heatmap_height = len(wrapped_labels) * row_height bar_chart_height = 100 # Create heatmap heatmap = ( alt.Chart(heatmap_data) .mark_rect() .encode( x=alt.X(f"{x}:O", title="Cluster ID", sort=cluster_order), y=alt.Y("wrapped_question:O", title=None, sort=wrapped_labels), color=alt.Color("background_color:N", scale=None), tooltip=[ alt.Tooltip(f"{x}:O", title="Cluster ID"), alt.Tooltip("question:O", title="Question"), alt.Tooltip("percent_positive:Q", title="% Positive", format=".2%"), alt.Tooltip("percent_negative:Q", title="% Negative", format=".2%"), alt.Tooltip("percent_neutral:Q", title="% Neutral", format=".2%"), ], ) .properties( width=chart_width, height=heatmap_height, title="Cluster Heatmap: Sentiment Distribution", ) ) # Add text labels to heatmap text = heatmap.mark_text(baseline="middle").encode( text=alt.Text("percent_positive:Q", format=".0%"), color=alt.Color("text_color:N", scale=None), ) # Create bar chart for cluster counts cluster_counts = df[x].value_counts().reset_index() cluster_counts.columns = [x, "count"] cluster_counts[x] = pd.Categorical( cluster_counts[x], categories=cluster_order, ordered=True ) cluster_counts = cluster_counts.sort_values(x) bar_chart = ( alt.Chart(cluster_counts) .mark_bar() .encode( x=alt.X(f"{x}:O", title="Cluster ID", sort=cluster_order), y=alt.Y("count:Q", title="Count"), tooltip=[ alt.Tooltip(f"{x}:O", title="Cluster ID"), alt.Tooltip("count:Q", title="Count"), ], ) .properties(width=chart_width, height=bar_chart_height, title="Cluster Sizes") ) # Add text labels to bar chart bar_text = bar_chart.mark_text(align="center", baseline="bottom", dy=-5).encode( text="count:Q" ) # Combine bar chart and heatmap using vconcat combined_chart = ( alt.vconcat((bar_chart + bar_text), (heatmap + text)) .configure_view(strokeWidth=0) .configure_axis( labelLimit=350 # Increase label limit to show full wrapped text ) ) return combined_chart