import numpy as np
from matplotlib import pyplot as plt
import math
import copy
import os
import seaborn as sns
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap, Normalize
[docs]
def plot_qwak(
x_value_matrix,
y_value_matrix,
x_label=None,
y_label=None,
plot_title=None,
legend_labels=None,
legend_title=None,
legend_ncol=3,
legend_loc='best',
save_path=None,
font_size=16, # Increased font size
figsize=(12, 8), # Adjusted figure size
color_list=None,
line_style_list=None,
use_loglog=False,
use_cbar=False,
cbar_label=None,
cbar_ticks_generator=None,
cbar_num_ticks=3,
cbar_tick_labels=['a', 'b', 'c'],
x_num_ticks=None,
y_num_ticks=None,
x_round_val=3,
y_round_val=3,
v_line_values=None,
v_line_style='--',
v_line_list_index=None,
v_line_colors=None,
x_lim=None,
use_grid=False, # Added use_grid parameter
**kwargs):
"""
Plots data from a matrix of x values and a matrix of y values.
Note that all parameters can be passed via a dictionary using the **kwargs syntax.
Parameters
----------
x_value_matrix : array_like
A matrix of x values. Each row is a set of x values for a single line.
y_value_matrix : array_like
A matrix of y values. Each row is a set of y values for a single line.
x_label : str, optional
The label for the x-axis. The default is None.
y_label : str, optional
The label for the y-axis. The default is None.
plot_title : str, optional
The title of the plot. The default is None.
legend_labels : list, optional
A list of strings to be used as labels for the legend. The default is None.
legend_title : str, optional
The title of the legend. The default is None.
legend_ncol : int, optional
The number of columns in the legend. The default is 3.
legend_loc : str, optional
The location of the legend. The default is 'best'.
save_path : str, optional
The path to save the plot to. The default is None.
font_size : int, optional
The font size of the plot. The default is 12.
figsize : tuple, optional
The size of the figure. The default is (10,6).
color_list : list, optional
A list of colors to be used for the lines. The default is None.
line_style_list : list, optional
A list of line styles to be used for the lines. The default is None.
use_loglog : bool, optional
Whether to use a log-log plot. The default is False.
use_cbar : bool, optional
Whether to use a colorbar. The default is False.
cbar_label : str, optional
The label for the colorbar. The default is None.
cbar_ticks_generator : function, optional
A function that generates the ticks for the colorbar. The default is None.
cbar_num_ticks : int, optional
The number of ticks for the colorbar. The default is 3.
cbar_tick_labels : list, optional
A list of strings to be used as labels for the colorbar ticks. The default is ['a','b','c'].
x_num_ticks : int, optional
The number of ticks for the x-axis. The default is None.
y_num_ticks : int, optional
The number of ticks for the y-axis. The default is None.
x_round_val : int, optional
The number of decimal places to round the x-axis ticks to. The default is 3.
y_round_val : int, optional
The number of decimal places to round the y-axis ticks to. The default is 3.
v_line_values : list, optional
A list of values to plot vertical lines at. The default is None.
v_line_style : str, optional
The style of the vertical lines. The default is '--'.
v_line_list_index : int, optional
The index of the line to plot the vertical lines on. The default is None.
v_line_colors : list, optional
A list of colors to be used for the vertical lines. The default is None.
x_lim : tuple, optional
The limits of the x-axis. The default is None.
use_grid : bool, optional
Whether to use a grid. The default is False.
marker_list : list, optional
A list of markers to be used for the lines. The default is None.
marker_interval : int, optional
Interval at which markers should be placed. The default is None.
"""
# Unpack optional parameters
x_label = kwargs.get('x_label', x_label)
y_label = kwargs.get('y_label', y_label)
plot_title = kwargs.get('plot_title', plot_title)
legend_labels = kwargs.get('legend_labels', legend_labels)
legend_title = kwargs.get('legend_title', legend_title)
legend_ncol = kwargs.get('legend_ncol', legend_ncol)
legend_loc = kwargs.get('legend_loc', legend_loc)
save_path = kwargs.get('save_path', save_path)
font_size = kwargs.get('font_size', font_size)
figsize = kwargs.get('figsize', figsize)
color_list = kwargs.get('color_list', color_list)
line_style_list = kwargs.get('line_style_list', line_style_list)
use_loglog = kwargs.get('use_loglog', use_loglog)
use_cbar = kwargs.get('use_cbar', use_cbar)
cbar_label = kwargs.get('cbar_label', cbar_label)
cbar_ticks_generator = kwargs.get(
'cbar_ticks', cbar_ticks_generator)
cbar_num_ticks = kwargs.get('cbar_num_ticks', cbar_num_ticks)
cbar_tick_labels = kwargs.get('cbar_tick_labels', cbar_tick_labels)
x_num_ticks = kwargs.get('x_num_ticks', x_num_ticks)
y_num_ticks = kwargs.get('y_num_ticks', y_num_ticks)
x_round_val = kwargs.get('x_round_val', x_round_val)
y_round_val = kwargs.get('y_round_val', y_round_val)
v_line_values = kwargs.get('v_line_values', v_line_values)
v_line_style = kwargs.get('v_line_style', v_line_style)
x_lim = kwargs.get('x_lim', x_lim)
v_line_list_index = kwargs.get(
'v_line_list_index', v_line_list_index)
use_grid = kwargs.get('use_grid', use_grid)
marker_list = kwargs.get('marker_list', None)
marker_interval = kwargs.get('marker_interval', None)
if not isinstance(
x_value_matrix[0],
list) and not isinstance(
x_value_matrix[0],
np.ndarray):
x_value_matrix = [x_value_matrix]
if not isinstance(
y_value_matrix[0],
list) and not isinstance(
y_value_matrix[0],
np.ndarray):
y_value_matrix = [y_value_matrix]
if v_line_list_index is not None and v_line_values is None:
raise ValueError(
"v_line_list_index is provided, but v_line_values is None.")
# plot the data for each row of the data matrix
fig, ax = plt.subplots(figsize=figsize)
if use_loglog:
ax.loglog()
if use_grid: # Apply grid option
ax.grid(True)
axis_matrix_counter = 0
color_matrix_counter = 0
for xvalues, yvalues in zip(x_value_matrix, y_value_matrix):
color = None
line_style = None
label = None
marker = None
markevery = None
if color_list is not None:
color = color_list[axis_matrix_counter]
if line_style_list is not None:
line_style = line_style_list[axis_matrix_counter]
if legend_labels is not None:
label = legend_labels[axis_matrix_counter]
if marker_list is not None:
marker = marker_list[axis_matrix_counter]
if marker_interval is not None:
markevery = marker_interval
ax.plot(xvalues, yvalues, label=label,
color=color, linestyle=line_style, marker=marker, markevery=markevery)
if v_line_values is not None and axis_matrix_counter == v_line_list_index:
if v_line_colors is None:
v_line_colors = ['k'] * len(v_line_values)
ax.axvline(
x=v_line_values[color_matrix_counter][0],
ymin=0,
ymax=v_line_values[color_matrix_counter][1] /
ax.get_ylim()[1],
color=v_line_colors[color_matrix_counter],
linestyle=v_line_style,
linewidth=1.5)
color_matrix_counter += 1
elif v_line_values is not None and v_line_list_index is None:
ax.axvline(
x=v_line_values[axis_matrix_counter][0],
ymin=0,
ymax=v_line_values[axis_matrix_counter][1] /
ax.get_ylim()[1],
color=color,
linestyle=v_line_style,
linewidth=1.5)
axis_matrix_counter += 1
# set the axis labels
if x_label is not None:
ax.set_xlabel(x_label, fontsize=font_size + 4)
if y_label is not None:
ax.set_ylabel(y_label, fontsize=font_size + 4)
# set the plot plot_title
if plot_title is not None:
ax.set_title(plot_title, fontsize=font_size + 6)
# set the legend
if legend_labels is not None:
legend = ax.legend(
loc=legend_loc,
ncol=legend_ncol,
fontsize=font_size)
if legend_title is not None:
legend.set_title(legend_title, prop={'size': font_size + 2})
# set font size for ticks
ax.tick_params(axis='both', labelsize=font_size + 2)
min_x_value_matrix = np.min(x_value_matrix)
max_x_value_matrix = np.max(x_value_matrix)
min_y_value_matrix = np.min(y_value_matrix)
max_y_value_matrix = np.max(y_value_matrix)
# set tick labels
if x_num_ticks is not None:
num_x_ticks = min(x_num_ticks, len(x_value_matrix[0]))
x_tick_labels = np.round(
np.linspace(
min_x_value_matrix,
max_x_value_matrix,
num_x_ticks),
x_round_val)
ax.set_xticks(
np.linspace(
min_x_value_matrix,
max_x_value_matrix,
num_x_ticks))
ax.set_xticklabels(x_tick_labels, fontsize=font_size + 2)
if y_num_ticks is not None:
num_y_ticks = min(y_num_ticks, len(y_value_matrix[0]))
y_tick_labels = np.round(
np.linspace(
min_y_value_matrix,
max_y_value_matrix,
num_y_ticks),
y_round_val)
ax.set_yticks(
np.linspace(
min_y_value_matrix,
max_y_value_matrix,
num_y_ticks))
ax.set_yticklabels(y_tick_labels, fontsize=font_size + 2)
# add colorbar
if use_cbar:
if cbar_ticks_generator is None:
cbar_ticks = np.linspace(
min_y_value_matrix,
max_y_value_matrix,
cbar_num_ticks)
else:
cbar_ticks = cbar_ticks_generator
colors = color_list
cmap = LinearSegmentedColormap.from_list(
'custom_cmap', colors, N=len(colors))
# norm = BoundaryNorm(cbar_ticks, len(cbar_ticks)-1)
norm = Normalize(vmin=cbar_ticks[0], vmax=cbar_ticks[-1])
cbar = plt.colorbar(
plt.cm.ScalarMappable(
norm=norm,
cmap=cmap),
ax=ax,
ticks=cbar_ticks)
if cbar_label is not None:
cbar.set_label(cbar_label, fontsize=font_size + 2)
cbar.ax.tick_params(labelsize=font_size + 2)
if cbar_tick_labels is not None:
cbar.ax.set_yticklabels(cbar_tick_labels, fontsize=font_size + 2)
if x_lim is not None:
ax.set_xlim(x_lim)
# save or show the plot
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight')
plt.show()
else:
plt.show()
[docs]
def plot_qwak_heatmap(
p_values,
t_values,
prob_values,
x_num_ticks=5,
y_num_ticks=5,
x_round_val=3,
y_round_val=3,
filepath=None,
N=None,
xlabel=None,
ylabel=None,
cbar_label=None,
font_size=12,
figsize=(8,6),
cmap='coolwarm',
x_vline_value=None,
y_hline_value=None,
title_font_size=16, # Added title_font_size parameter
xlabel_font_size=14, # Added xlabel_font_size parameter
ylabel_font_size=14, # Added ylabel_font_size parameter
tick_font_size=12, # Added tick_font_size parameter
**kwargs):
"""
Plot a heatmap of the probability of a given p and t value.
Parameters
----------
p_values : list of lists
List of lists of p values.
t_values : list of lists
List of lists of t values.
prob_values : list of lists
List of lists of probability values.
x_num_ticks : int, optional
Number of x-axis ticks. The default is 5.
y_num_ticks : int, optional
Number of y-axis ticks. The default is 5.
x_round_val : int, optional
Number of decimal places to round x-axis tick labels. The default is 3.
y_round_val : int, optional
Number of decimal places to round y-axis tick labels. The default is 3.
filepath : str, optional
Path to save the plot. The default is None.
N : int, optional
Number of samples. The default is None.
xlabel : str, optional
Label for the x-axis. The default is None.
ylabel : str, optional
Label for the y-axis. The default is None.
cbar_label : str, optional
Label for the colorbar. The default is None.
font_size : int, optional
Font size for the plot. The default is 12.
figsize : tuple, optional
Size of the plot. The default is (8, 6).
cmap : str, optional
Colormap for the plot. The default is 'coolwarm'.
x_vline_value : float, optional
Value to draw a vertical line at. The default is None.
y_hline_value : float, optional
Value to draw a horizontal line at. The default is None.
title_font_size : int, optional
Font size for the title. The default is 16.
xlabel_font_size : int, optional
Font size for the x-axis label. The default is 14.
ylabel_font_size : int, optional
Font size for the y-axis label. The default is 14.
tick_font_size : int, optional
Font size for the tick labels. The default is 12.
"""
# Get parameters from kwargs with their default values
x_num_ticks = kwargs.get('x_num_ticks', x_num_ticks)
y_num_ticks = kwargs.get('y_num_ticks', y_num_ticks)
x_round_val = kwargs.get('x_round_val', x_round_val)
y_round_val = kwargs.get('y_round_val', y_round_val)
filepath = kwargs.get('filepath', filepath)
N = kwargs.get('N', N)
xlabel = kwargs.get('xlabel', xlabel)
ylabel = kwargs.get('ylabel', ylabel)
cbar_label = kwargs.get('cbar_label', cbar_label)
font_size = kwargs.get('font_size', font_size)
figsize = kwargs.get('figsize', figsize)
cmap = kwargs.get('cmap', cmap)
x_vline_value = kwargs.get('x_vline_value', x_vline_value)
y_hline_value = kwargs.get('y_hline_value', y_hline_value)
title_font_size = kwargs.get('title_font_size', title_font_size)
xlabel_font_size = kwargs.get('xlabel_font_size', xlabel_font_size)
ylabel_font_size = kwargs.get('ylabel_font_size', ylabel_font_size)
tick_font_size = kwargs.get('tick_font_size', tick_font_size)
flat_p = [item for sublist in p_values for item in sublist]
flat_t = [item for sublist in t_values for item in sublist]
flat_prob = [item for sublist in prob_values for item in sublist]
data = {'p': flat_p, 't': flat_t, 'prob': flat_prob}
df = pd.DataFrame(data)
pivot = df.pivot(index='t', columns='p', values='prob')
fig, ax = plt.subplots(figsize=figsize)
sns.heatmap(pivot, cmap=cmap, vmin=0, vmax=1, cbar_kws={'label': cbar_label}, linewidths=0)
# set customizations
ax.invert_yaxis()
ax.tick_params(axis='both', which='major', labelsize=tick_font_size)
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=tick_font_size)
cbar.ax.set_ylabel(cbar_label, fontsize=font_size + 2)
# generate tick labels based on p and t values
num_p_ticks = min(x_num_ticks, len(p_values[0]))
num_t_ticks = min(y_num_ticks, len(t_values[0]))
p_tick_labels = np.round(np.linspace(100 * min(flat_p), 100 * max(flat_p), num_p_ticks), x_round_val)
t_tick_labels = np.round(np.linspace(min(flat_t), max(flat_t), num_t_ticks), y_round_val)
# set x and y tick labels
ax.set_xticks(np.linspace(0, len(p_values[0]), num_p_ticks))
ax.set_yticks(np.linspace(0, len(t_values[0]), num_t_ticks))
ax.set_xticklabels(p_tick_labels, rotation=0)
ax.set_yticklabels(t_tick_labels)
if x_vline_value is not None:
x_tick_value = (x_vline_value - min(flat_p)) / (max(flat_p) - min(flat_p)) * len(p_values[0])
ax.axvline(x=x_tick_value, linestyle=':', marker='D', markersize=5, linewidth=1, color='white')
if y_hline_value is not None:
y_tick_value = (y_hline_value - min(flat_t)) / (max(flat_t) - min(flat_t)) * len(t_values[0])
ax.axhline(y=y_tick_value, linestyle=':', marker='D', markersize=5, linewidth=1, color='white')
# add title to the plot if N is provided
if N is not None:
plt.title(f'N={N}', fontsize=title_font_size)
# add x and y axis labels
if xlabel is not None:
ax.set_xlabel(xlabel, fontsize=xlabel_font_size)
if ylabel is not None:
ax.set_ylabel(ylabel, fontsize=ylabel_font_size)
if filepath is not None:
plt.savefig(filepath, bbox_inches='tight')
plt.show()
else:
plt.show()
[docs]
def searchProbStepsPlotting():
# Define the function here
pass