Source code for nemf.plot

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import networkx as nx
import yaml
import nemf
from copy import deepcopy


# plotting routines

def draw_interaction_graph(model):
    """ Takes the model configuration and draws a labeled
        directional-multi-graph to illustrate the interactions """
    nodes = list(model.compartment)
    interactions = list(model.interactions)

    config = model.configuration.copy()
    if 'idx_sinks' in config:

    # turns dict into yaml style string
    ## quick and dirty reformating of the lists and tuples
    for item in config:
        if ((type(config[item]) == list) or (type(config[item]) == tuple)):
            config[item] = str(config[item])
    comment = yaml.dump(config, default_flow_style=False, line_break=True)
    comment = comment.replace('!!python/tuple','')

    # --- GRAPH ---
    # fetch list of edges and their labels
    edges = []; labels = []
    for path in interactions:
        for edge in model.interactions[path]:

    # setting up edge labes dict
    edge_labels = {}
    for ii,edge in enumerate(edges):
        # checks if key is already present
        if (edge in edge_labels):
            # and if so appends it
            edge_labels[edge] += '\n + {}'.format(labels[ii].replace('_','\n'))
            # or creates it
            edge_labels[edge] = labels[ii].replace('_','\n')

    # initialise graph
    G = nx.MultiDiGraph()

    # define node positions
    pos = nx.circular_layout(G)

    # --- PLOTTING ---
    fig = plt.figure()
    ax = plt.subplot(111)
    # draws nodes
    nx.draw(G, pos,node_size=2000, node_color='pink',
           labels={node:node for node in G.nodes()},
    # draws edges
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels,
        label_pos=0.35, font_size=10,font_color='tab:red',rotate=False)

    # # adds configuration
    # plt.legend(title=comment,loc='center left', bbox_to_anchor=(1., 0.5))
    # # positions legends
    # ## Shrink current axis by 20%
    # box = ax.get_position()
    # ## Put a legend to the right of the current axis
    # ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    return fig

[docs]def interaction_graph(model): """ shows a graph/network of all compartments and their interactions """ fig = draw_interaction_graph(model) fig.tight_layout()
def draw_cost(ax,cost): ax.title.set_text('cost') ax.set_ylabel('Cost function (a.u.)') ax.set_xlabel('Iteration Step') plt.plot(cost) def draw_predictions(ax,predictions,model): if model.configuration['fit_model'] == 'direct_fit_model': labels = list(model.compartment) elif model.configuration['fit_model'] == 'net_flux_fit_model': labels = ['net incoming/outgoing flux'] else: labels = [] ax.title.set_text('fit model prediction') handles = plt.plot(predictions) ax.set_ylabel('Model predictions (a.u.)') ax.set_xlabel('Iteration Step') plt.legend(handles, labels) def draw_parameters(ax,parameters,model,colors=None): if colors == None: colors = sns.color_palette('husl',len(model.compartment)) labels = model.to_grad_method()[2] filled_param = parameters[~np.isnan(parameters)] shape = np.shape(parameters) shape = (int(np.shape(filled_param)[0]/shape[1]),shape[1]) parameters = np.reshape(filled_param, shape) ax.title.set_text('parameters') if np.shape(parameters)[0] == 1: y = np.ones(shape[1]) for ii in np.arange(shape[1]): handles = ax.scatter(y[ii],parameters[:,ii]) handles = ax.collections else: for ii in np.arange(shape[1]): plt.plot(parameters[:,ii]) handles = ax.get_lines() ax.set_ylabel('optimized parameters (a.u.)') ax.set_xlabel('Iteration Step') plt.legend(handles, labels, loc='upper left',bbox_to_anchor=(1.0, 1.0)) return ax def draw_model_output(ax,model,colors=None): if colors == None: colors = sns.color_palette('husl',len(model.compartment)) time_series = model.log['time_series'] t = time_series[:,0] y_t = time_series[:,1:] labels = list(model.compartment) handles = [] for ii in np.arange(np.shape(y_t)[1]): plt.plot(t,y_t[:,ii],color=colors[ii]) ax.set_ylabel('Model output (a.u.)') ax.set_xlabel('Time (a.u.)') handles = ax.get_lines() plt.legend(handles, labels, loc='upper left',bbox_to_anchor=(1.0, 1.0)) return ax def draw_ref_data(ax,model,colors): t_ref = model.reference_data[:,0] y_ref = model.reference_data[:,1:] idx_refed_compart = model.prep_ref_data()[1] colors_ref = [item for ii,item in enumerate(colors) if ii in idx_refed_compart] if np.shape(y_ref)[0] < 20: for ii in np.arange(np.shape(y_ref)[1]): #sns.lineplot(t_ref,y_ref,palette=colors) plt.scatter(t_ref,y_ref[:,ii], color=colors_ref[ii]) else: for ii in np.arange(np.shape(y_ref)[1]): #sns.lineplot(t_ref,y_ref,palette=colors) plt.plot(t_ref,y_ref[:,ii],ls='--',linewidth=2, color=colors_ref[ii]) return ax def draw_optimization_overview(model): """ reads the data saved in the model class and depending on this data chooses a visualization method to present the results """ fig = plt.figure() fig.suptitle('Results of optimization run') if model.reference_data[:,0][0] == np.inf: non_steady_state = False else: non_steady_state = True if not np.isnan(model.log['cost'][0]): ax1 = plt.subplot(221) draw_cost(ax1,model.log['cost']) ax2 = plt.subplot(222) draw_parameters(ax2,model.log['parameters'],model) ax3 = plt.subplot(212) if non_steady_state: t_ref = model.reference_data[:,0] y_ref = model.reference_data[:,1:] plt.plot(t_ref,y_ref,ls='--',linewidth=2) time_series = nemf.forward_model(model,t_eval=t_ref) time_series_model = time_series.log['time_series'] ax2 = draw_model_output(ax2, model) ax2.title.set_text('optimized model') else: # steady-state t_max = model.configuration['max_time_evo'] t_ref = np.linspace(0,t_max,1000) y_ref = model.reference_data[:,1:] plt.hlines(y_ref,t_ref[0],t_ref[-1],ls='--') time_series_model = nemf.forward_model(model,t_eval=t_ref) ax3 = draw_model_output(ax3, time_series_model) ax3.title.set_text('optimized model') else: ax1 = plt.subplot(211) draw_parameters(ax1,model.log['parameters'],model) ax2 = plt.subplot(212) if non_steady_state: t_ref = model.reference_data[:,0] y_ref = model.reference_data[:,1:] colors = sns.color_palette('husl',len(model.compartment)) ax2.title.set_text('optimized model') # plotting model output time_series = nemf.forward_model(model,t_eval=t_ref) time_series_model = time_series.log['time_series'] ax2 = draw_model_output(ax2, model,colors=colors) # plotting reference data ax2 = draw_ref_data(ax2,model,colors) else: # steady-state t_max = model.configuration['max_time_evo'] dt = model.configuration['dt_time_evo'] t_ref = np.arange(0,t_max,dt) y_ref = model.reference_data[:,1:] plt.hlines(y_ref,t_ref[0],t_ref[-1],ls='--') time_series_model = nemf.forward_model(model,t_eval=t_ref) draw_model_output(ax2, time_series_model) ax2.title.set_text('optimized model') return fig def draw_output_summary(model): """ reads the data saved in the model class and depending on this data chooses a visualization method to present the results with the help of draw_optimization_overview """ if 'time_series' in model.log: # no optimization has happend. # hence, cost/predictions/parameters is 0-dim fig = plt.figure() ax = plt.subplot(1,1,1) ax = draw_model_output(ax,model) ax.title.set_text('Model Output') else: fig = draw_optimization_overview(model) return fig
[docs]def output_summary(model): """ reads the data saved in the model class and depending on this data chooses a visualization method with the help of draw_output_summary to present the results """ fig = draw_output_summary(model) plt.tight_layout()
def initial_guess(model): ax = plt.subplot(111) t_ref = model.reference_data[:,0] y_ref = model.reference_data[:,1:] plt.title('Initial model behavior and reference') plt.plot(t_ref,y_ref,ls='--',linewidth=2) initial_model = deepcopy(nemf.forward_model(model,t_eval=t_ref)) ax = draw_model_output(ax, initial_model) ax.title.set_text('Initial model guess') return ax def coupling_matrix(d2_weights,ODE_coeff_weights,names): plt.figure(figsize=(12,6)) ax = plt.subplot(121) ax.set_title("d2 coupling matrix") plt.imshow(d2_weights,cmap='PiYG',vmin=-1,vmax=1) plt.xticks(np.arange(len(names)),names, rotation=30) plt.yticks(np.arange(len(names)),names) ax.xaxis.tick_top() ax = plt.subplot(122) ax.set_title("ODE_coeff coupling matrix") plt.imshow(ODE_coeff_weights,cmap='PiYG',vmin=-1,vmax=1) plt.xticks(np.arange(len(names)),names, rotation=30) plt.yticks(np.arange(len(names)),names) ax.xaxis.tick_top() plt.savefig('coupling_matrices.svg')