Py.Cafe

nespinoza/

lightcurve-fit-panel

Fitting WASP-39b JWST NIRSpec/PRISM lightcurves via sliders

DocsPricing
  • lightcurves/
  • app.py
  • requirements.txt
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import panel as pn
import glob
import os
pn.extension()

def load_lightcurves():
    files = sorted(glob.glob("lightcurves/w_*"))
    curves = {}
    for f in files:
        name = os.path.basename(f)
        tt, ff, fferr = np.loadtxt(f, unpack=True, usecols=(0,1,2))
        curves[name] = {
            "time": tt,
            "flux": ff,
            "error": fferr
        }
    return curves

def linear_ld_model(z, rp, u):
    """
    Approximate analytic light curve using linear limb darkening.
    """
    z = np.asarray(z)
    flux = np.ones_like(z)
    
    norm = 1 - u / 3  # Normalization factor for limb darkening

    for i, zi in enumerate(z):
        if zi >= 1 + rp:
            flux[i] = 1.0
        elif zi <= 1 - rp:
            flux[i] = 1.0 - rp**2 * (1 - u) / norm
        else:
            # Partial transit — approximate smooth transition
            x = (1 + rp - zi) / (2 * rp)  # 0 to 1 during ingress/egress
            overlap = rp**2 * (1 - u) * x  # simple shape
            flux[i] = 1.0 - overlap / norm

    return flux

def compute_durations(rprs, a_rs, b, period):
    inc = np.arccos(b / a_rs)
    arg = (1 + rprs)**2 - b**2
    T14 = (period / np.pi) * np.arcsin(np.sqrt(arg) / (a_rs * np.sin(inc)))
    T12 = T14 * rprs / (1 + rprs)
    return T14, T12

# Fixed orbital parameters
PERIOD = 4.0552841999
T0 = 2459771.3356462922
A_RS = 11.3818407172
B = 0.4481865001

curves = load_lightcurves()
curve_names = list(curves.keys())

# Widgets
curve_select = pn.widgets.Select(name="Light curve to fit", options=curve_names)
rp_rs_slider = pn.widgets.FloatSlider(name="Rp/Rs", start=0.1, end=0.2, value=0.145, step=0.0005)
intercept_slider = pn.widgets.FloatSlider(name="Intercept (ppm)", start=-1000, end=1000, value=0.0, step=10)
slope_slider = pn.widgets.FloatSlider(name="Slope (ppm/hour)", start=-1000, end=1000, value=0.0, step=10)
ld_coeff_slider = pn.widgets.FloatSlider(name="Linear Limb-Darkening (u)", start=-3.0, end=3.0, step=0.02, value=0.0)

@pn.depends(
    curve=curve_select, rp_rs=rp_rs_slider,
    intercept=intercept_slider, slope=slope_slider,
    u=ld_coeff_slider
)
def make_plot(curve, rp_rs, intercept, slope, u):
    fig = plt.figure(figsize=(5, 3))
    gs = gridspec.GridSpec(4, 1, height_ratios=[3, 0.2, 0.2, 1], hspace=0.25)
    ax_main = fig.add_subplot(gs[:3, 0])
    ax_resid = fig.add_subplot(gs[3, 0], sharex=ax_main)

    data = curves[curve]
    t, f, e = data["time"], data["flux"], data["error"]

    # Compute projected separation z(t)
    phase = (t - T0) / PERIOD
    inc = np.arccos(B / A_RS)
    z = A_RS * np.sqrt(np.sin(2 * np.pi * phase)**2 + (np.cos(inc))**2 * np.cos(2 * np.pi * phase)**2)

    model_flux = linear_ld_model(z, rp_rs, u)

    # Add baseline trend
    baseline = intercept * 1e-6 + slope * (t - T0) * (24. / 1e6)
    model_flux += baseline

    residuals = f - model_flux
    mad = str(int(np.median(np.abs(residuals - np.median(residuals))) * 1e6))

    ax_main.errorbar(t - T0, f, yerr=e, fmt='.', alpha=0.6, label='Data', zorder=1)
    ax_main.plot(t - T0, model_flux, '-', label='Model', lw=4, zorder=10)
    ax_main.set_ylabel("Relative Flux")
    wavelength = curve.split('_')[-1][:5]
    ax_main.set_title(f"Fitting {wavelength} um light curve")
    ax_main.legend()
    ax_main.tick_params(labelbottom=False)
    ax_main.text(0.02, 0.95,
        f"Rp/Rs = {rp_rs:.4f}\nMAD = {mad} ppm\nu = {u:.2f}",
        transform=ax_main.transAxes,
        verticalalignment='top',
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8)
    )
    ax_main.set_ylim(0.96, 1.03)

    ax_resid.axhline(0, color='gray', linestyle='--', zorder=10, lw=4)
    ax_resid.errorbar(t - T0, residuals * 1e6, yerr=e * 1e6, fmt='.', alpha=0.5, zorder=1)
    ax_resid.set_xlabel("Time from mid-transit (days)")
    ax_resid.set_ylabel("O-C (ppm)")
    ax_resid.set_ylim(-3500, 3500)
    ax_resid.set_xlim(np.min(t - T0), np.max(t - T0))
    ax_main.set_xlim(np.min(t - T0), np.max(t - T0))

    fig.tight_layout()
    return fig

layout = pn.Row(
    pn.Column(
        curve_select, rp_rs_slider,
        intercept_slider, slope_slider,
        ld_coeff_slider
    ),
    pn.pane.Matplotlib(make_plot, tight=True)
)

layout.servable()