multiple_std_lines_df

multiple_std_lines_df#

Line plot with shaded confidence intervals and configurable label, color, and linestyle mappings, using a pandas DataFrame in long format.


📥 Arguments#

Name

Type

Required

Description

data_df

pd.DataFrame

Long-form DataFrame containing one row per point with columns for the run label, x values, y values, and standard error.

label_key

str

Column name in data_df that identifies each run/series.

x_key

str

Column name for x-axis values. Default: ‘round_num’.

y_key

str

Column name for y-axis values. Default: ‘unique_scores’.

yerr_key

str

Column name for standard error. Default: ‘std_error’.

figsize

tuple

Figure size. Default: (8, 5).

legend_loc

str

Legend location. Default: ‘upper left’.

label_map

Dict[str, str]

Mapping of raw labels to display names.

color_map

Dict[str, str]

Mapping of raw labels to line colors.

style_map

Dict[str, str]

Mapping of raw labels to line styles.

xlim

Tuple[float, float]

X-axis limits.

ylim

Tuple[float, float]

Y-axis limits.

xlabel

str

X-axis label.

ylabel

str

Y-axis label.

x_formatter

Callable

Formatter for x-axis ticks.

y_formatter

Callable

Formatter for y-axis ticks.

save

str

Base filename to save PNG and PDF.


📦 Example Output#

Click to show example code
import numpy as np
from matplotlib import pyplot as plt
from swizz import plot
import pandas as pd 

rounds = np.linspace(250, 2900, 30)

def fake_curve(seed, offset=0):
    np.random.seed(seed)
    base = np.linspace(-550 + offset, -400 + offset, len(rounds))
    noise = np.random.normal(0, 8, size=len(rounds))
    stderr = np.random.uniform(5, 20, size=len(rounds))
    return base + noise, stderr

# Build a long-form DataFrame
records = []
for name, (seed, offset) in [("forward-method", (0, 0)), ("reverse-method", (1, -40)), ("baseline", (2, -60))]:
    y_vals, stderr_vals = fake_curve(seed, offset)
    for x_val, y_val, err in zip(rounds, y_vals, stderr_vals):
        records.append({
            "method": name,
            "round_num": x_val,
            "unique_scores": y_val,
            "std_error": err,
        })
df = pd.DataFrame.from_records(records)

# Correct call: include plot name as first argument
fig, ax = plot(
    "multiple_std_lines_df",
    df,
    label_key="method",
    label_map={
        "forward-method": "Forward KL",
        "reverse-method": "Reverse KL",
        "baseline": "No Training",
    },
    color_map={
        "forward-method": "#CC79A7",
        "reverse-method": "#0072B2",
        "baseline": "#009E73",
    },
    style_map={
        "forward-method": "solid",
        "reverse-method": "dashed",
        "baseline": "dotted",
    },
    xlabel="Round Number",
    ylabel="Number of Unique Scores",
    xlim=(250, 2900),
    ylim=(-650, -355),
    x_formatter=lambda x, _: f"{x * 10:.0f}",
    y_formatter=lambda y, _: f"{y / 100:.1f}",
    save="ablation_df"
)
plt.show()

multiple_std_lines_df