Source code for jinete.storers.plots.graph

"""Graph plotting storers module, in which a set of plotting storers whose resulting artifact is a graph."""

from __future__ import (
    annotations,
)

from typing import (
    TYPE_CHECKING,
)

import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns

from ..abc import (
    Storer,
)

if TYPE_CHECKING:
    from typing import (
        Dict,
        Any,
        Tuple,
    )
    from pathlib import Path
    from ...models import Position


[docs]class GraphPlotStorer(Storer): """Generate a directed graph representation of the solution."""
[docs] def __init__(self, file_path: Path = None, *args, **kwargs): """Construct a new object instance. :param file_path: The file path in which to store the problem solution. :param args: Additional positional arguments. :param kwargs: Additional named arguments. """ super().__init__(*args, **kwargs) self.file_path = file_path
def _generate_nodes(self, edges: Dict[Tuple[Position, Position], Dict[str, Any]]) -> Dict[Position, Dict[str, Any]]: nodes: Dict[Position, Dict[str, Any]] = dict() for trip in self._trips: nodes[trip.origin_position] = { "label": f"+{trip.identifier}", } nodes[trip.destination_position] = { "label": f"-{trip.identifier}", } for position_pair in edges.keys(): if position_pair[0] not in nodes: nodes[position_pair[0]] = { "label": "", } if position_pair[1] not in nodes: nodes[position_pair[1]] = { "label": "", } return nodes def _generate_edges(self) -> Dict[Tuple[Position, Position], Dict[str, Any]]: edges = dict() for route, color in zip(self._routes, sns.husl_palette(len(self._routes))): for first, second in zip(route.stops[:-1], route.stops[1:]): edges[(first.position, second.position)] = { "color": color, "label": "", } return edges def _generate_graph(self) -> nx.Graph: graph = nx.DiGraph() edges = self._generate_edges() graph.add_edges_from(edges.keys()) for position_pair, metadata in edges.items(): graph.edges[position_pair].update(metadata) nodes = self._generate_nodes(edges) graph.add_nodes_from(nodes.keys()) for position, metadata in nodes.items(): graph.nodes[position].update(metadata) return graph def _show_graph(self, graph: nx.Graph) -> None: import matplotlib as mpl mpl.rcParams["figure.dpi"] = 300 pos = {node: node.coordinates for node in graph.nodes.keys()} node_labels = {node: metadata["label"] for node, metadata in graph.nodes.items()} edge_color = [metadata["color"] for metadata in graph.edges.values()] edge_labels = {edge: metadata["label"] for edge, metadata in graph.edges.items()} nx.draw(graph, pos=pos, edge_color=edge_color, node_size=100) nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=5, font_color="white") nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels) if self.file_path is not None: plt.savefig(str(self.file_path)) else: plt.show()
[docs] def store(self) -> None: """Perform a storage process.""" graph = self._generate_graph() self._show_graph(graph)