Py.Cafe

nitinmagima/

decoded-futures

Decoded Futures - Admissions Probability & College Clustering

DocsPricing
  • app.py
  • kmeans_college.pkl
  • logreg_admissions.pkl
  • requirements.txt
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
import streamlit as st
import pandas as pd
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegressionCV
import plotly.graph_objects as go
import plotly.express as px
import math


# Set the title and subheader with emojis
st.title("Student Leadership Network's Post Secondary Education Recommendation Engine 🏫")
st.subheader("Supported by Decoded Futures Project")
st.write("AI Volunteer Expert - Nitin Magima")

# Create an expander for instructions
with st.expander("Click for Guidance & Tips ✨"):
    st.markdown("""

    ## **Welcome to our Postsecondary Education Recommendation Engine!**

    We are using the "A How-To Guide to High Quality, Data-Driven Advising" by  Ryan Hoch, CEO of Overgrad, as a guide to build a Post Secondary Education Recommendation Model with ChatGPT. This recommendation system combines advanced machine learning techniques with personalized guidance to help support councillors while guiding students on their higher education journey. 
    
    The Recommendation Engine has been built using a subset of anonymized student data for demonstration purposes.

    ### Admissions Probability Model
    - Utilizes logistic regression to analyze admissions data
    - Estimates the likelihood of acceptance for each school
    - Categorizes schools as "likely," "match," or "reach" for the student

    ### Interactive Post Secondary Education Recommendation Engine
    - Integrates machine learning outputs with ChatGPT for a conversational interface
    - Facilitates human-in-the-loop interaction between students and advisors
    - Enables personalized discussions to develop optimal application strategies

    ### **How to Navigate This App**:

    1. **Fill** out your student profile in **"Student Profile for Admissions Probability"**:
       - **GPA**: Use the slider (0–100) βš™οΈ  
       - **SAT**: Use the slider (400–1600) πŸ“ˆ  
       - **Free/Reduced Lunch** & **First Gen**: Toggle to fine-tune recommendations 🍎

    2. **Adjust** threshold sliders for **Likely** and **Reach**:
       - **Likely** threshold: higher = only very high-confidence fits 🀝  
       - **Reach** threshold: adjust to expand or shrink your ambitious choices πŸš€

    3. **Click** the **"Get Recommendations"** button:
       - See **top 10** colleges by **Likely/Match/Reach** categories πŸŽ‰

    4. **Check** the **College Cluster Summaries Based on K Means**:
       - Observe how colleges cluster and what their average metrics are πŸ’‘

    5. **Explore** the **"Manual Likely/Match/Reach Selection"** (optional):
       - Handpick which schools you’d label as Likely, Match, and Reach.
       - **Download** your final picks as a convenient CSV πŸ’Ύ

    6. **Integrate** with Internal Custom GPT πŸ€–:
       - Upload the CSV file to your custom GPT model. 
       - This allows you to ask questions about the colleges, such as their locations, available scholarships, and course offerings, while engaging with students in a conversational manner πŸ’¬

    ---
    **Thank you for your dedication to guiding students through their post-secondary journey. We hope this tool empowers you with data-driven insights to support informed, personalized advising. Best of luck in helping students navigate their future with confidence!** πŸŽ“

    """)


##################################################
# 1) LOAD MODELS
##################################################
try:
    with open("logreg_admissions.pkl", "rb") as f:
        (
            logreg_cv,
            scaler,
            training_columns,
            num_cols,
            ethnicity_values,
            col_mapping_logreg,
            min_gpa,
            max_gpa,
            min_sat,
            max_sat
        ) = pickle.load(f)

    with open("kmeans_college.pkl", "rb") as f:
        (
            kmeans_model,
            kmeans_scaler,
            df_colleges_final,
            col_mapping_kmeans
        ) = pickle.load(f)

except FileNotFoundError as e:
    st.error(f"Could not find pickled model file: {e}")
    st.stop()


##################################################
# 2) The 'recommend_colleges_for_student' function
##################################################
def recommend_colleges_for_student(
    model,
    scaler,
    training_columns,
    numeric_cols,
    colleges_list,
    student_gpa,
    student_sat,
    free_lunch_eligible,
    ethnicity,
    first_gen,
    threshold_likely=0.7,
    threshold_reach=0.3
):
    """
    Builds scenario DataFrame for each college, one-hot encodes, scales,
    calls model.predict_proba, classifies as L/M/R, returns sorted DataFrame.
    """
    df_scenario = pd.DataFrame({
        "college": colleges_list,
        "gpa_(0-100_scale)": [student_gpa]*len(colleges_list),
        "highest_sat_score_(composite)": [student_sat]*len(colleges_list),
        "free_and_reduced_lunch_eligible": [free_lunch_eligible]*len(colleges_list),
        "ethnicity": [ethnicity]*len(colleges_list),
        "first_gen_college": [first_gen]*len(colleges_list)
    })

    cat_cols = []
    for col in ["college", "ethnicity"]:
        if col in df_scenario.columns and df_scenario[col].dtype == object:
            cat_cols.append(col)

    df_scenario_encoded = pd.get_dummies(df_scenario, columns=cat_cols, drop_first=True)

    # Align columns
    missing_cols = set(training_columns) - set(df_scenario_encoded.columns)
    for c in missing_cols:
        df_scenario_encoded[c] = 0

    extra_cols = set(df_scenario_encoded.columns) - set(training_columns)
    df_scenario_encoded.drop(columns=extra_cols, inplace=True, errors="ignore")
    df_scenario_encoded = df_scenario_encoded[training_columns]

    # Scale numeric
    df_scenario_encoded[numeric_cols] = scaler.transform(df_scenario_encoded[numeric_cols])

    # Predict acceptance
    probs = model.predict_proba(df_scenario_encoded)[:,1]

    # Classify into L/M/R
    categories = []
    for p in probs:
        if p >= threshold_likely:
            categories.append("Likely")
        elif p < threshold_reach:
            categories.append("Reach")
        else:
            categories.append("Match")

    df_results = pd.DataFrame({
        "college": colleges_list,
        "prob_of_acceptance": probs,
        "category": categories
    }).sort_values("prob_of_acceptance", ascending=False)
    return df_results


##################################################
# 3) LOGISTIC REGRESSION UI
##################################################
st.markdown("# Admissions Probability Based on Student Profile πŸ§‘β€πŸŽ“")

# Sliders for GPA & SAT
student_gpa = st.slider("GPA (0-100 Scale)", min_value=min_gpa, max_value=max_gpa, value=(min_gpa+max_gpa)//2)
student_sat = st.slider("Highest SAT (Composite)", min_value=min_sat, max_value=max_sat, value=(min_sat+max_sat)//2)

# If you want to let the user pick from the pickled `ethnicity_values`:
if ethnicity_values:
    ethnicity_str = st.selectbox("Ethnicity", ethnicity_values)
else:
    ethnicity_str = st.text_input("Ethnicity", "Hispanic, Latino, or Spanish Origin")

free_lunch_eligible = st.selectbox("Free/Reduced Lunch Eligible?", [0,1], index=0)
first_gen = st.selectbox("First Gen College?", [0,1], index=0)

threshold_likely = st.slider("Likely Threshold", 0.0, 1.0, 0.7, 0.05)
threshold_reach = st.slider("Reach Threshold", 0.0, 1.0, 0.3, 0.05)

if st.button("Get Recommendations"):
    # Use logistic regression
    colleges_list = df_colleges_final["college"].unique()
    df_recs = recommend_colleges_for_student(
        logreg_cv,
        scaler,
        training_columns,
        num_cols,
        colleges_list,
        student_gpa,
        student_sat,
        free_lunch_eligible,
        ethnicity_str,
        first_gen,
        threshold_likely,
        threshold_reach
    )
    # rename columns with col_mapping_logreg if desired
    df_recs_renamed = df_recs.rename(columns=col_mapping_logreg)

    
    df_likely = df_recs_renamed[df_recs_renamed[col_mapping_logreg.get("category", "category")] == "Likely"]
    df_match = df_recs_renamed[df_recs_renamed[col_mapping_logreg.get("category", "category")] == "Match"]
    df_reach = df_recs_renamed[df_recs_renamed[col_mapping_logreg.get("category", "category")] == "Reach"]

    st.write("## Likely Colleges")
    st.dataframe(df_likely)
    st.write("## Match Colleges")
    st.dataframe(df_match)
    st.write("## Reach Colleges")
    st.dataframe(df_reach)

st.write("")
st.markdown("---")
st.write("")

##################################################
# 4) K-MEANS CLUSTER LOOKUP & TABLE
##################################################
st.markdown("# College Cluster Summaries & Lookup - Based on K-Means πŸ“š")

cluster_summary = df_colleges_final.groupby("kmeans_cluster").mean(numeric_only=True)
# 1) Subset only the columns we have a mapping for
cols_to_keep = list(col_mapping_kmeans.keys())

# Intersect with whatever columns actually exist in cluster_summary 
# to avoid KeyErrors if some columns in the mapping aren't in the DataFrame:
cols_to_keep = [c for c in cols_to_keep if c in cluster_summary.columns]

cluster_summary_subset = cluster_summary[cols_to_keep]

# 2) Rename them
cluster_summary_renamed = cluster_summary_subset.rename(columns=col_mapping_kmeans)

# 3) Display
STUDENT_PROFILE_FRL = "Mean Free/Reduced Lunch"
STUDENT_PROFILE_FIRSTGEN = "Mean First Gen"

def create_concentric_donut_all_clusters(cluster_summary_renamed, col_name="Mean Free/Reduced Lunch"):
    """
    Creates a multi-ring donut for one metric across all clusters.
    Each cluster = one ring.
    Each ring has 2 slices => [ fraction, 1 - fraction ].
    Each ring is assigned a unique color from 'color_options'.

    cluster_summary_renamed: DataFrame with index => cluster labels,
        columns => must include col_name
    col_name: e.g. "Mean Free/Reduced Lunch" or "Mean First Gen"

    Returns a Plotly Figure, or None if no rings are made.
    """

    clusters = cluster_summary_renamed.index.tolist()
    fig = go.Figure()

    thickness = 0.15   # ring thickness
    spacing = 0.15     # spacing between rings

    ring_count = 0
    for i, cluster_label in enumerate(clusters):
        # Attempt to get the value
        val = cluster_summary_renamed.loc[cluster_label, col_name]
        # Check if val is numeric & not NaN
        if not isinstance(val, (int, float)) or val is None or (val != val):
            continue
        
        val = float(val)
        # If val > 1, interpret as out of 100 => convert to fraction in [0..1]
        if val > 1.0:
            val /= 100.0
        val = max(0.0, min(val, 1.0))  # clamp to [0..1]

        # 2-slice ring => [visible fraction, remainder]
        labels = [str(cluster_label), ""]
        values = [val, 1.0 - val]

        hole_start = i*spacing
        hole_end = hole_start + thickness
        # If ring extends beyond radius=1.0, we skip or break
        if hole_end > 1.0:
            break

        fig.add_trace(go.Pie(
            labels=labels,
            values=values,
            hole=hole_end,             # inner radius
            textinfo="label+percent",    # show cluster name & percent
            hoverinfo="label+percent",
            sort=False,
            marker=dict(
                colors=["", "rgb(18, 18, 18)"],
#                line=dict(color="#000", width=1)  # black outline for each slice
            ),
            domain=dict(x=[0,1], y=[0,1]),
            name=str(cluster_label)
        ))
        ring_count += 1

    # If no rings were created => None
    if ring_count == 0:
        return None

    fig.update_layout(
        title=f"All Clusters - {col_name}",
        showlegend=False,
        height=500,
        template="plotly_dark"
    )
    return fig

def display_student_profile_tab_concentric_2(cluster_summary_renamed):
    """
    Creates exactly TWO Plotly figures:
      1) Multi-ring donut for 'Mean Free/Reduced Lunch' across all clusters
      2) Multi-ring donut for 'Mean First Gen' across all clusters
    Then displays them side by side in Streamlit.
    """

    st.subheader("Student Profiles Across Clusters")

    # 1) Concentric Donut for Free Lunch
    fig_frl = create_concentric_donut_all_clusters(
        cluster_summary_renamed,
        col_name=STUDENT_PROFILE_FRL
    )

    # 2) Concentric Donut for First Gen
    fig_fg = create_concentric_donut_all_clusters(
        cluster_summary_renamed,
        col_name=STUDENT_PROFILE_FIRSTGEN
    )

    # Display them in two columns
    col1, col2 = st.columns(2)
    if fig_frl:
        col1.plotly_chart(fig_frl, use_container_width=True)
    else:
        col1.write("No Free/Reduced Lunch data for any cluster.")

    if fig_fg:
        col2.plotly_chart(fig_fg, use_container_width=True)
    else:
        col2.write("No First Gen data for any cluster.")


#display_student_profile_tab(cluster_summary_renamed)

college_stats_cols = [
    "Mean Total College Apps",
    "Mean Acceptance Rate",
    "Avg GPA (0-100 Scale)"
]

def create_college_stats_grouped_bar_clusters_legend(cluster_summary_renamed):
    """
    Builds a grouped bar chart with the 3 metrics on the x-axis:
      - "Mean Total College Apps"
      - "Mean Acceptance Rate" (converted to % if <= 1)
      - "Avg GPA (0-100 Scale)"

    Each cluster is a separate bar trace (=> clusters in the legend).
    So the x-axis => [metric1, metric2, metric3], 
    each cluster has one bar for each metric.
    
    cluster_summary_renamed:
      index => cluster labels
      columns => at least the 3 above
    Returns: Plotly Figure with grouped bars, cluster-based legend.
    """

    clusters = cluster_summary_renamed.index.tolist()
    if not clusters:
        return None

    # We define a fixed x-axis of the 3 metrics
    x_axis_metrics = college_stats_cols

    fig = go.Figure()

    # For each cluster => one trace
    for cluster_label in clusters:
        row = cluster_summary_renamed.loc[cluster_label]
        
        x_data = []
        y_data = []
        text_data = []

        for metric in x_axis_metrics:
            val = 0.0
            if metric in row.index and pd.notnull(row[metric]):
                val = float(row[metric])
                # If "Mean Acceptance Rate" <= 1 => treat as fraction => multiply by 100
                if metric == "Mean Acceptance Rate" and val <= 1.0:
                    val *= 100.0
            x_data.append(metric)        # metric as x label
            y_data.append(val)          # numeric value
            text_data.append(f"{val:.1f}")  # label above bar

        # Add a bar trace for this cluster
        fig.add_trace(go.Bar(
            x=x_data,
            y=y_data,
            name=str(cluster_label),     # cluster name in legend
            text=text_data,
            textposition='outside',
            cliponaxis=False
        ))

    # Group them side by side
    fig.update_layout(
        barmode='group',        # group clusters next to each metric
        xaxis_title="Metric",
        yaxis_title="Value",
        height=500,
        template="plotly_dark"
    )
    return fig

def display_college_stats_tab_grouped_bar(cluster_summary_renamed):
    """
    Creates a single grouped bar chart with x-axis = the 3 metrics,
    each cluster is a separate bar trace => cluster in legend.
    """
    st.subheader("College Stats - Grouped Bar Chart of All Clusters")
    if cluster_summary_renamed.empty:
        st.write("No data to display.")
        return

    fig = create_college_stats_grouped_bar_clusters_legend(cluster_summary_renamed)
    if fig:
        st.plotly_chart(fig, use_container_width=True)
    else:
        st.write("No valid data for the grouped bar chart.")


#display_college_stats_tab(cluster_summary_renamed)

sat_score_cols = [
    "Avg SAT Score (Composite)",
    "Avg SAT Score (Super Score)",
    "Avg SAT Math Score",
    "Avg SAT Reading Score"
]

def create_sat_score_bars_all_clusters(cluster_summary_renamed):
    """
    Creates a single grouped bar chart that compares all clusters side by side
    for each SAT metric in sat_score_cols.

    cluster_summary_renamed:
      Rows: cluster labels/index
      Columns: includes the sat_score_cols with numeric values

    Returns: A Plotly Figure with grouped bars.
    """

    # 1) Prepare a Plotly figure
    fig = go.Figure()

    # 2) For each cluster, create a Bar trace
    #    x = the 4 SAT metrics, y = the average values
    cluster_labels = cluster_summary_renamed.index.tolist()

    for cluster_label in cluster_labels:
        row = cluster_summary_renamed.loc[cluster_label]

        x_data = []
        y_data = []
        for metric in sat_score_cols:
            if metric in row.index and pd.notnull(row[metric]):
                val = float(row[metric])
            else:
                val = 0.0  # fallback if missing
            x_data.append(metric)
            y_data.append(val)

        # Add one Bar trace for this cluster
        fig.add_trace(go.Bar(
            x=x_data,
            y=y_data,
            name=str(cluster_label)  # cluster name in the legend
        ))

    # 3) Make it a grouped bar chart
    fig.update_layout(
        barmode='group',
#        title="SAT Scores - All Clusters",
        xaxis_title="SAT Metric",
        yaxis_title="Score",
        height=500,
        template="plotly_dark"
    )

    return fig

def display_sat_score_tab_all_clusters(cluster_summary_renamed):
    """
    Displays a 'SAT Score' tab with a single grouped bar chart
    for all clusters side by side.
    """
    st.subheader("SAT Scores - Grouped Bar Chart of All Clusters")
    
    if cluster_summary_renamed.empty:
        st.write("No data to display.")
        return

    fig = create_sat_score_bars_all_clusters(cluster_summary_renamed)
    st.plotly_chart(fig, use_container_width=True)

ethnicity_cols = [
    "Mean % American Indian/Alaskan Native",
    "Mean % Asian/Native Hawaiian/Other PI",
    "Mean % Black (Non-Hispanic)",
    "Mean % Hispanic/Latino/Spanish Origin",
    "Mean % Multi-Racial",
    "Mean % White (Non-Hispanic)"
]

def create_ethnicity_spider_chart(cluster_summary_renamed):
    """
    Builds a single spider/radar chart for the 6 ethnicity columns, 
    with each cluster as a separate radial trace.

    Each column is converted to a percentage if <= 1.0
    (i.e., if the data is in fraction form).
    If the data is already in [0..100], we keep it as is.
    The chart's radial axis is 0..100 by default.

    Returns a Plotly Figure.
    """

    fig = go.Figure()

    # We'll store these categories in a list for the radial axis
    categories = ethnicity_cols  # 6 categories

    # In a spider chart, we typically close the polygon
    # so we append the first value at the end as well (Plotly approach).
    # But we'll do that on each trace individually.

    # For each cluster => one Scatterpolar trace
    clusters_list = cluster_summary_renamed.index.tolist()
    if not clusters_list:
        return None  # no data?

    for cluster_label in clusters_list:
        row = cluster_summary_renamed.loc[cluster_label]

        # Build an array of 6 numeric values
        r_values = []
        skip_trace = False
        for col in categories:
            if col not in row.index or pd.isnull(row[col]):
                skip_trace = True
                break
            val = float(row[col])
            # If val <= 1 => treat as fraction => multiply by 100
            if val <= 1.0:
                val *= 100.0
            # clamp to [0..100]
            val = max(0.0, min(val, 100.0))
            r_values.append(val)

        if skip_trace:
            # if any category is missing, skip cluster
            continue

        # In a spider chart, to "close" the polygon, we repeat the first value at the end.
        r_values.append(r_values[0])
        theta = categories + [categories[0]]

        fig.add_trace(go.Scatterpolar(
            r=r_values,
            theta=theta,
            name=str(cluster_label), 
            mode='lines+markers',  # or 'lines', or 'markers'
            fill='none'            # 'toself' can fill the polygon, but can get messy with multiple clusters
        ))

    if len(fig.data) == 0:
        return None

    # Let's define radialaxis range => 0..100
    fig.update_layout(
        polar=dict(
            radialaxis=dict(
                visible=True,
                range=[0, 100]
            )
        ),
        showlegend=True,
        title="Ethnicity Distribution Spider Chart",
        template="plotly_dark"
    )
    return fig

def display_ethnicity_spider_tab(cluster_summary_renamed):
    """
    Displays a single tab with the ethnicity spider chart 
    (all clusters together).
    """
    st.subheader("Ethnicity Distribution (Spider)")
    if cluster_summary_renamed.empty:
        st.write("No data to display.")
        return

    fig = create_ethnicity_spider_chart(cluster_summary_renamed)
    if fig:
        st.plotly_chart(fig, use_container_width=True)
    else:
        st.write("No valid ethnicity data found.")    

########################
# Putting Tabs Together
########################
def display_cluster_tabs(cluster_summary_renamed):
    """Creates tabs: 'Student Profile', 'College Stats', "SAT Scores Stats"."""
    tab_college, tab_sat, tab_ethnicity, tab_profile = st.tabs(["College Stats","SAT Scores Stats", "Ethnicity Stats", "Student Profile"])

    # 1) Student Profile Tab
    with tab_profile:
        display_student_profile_tab_concentric_2(cluster_summary_renamed)

    # 2) College Stats Tab
    with tab_college:
        display_college_stats_tab_grouped_bar(cluster_summary_renamed)

    # 3) SAT Stats Tab
    with tab_sat:
        display_sat_score_tab_all_clusters(cluster_summary_renamed)

    # 4) Ethnicity Tab
    with tab_ethnicity:
        display_ethnicity_spider_tab(cluster_summary_renamed)    

display_cluster_tabs(cluster_summary_renamed)        

selected_colleges = st.multiselect(
    "Select colleges to highlight in the cluster table:",
    options=df_colleges_final["college"].unique()
)

if selected_colleges:
    for college in selected_colleges:
        row = df_colleges_final[df_colleges_final["college"] == college]
        if row.empty:
            st.write(f"College **{college}** not found in df_colleges_final.")
        else:
            clust = row.iloc[0]["kmeans_cluster"]
            st.write(f"College **{college}** => Cluster {clust}")

# highlight function
def highlight_selected(val):
    if val in selected_colleges:
        return "background-color: #087099; color: white;"
    return ""

all_clusters = df_colleges_final["kmeans_cluster"].unique()
clusters_dict = {}
max_len = 0

for c in all_clusters:
    col_list = df_colleges_final.loc[df_colleges_final["kmeans_cluster"] == c, "college"].tolist()
    clusters_dict[c] = col_list
    max_len = max(max_len, len(col_list))

clusters_data = {}
for c in all_clusters:
    padded = clusters_dict[c] + [""]*(max_len - len(clusters_dict[c]))
    clusters_data[c] = padded

clusters_df = pd.DataFrame(clusters_data)
clusters_df = clusters_df.rename(columns=lambda x: f"Cluster {x}")

clusters_df_styled = clusters_df.style.applymap(highlight_selected)

st.write("### Colleges in Each Cluster")
st.dataframe(clusters_df_styled, height=500)

st.write("")
st.markdown("---")
st.write("")

st.markdown("# Your Personalised Likely/Match/Reach College Selection βœ…")

# 1) Gather all colleges from df_colleges_final
all_colleges = sorted(df_colleges_final["college"].unique())

# 2) First multi-select: Likely Colleges
likely_selected = st.multiselect(
    "Likely Colleges",
    options=all_colleges,
    help="Pick your Likely colleges here (they won't appear in the other lists)."
)

# Exclude Likely picks from next
remaining_after_likely = [c for c in all_colleges if c not in likely_selected]

# 3) Second multi-select: Match Colleges
match_selected = st.multiselect(
    "Match Colleges",
    options=remaining_after_likely,
    help="Pick your Match colleges here."
)

# Exclude Match picks from next
remaining_after_match = [c for c in remaining_after_likely if c not in match_selected]

# 4) Third multi-select: Reach Colleges
reach_selected = st.multiselect(
    "Reach Colleges",
    options=remaining_after_match,
    help="Pick your Reach colleges here."
)

# 5) Turn them into a single DataFrame
# Because each list may have different lengths, we pad them so each column is the same length.
max_len = max(len(likely_selected), len(match_selected), len(reach_selected))

padded_likely = likely_selected + [""]*(max_len - len(likely_selected))
padded_match = match_selected + [""]*(max_len - len(match_selected))
padded_reach = reach_selected + [""]*(max_len - len(reach_selected))

df_manual_picks = pd.DataFrame({
    "Likely": padded_likely,
    "Match": padded_match,
    "Reach": padded_reach
})

st.write("### Your Manually Selected Colleges")
st.dataframe(df_manual_picks)

# 6) Provide a download button for df_manual_picks
csv_data = df_manual_picks.to_csv(index=False)

st.download_button(
    label="Download CSV",
    data=csv_data,
    file_name="likely_match_reach_selections.csv",
    mime="text/csv"
)