Skip to content

Commit 0731442

Browse files
authored
Merge pull request #1 from eriknw/second_commit
First PR
2 parents 59aff1f + 6e56e9f commit 0731442

19 files changed

+993
-2
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ coverage.xml
5050
.hypothesis/
5151
.pytest_cache/
5252
cover/
53+
.ruff_cache/
5354

5455
# Translations
5556
*.mo
@@ -82,6 +83,9 @@ target/
8283
profile_default/
8384
ipython_config.py
8485

86+
# Vim's swap files
87+
*.sw[op]
88+
8589
# pyenv
8690
# For a library or package, you might want to ignore these files since the code is
8791
# intended to run in multiple environments; otherwise, check them in:

.pre-commit-config.yaml

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# https://pre-commit.com/
2+
#
3+
# Before first use: `pre-commit install`
4+
# To run manually: `pre-commit run --all-files`
5+
# To update: `pre-commit autoupdate`
6+
fail_fast: false
7+
default_language_version:
8+
python: python3
9+
repos:
10+
- repo: https://github.com/pre-commit/pre-commit-hooks
11+
rev: v4.6.0
12+
hooks:
13+
- id: check-added-large-files
14+
- id: check-case-conflict
15+
- id: check-merge-conflict
16+
- id: check-symlinks
17+
- id: check-ast
18+
- id: check-toml
19+
- id: check-yaml
20+
- id: debug-statements
21+
- id: end-of-file-fixer
22+
exclude_types: [svg]
23+
- id: mixed-line-ending
24+
- id: trailing-whitespace
25+
- repo: https://github.com/asottile/pyupgrade
26+
rev: v3.16.0
27+
hooks:
28+
- id: pyupgrade
29+
args: [--py310-plus]
30+
- repo: https://github.com/psf/black
31+
rev: 24.4.2
32+
hooks:
33+
- id: black
34+
- repo: https://github.com/astral-sh/ruff-pre-commit
35+
rev: v0.5.1
36+
hooks:
37+
- id: ruff
38+
args:
39+
- --fix
40+
# - id: ruff-format # Prefer black for now
41+
- repo: https://github.com/abravalheri/validate-pyproject
42+
rev: v0.18
43+
hooks:
44+
- id: validate-pyproject
45+
name: Validate pyproject.toml
46+
- repo: https://github.com/pre-commit/pre-commit-hooks
47+
rev: v4.6.0
48+
hooks:
49+
- id: no-commit-to-branch # no commit directly to main

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
BSD 3-Clause License
22

3-
Copyright (c) 2024, NetworkX
3+
Copyright (c) 2024, NetworkX Developers, NVIDIA CORPORATION, and nx-pandas contributors
44

55
Redistribution and use in source and binary forms, with or without
66
modification, are permitted provided that the following conditions are met:

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# nx-pandas
1+
# nx-pandas

nx_pandas/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import importlib.metadata
2+
3+
# This package *must* be installed even for local development,
4+
# so checking version like this lets us be strict and informative.
5+
try:
6+
__version__ = importlib.metadata.version("nx-pandas")
7+
except Exception as exc:
8+
raise AttributeError(
9+
"`nx_pandas.__version__` not available. This may mean "
10+
"nx-pandas was incorrectly installed or not installed at all. "
11+
"For local development, you may want to do an editable install via "
12+
"`python -m pip install -e path/to/nx-pandas`"
13+
) from exc
14+
del importlib

nx_pandas/_patch.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import pandas as pd
2+
3+
4+
# https://pandas.pydata.org/docs/development/extending.html#registering-custom-accessors
5+
@pd.api.extensions.register_dataframe_accessor("nx")
6+
class NxAccessor:
7+
def __init__(self, pandas_obj):
8+
self._df = pandas_obj
9+
self.is_directed = True
10+
self.is_multigraph = False
11+
self._source = "source" if "source" in pandas_obj.columns else None
12+
self._target = "target" if "target" in pandas_obj.columns else None
13+
self._edge_key = "edge_key" if "edge_key" in pandas_obj.columns else None
14+
self.node_df = None
15+
self.graph = {} # `df.nx.graph` instead of `df.graph`
16+
self._cache = None
17+
18+
@property
19+
def source(self):
20+
if self._source is not None and self._source not in self._df.columns:
21+
# Should we raise here to ensure consistency or let users break themselves?
22+
raise KeyError(
23+
f"DataFrame does not have column {self._source!r}. "
24+
"`df.nx.source` must be set to an existing column name "
25+
"for the DataFrame to be used as a networkx graph."
26+
)
27+
return self._source
28+
29+
@source.setter
30+
def source(self, val):
31+
if val is not None and val not in self._df.columns:
32+
raise KeyError(
33+
f"DataFrame does not have column {val!r}. "
34+
"`df.nx.source` must be set to an existing column name "
35+
"for the DataFrame to be used as a networkx graph."
36+
)
37+
self._source = val
38+
39+
@property
40+
def target(self):
41+
if self._target is not None and self._target not in self._df.columns:
42+
raise KeyError(
43+
f"DataFrame does not have column {self._target!r}. "
44+
"`df.nx.target` must be set to an existing column name "
45+
"for the DataFrame to be used as a networkx graph."
46+
)
47+
return self._target
48+
49+
@target.setter
50+
def target(self, val):
51+
if val is not None and val not in self._df.columns:
52+
raise KeyError(
53+
f"DataFrame does not have column {val!r}. "
54+
"`df.nx.target` must be set to an existing column name "
55+
"for the DataFrame to be used as a networkx graph."
56+
)
57+
self._target = val
58+
59+
@property
60+
def edge_key(self):
61+
if not self.is_multigraph:
62+
raise AttributeError("'edge_key' attribute only exists for multigraphs")
63+
if self._edge_key is not None and self._edge_key not in self._df.columns:
64+
raise KeyError(
65+
f"DataFrame does not have column {self._edge_key!r}. "
66+
"`df.nx.edge_key` must be set to an existing column name or None "
67+
"for the DataFrame to be used as a networkx multi-graph."
68+
)
69+
return self._edge_key
70+
71+
@edge_key.setter
72+
def edge_key(self, val):
73+
if not self.is_multigraph:
74+
raise AttributeError("'edge_key' attribute only exists for multigraphs")
75+
if val is not None and val not in self._df.columns:
76+
raise KeyError(
77+
f"DataFrame does not have column {val!r}. "
78+
"`df.nx.edge_key` must be set to an existing column name or None "
79+
"for the DataFrame to be used as a networkx multi-graph."
80+
)
81+
self._edge_key = val
82+
83+
@property
84+
def cache_enabled(self):
85+
return self._cache is not None
86+
87+
@cache_enabled.setter
88+
def cache_enabled(self, val):
89+
if not val:
90+
# Wipe out the cache when disabling the cache
91+
self._cache = None
92+
elif self._cache is None:
93+
# Enable cache if necessary
94+
self._cache = {}
95+
96+
def __dir__(self):
97+
attrs = super().__dir__()
98+
if not self.is_multigraph:
99+
attrs.remove("edge_key")
100+
return attrs
101+
102+
def set_properties(
103+
self,
104+
*,
105+
source=None,
106+
target=None,
107+
edge_key=None,
108+
is_directed=None,
109+
is_multigraph=None,
110+
cache_enabled=None,
111+
):
112+
"""Set many graph properties (i.e., ``df.nx`` attributes) at once.
113+
114+
Return the original DataFrame to allow method chaining. For example::
115+
116+
>>> df = pd.read_csv("my_data.csv").nx.set_properties(is_directed=False)
117+
118+
This is a bulk transaction, so either all given attributes will be updated,
119+
or nothing will be set if there was an exception.
120+
"""
121+
prev = {}
122+
cur = {}
123+
if source is not None:
124+
prev["_source"] = self._source
125+
cur["source"] = source
126+
if target is not None:
127+
prev["_target"] = self._target
128+
cur["target"] = target
129+
if is_directed is not None:
130+
prev["is_directed"] = self.is_directed
131+
cur["is_directed"] = is_directed
132+
if is_multigraph is not None:
133+
prev["is_multigraph"] = self.is_multigraph
134+
cur["is_multigraph"] = is_multigraph
135+
if edge_key is not None:
136+
prev["_edge_key"] = self._edge_key
137+
cur["edge_key"] = edge_key
138+
if cache_enabled is not None:
139+
prev["cache_enabled"] = self.cache_enabled
140+
cur["cache_enabled"] = cache_enabled
141+
try:
142+
for attr, val in cur.items():
143+
setattr(self, attr, val)
144+
except Exception:
145+
for attr, val in prev.items():
146+
setattr(self, attr, val)
147+
raise
148+
return self._df
149+
150+
151+
def _attr_raise_if_invalid_graph(df, attr):
152+
try:
153+
df.nx.source
154+
df.nx.target
155+
if df.nx.is_multigraph:
156+
df.nx.edge_key
157+
except KeyError as exc:
158+
raise AttributeError(
159+
f"{type(df).__name__!r} object has no attribute '{attr}'"
160+
) from exc
161+
if df.nx._source is None:
162+
raise AttributeError(
163+
f"{type(df).__name__!r} object has no attribute '{attr}'.\n\n"
164+
"`df.nx.source` (currently None) must be set to an existing "
165+
"column name for the DataFrame to be used as a networkx graph."
166+
)
167+
if df.nx._target is None:
168+
raise AttributeError(
169+
f"{type(df).__name__!r} object has no attribute '{attr}'.\n\n"
170+
"`df.nx.target` (currently None) must be set to an existing "
171+
"column name for the DataFrame to be used as a networkx graph."
172+
)
173+
174+
175+
def __networkx_backend__(self):
176+
# `df.__networkx_backend__` only available if `df` is a valid graph
177+
_attr_raise_if_invalid_graph(self, "__networkx_backend__")
178+
return "pandas"
179+
180+
181+
def __networkx_cache__(self):
182+
# `df.__networkx_cache__` only available if `df` is a valid graph
183+
_attr_raise_if_invalid_graph(self, "__networkx_cache__")
184+
return self.nx._cache
185+
186+
187+
def is_directed(self):
188+
"""Returns True if graph is directed, False otherwise."""
189+
return self.nx.is_directed
190+
191+
192+
def is_directed_property(self):
193+
"""Returns True if graph is directed, False otherwise."""
194+
# `df.is_directed` only available if `df` is a valid graph
195+
_attr_raise_if_invalid_graph(self, "is_directed")
196+
return is_directed.__get__(self)
197+
198+
199+
def is_multigraph(self):
200+
"""Returns True if graph is a multigraph, False otherwise."""
201+
return self.nx.is_multigraph
202+
203+
204+
def is_multigraph_property(self):
205+
"""Returns True if graph is a multigraph, False otherwise."""
206+
# `df.is_multigraph` only available if `df` is a valid graph
207+
_attr_raise_if_invalid_graph(self, "is_multigraph")
208+
return is_multigraph.__get__(self)
209+
210+
211+
pd.DataFrame.__networkx_backend__ = property(__networkx_backend__)
212+
pd.DataFrame.__networkx_cache__ = property(__networkx_cache__)
213+
# Add `is_directed` and `is_multigraph` so `not_implemented_for` decorator works
214+
pd.DataFrame.is_directed = property(is_directed_property)
215+
pd.DataFrame.is_multigraph = property(is_multigraph_property)
216+
217+
218+
def get_info():
219+
# Should we add config for e.g. default source, target, edge_key columns?
220+
# Maybe config to enable/disable cache by default?
221+
return {}

0 commit comments

Comments
 (0)