Skip to content

Commit 7ce1a9a

Browse files
committed
Add mixed-edge-graph example
Signed-off-by: Adam Li <[email protected]>
1 parent e0f9533 commit 7ce1a9a

File tree

4 files changed

+196
-0
lines changed

4 files changed

+196
-0
lines changed

docs/api.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,22 @@ and corresponding causal graphs in pywhy-graphs.
6363
graph_to_arr
6464
clearn_arr_to_graph
6565

66+
NetworkX Experimental Functionality
67+
===================================
68+
Currently, NetworkX does not support mixed-edge graphs, which are crucial
69+
for representing causality with latent confounders and selection bias. The
70+
following represent functionality that we intend to PR eventually into
71+
networkx. They are included in pywhy-graphs as a temporary bridge. We
72+
welcome feedback.
73+
74+
.. currentmodule:: pywhy_graphs.networkx
75+
.. autosummary::
76+
:toctree: generated/
77+
78+
MixedEdgeGraph
79+
bidirected_to_unobserved_confounder
80+
m_separated
81+
6682
Visualization of causal graphs
6783
==============================
6884
Visualization of causal graphs is different compared to networkx because causal graphs

examples/mixededge/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Examples
2+
--------
3+
4+
Examples demonstrating how to use a MixedEdgeGraph (Note this is a WIP API with the intention of adding into networkx).
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
====================================================
3+
MixedEdgeGraph - Graph with different types of edges
4+
====================================================
5+
6+
A ``MixedEdgeGraph`` is a graph comprised of a tuple, :math:`G = (V, E)`.
7+
The difference compared to the other networkx graphs are the edges, E.
8+
``E`` is comprised of a set of mixed edges defined by the user. This
9+
allows arbitrary representation of graphs with different types of edges.
10+
The ``MixedEdgeGraph`` class represents each type of edge using an internal
11+
graph that is one of ``nx.Graph`` or ``nx.DiGraph`` classes. Each internal graph
12+
represents one type of edge.
13+
14+
Semantically a ``MixedEdgeGraph`` with just one type of edge, is just a normal
15+
``nx.Graph` or ``nx.DiGraph`` and should be converted to its appropriate
16+
networkx class.
17+
18+
For example, causal graphs typically have two types of edges:
19+
20+
- ``->`` directed edges representing causal relations
21+
- ``<->`` bidirected edges representing the presence of an unobserved
22+
confounder.
23+
24+
This would type of mixed-edge graph with two internal graphs: a ``nx.DiGraph``
25+
to represent the directed edges, and a ``nx.Graph`` to represent the bidirected
26+
edges.
27+
"""
28+
29+
import matplotlib.pyplot as plt
30+
import networkx as nx
31+
32+
import pywhy_graphs as pg
33+
34+
# %%
35+
# Construct a MixedEdgeGraph
36+
# --------------------------
37+
# Using the ``MixedEdgeGraph``, we can represent a causal graph
38+
# with two different kinds of edges. To create the graph, we
39+
# use networkx ``nx.DiGraph`` class to represent directed edges,
40+
# and ``nx.Graph`` class to represent edges without directions (i.e.
41+
# bidirected edges). The edge types are then specified, so the mixed edge
42+
# graph object knows which graphs are associated with which types of edges.
43+
44+
directed_G = nx.DiGraph(
45+
[
46+
("X", "Y"),
47+
("Z", "X"),
48+
]
49+
)
50+
bidirected_G = nx.Graph(
51+
[
52+
("X", "Y"),
53+
]
54+
)
55+
directed_G.add_nodes_from(bidirected_G.nodes)
56+
bidirected_G.add_nodes_from(directed_G.nodes)
57+
G = pg.networkx.MixedEdgeGraph(
58+
graphs=[directed_G, bidirected_G],
59+
edge_types=["directed", "bidirected"],
60+
name="IV Graph",
61+
)
62+
63+
# Compute the multipartite_layout using the "layer" node attribute
64+
pos = nx.spring_layout(G)
65+
66+
# we can then visualize the mixed-edge graph
67+
fig, ax = plt.subplots()
68+
nx.draw_networkx(G.get_graphs(edge_type="directed"), pos=pos, ax=ax)
69+
nx.draw_networkx(G.get_graphs(edge_type="bidirected"), pos=pos, ax=ax)
70+
ax.set_title("Instrumental Variable Mixed Edge Causal Graph")
71+
fig.tight_layout()
72+
plt.show(block=False)
73+
74+
# %%
75+
# Mixed Edge Graph Properties
76+
# ---------------------------
77+
78+
print(G.name)
79+
80+
# G is directed since there are directed edges
81+
print(f"{G} is directed: {G.is_directed()} because there are directed edges.")
82+
83+
# MixedEdgeGraphs are not multigraphs
84+
print(G.is_multigraph())
85+
86+
# the different edge types present in the graph
87+
print(G.edge_types)
88+
89+
# the internal networkx graphs representing each edge type
90+
print(G.get_graphs())
91+
92+
# we can specifically get the networkx graph representation
93+
# of any edge, e.g. the bidirected edges
94+
bidirected_edges = G.get_graphs("bidirected")
95+
96+
# %%
97+
# Mixed Edge Graph Operations on Nodes
98+
# ------------------------------------
99+
100+
# Nodes: Similar to ``nx.Graph`` and ``nx.DiGraph``, the nodes of the graph
101+
# can be queried via the same API. By default nodes are stored
102+
# inside every internal graph.
103+
nodes = G.nodes
104+
assert G.order() == len(G)
105+
assert len(G) == G.number_of_nodes()
106+
print(f"{G} has nodes: {nodes}")
107+
108+
# If we add a node, then we can query if the new node is there
109+
print(f"Graph has node A: {G.has_node('A')}")
110+
G.add_node("A")
111+
print(f"Now graph has node A: {G.has_node('A')}")
112+
113+
# Now, we can remove the node
114+
G.remove_node("A")
115+
print(f"Graph has node A: {G.has_node('A')}")
116+
117+
# %%
118+
# Mixed Edge Graph Operations on Edges
119+
# ------------------------------------
120+
# Mixed edge graphs are just like normal networkx graph classes,
121+
# except that they store an internal networkx graph per edge type.
122+
# As a result, each edge now corresponds to an 'edge_type', which
123+
# typically must be specified in edge operations for mixed edge graphs.
124+
125+
# Edges: We can query specific edges by type
126+
print(f"The graph has directed edges: {G.edges()['directed']}")
127+
128+
# Note these edges correspond to the edges of the internal networkx
129+
# DiGraph that represents the directed edges
130+
print(G.get_graphs("directed").edges())
131+
132+
# When querying, adding, or removing an edge, you must specify
133+
# the edge type as well.
134+
# Here, we can add a new Z <-> Y bidirected edge.
135+
assert G.has_edge("X", "Y", edge_type="directed")
136+
G.add_edge("Z", "Y", edge_type="bidirected")
137+
assert not G.has_edge("Z", "Y", edge_type="directed")
138+
139+
# Now, we can remove the Z <-> Y bidirected edge.
140+
G.remove_edge("Z", "Y", edge_type="bidirected")
141+
assert not G.has_edge("Z", "Y", edge_type="bidirected")
142+
143+
# %%
144+
# Mixed Edge Graph Key Differences
145+
# --------------------------------
146+
# Mixed edge graphs implement the standard networkx API, but the
147+
# ``adj``, ``edges``, and ``degree`` are functions instead of
148+
# class properties. Moreover, one can specify the edge type.
149+
150+
# Neighbors: Compared to its uni-edge networkx counterparts, a mixed-edge
151+
# graph has many edge types. We define neighbors as any node with a connection.
152+
# This is similar to `nx.Graph` where neighbors are any adjacent neighbors.
153+
assert "Z" in G.neighbors("X")
154+
155+
# Similar to the networkx API, the ``adj`` provides a way to iterate
156+
# through the nodes and edges, but now over different edge types.
157+
for edge_type, adj in G.adj.items():
158+
print(edge_type)
159+
print(adj)
160+
161+
# If you only want the adjacencies of the directed edges, you can
162+
# query the returned dictionary of adjacencies.
163+
print(G.adj["directed"])
164+
165+
# Similar to the networkx API, the ``edges`` provides a way to iterate
166+
# through the edges, but now over different edge types.
167+
for edge_type, edges in G.edges().items():
168+
print(edge_type)
169+
print(edges)
170+
171+
# Similar to the networkx API, the ``edges`` provides a way to iterate
172+
# through the edges, but now over different edge types.
173+
for node, degrees in G.degree().items():
174+
print(f"{node} with degree: {degrees}")

pywhy_graphs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
from .classes import ADMG, CPDAG, PAG
33
from .algorithms import * # noqa: F403
44
from .array import export
5+
6+
from . import networkx

0 commit comments

Comments
 (0)