Skip to content

Commit 3b1afc2

Browse files
committed
wip sciserver backend, seems to work, needs more tests
1 parent 0ecc81c commit 3b1afc2

File tree

5 files changed

+66
-13
lines changed

5 files changed

+66
-13
lines changed

README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,47 @@ For other options, see the help message.
3939
python -m src.run_algo --help
4040
```
4141

42+
### Implementing a runtime
43+
44+
The algorithm works generall in three phases:
45+
46+
1. Get the prior data for that year and any number of prior years.
47+
2. Run the label propagation algorithm
48+
3. Update the posterior data for that year
49+
50+
A backend then needs to implement steps 1 and 3.
51+
52+
You need to implement the following functions
53+
54+
```python
55+
MaybeSparseMatrix = Union[np.ndarray, sp.spmatrix]
56+
57+
get_data(
58+
year: int,
59+
logger: logging.Logger
60+
) -> Tuple[MaybeSparseMatrix, np.ndarray, np.ndarray]:
61+
```
62+
63+
This function accepts a year and a logger and returns a tuple of the following:
64+
- The adjacency matrix
65+
- The auids
66+
- The prior for the auids
67+
68+
The second function you need to implement is
69+
70+
```python
71+
def update_posterior(
72+
auids: np.ndarray,
73+
posterior_y_value: np.ndarray,
74+
year: int,
75+
logger: logging.Logger,
76+
) -> None:
77+
```
78+
79+
This function accepts the auids, the posterior_y_value, and the year and
80+
updates the posterior values for that year. It's important to note that
81+
if you parse the graph in pieces of disconnnected sets, this will update
82+
the same file multiple times.
4283

4384
### TODO
4485

src/backend/elsevier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23+
import logging
2324
from typing import Tuple
2425

2526
import numpy as np
2627
import scipy.sparse as sparse
2728

2829

29-
def get_data(year: int) -> Tuple[sparse.csr_matrix, np.ndarray, np.ndarray]:
30+
def get_data(year: int, logger: logging.Logger) -> Tuple[sparse.csr_matrix, np.ndarray, np.ndarray]:
3031
raise NotImplementedError("Not implemented in the Elsevier backend.")
3132

3233

src/backend/sciserver.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,20 @@
2626
import itertools
2727
import logging
2828
import os
29+
import warnings
2930
from typing import Callable, Iterable, Iterator, List, Tuple, Union
3031

3132
import numpy as np
3233
import pandas as pd
3334
import scipy.sparse as sparse
35+
try:
36+
from pandarallel import pandarallel
37+
pandarallel.initialize(progress_bar=True)
38+
parallel_apply = True
39+
except ImportError:
40+
warnings.warn("pandarallel not installed, parallel processing will not be available.")
41+
parallel_apply = False
42+
3443

3544
import src.utils.log_time as log_time
3645

@@ -59,12 +68,8 @@ def default_combine_posterior_prior_y_func(arrs: List[np.ndarray]) -> np.ndarray
5968
if not all(arr.shape[0] == length for arr in arrs):
6069
raise ValueError("All arrays must be same length.")
6170

62-
print("arr shapes:", [arr.shape for arr in arrs])
63-
6471
outs = np.nanmean(np.stack(arrs, axis=1), axis=1)
6572

66-
print("out shape:", outs.shape)
67-
6873
return np.nanmean(np.stack(arrs, axis=1), axis=1)
6974

7075

@@ -197,9 +202,14 @@ def calculate_prior_y_from_eids(
197202

198203
selected_eids = auid_eids[auids]
199204

200-
y = selected_eids.apply(
201-
lambda eids: agg_score_func(eid_score[eids])
202-
).astype(eid_score.dtype)
205+
if len(selected_eids) > MIN_ARR_SIZE_FOR_CACHE and parallel_apply:
206+
y = selected_eids.parallel_apply(
207+
lambda eids: agg_score_func(eid_score[eids])
208+
).astype(eid_score.dtype)
209+
else:
210+
y = selected_eids.apply(
211+
lambda eids: agg_score_func(eid_score[eids])
212+
).astype(eid_score.dtype)
203213

204214
return y
205215

@@ -384,6 +394,7 @@ def update_posterior(
384394
auids: np.ndarray,
385395
posterior_y_values: np.ndarray,
386396
year: int,
397+
logger: logging.Logger,
387398
) -> None:
388399

389400
posterior_path = POSTIEOR_DATA_PATH.format(year=year)

src/run_algo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,10 @@ def run_algo_year(
7373
for i, (algo, (A, auids, prior_y)) in enumerate(
7474
zip(algo_instances, get_data_func(year, logger)), start=1
7575
):
76-
print(i, A.shape, auids.shape, prior_y.shape)
7776
with log_time.LogTime(f"Fitting data for {year}, ajd matrix {i}", logger):
7877
posterior_y = algo.fit_predict_graph(A, prior_y)
7978
with log_time.LogTime(f"Updating posterior for {year}", logger):
80-
posterior_update_func(auids, posterior_y, year)
79+
posterior_update_func(auids, posterior_y, year, logger)
8180

8281

8382
def main(args: Dict[str, Any]):
@@ -98,7 +97,7 @@ def main(args: Dict[str, Any]):
9897
get_data_func = functools.partial(
9998
sciserver.get_data,
10099
prior_y_aggregate_eid_score_func=np.mean,
101-
combine_posterior_prior_y_func=functools.partial(np.mean, axis=1),
100+
combine_posterior_prior_y_func=sciserver.default_combine_posterior_prior_y_func,
102101
operate_on_subgraphs_separately=args.get("parse_subgraphs_separately"),
103102
)
104103
posterior_update_func = sciserver.update_posterior

tests/test_sciserver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
"""Testing for the SciServer backend."""
2424

25+
import logging
2526
import os
2627

2728
import numpy as np
@@ -252,7 +253,7 @@ def test_update_posterior_nothing_exists():
252253
os.makedirs("./tmp", exist_ok=True)
253254
ss.POSTIEOR_DATA_PATH = "./tmp/posterior_{year}.parquet"
254255

255-
ss.update_posterior(auids, posterior_y, year)
256+
ss.update_posterior(auids, posterior_y, year, logging.getLogger("test"))
256257

257258
df = pd.read_parquet(ss.POSTIEOR_DATA_PATH.format(year=2020))
258259

@@ -282,7 +283,7 @@ def test_update_posterior_something_exists():
282283
auids = [4, 5, 6]
283284
posterior_y = np.array([0.6, 0.5, 0.4])
284285

285-
ss.update_posterior(auids, posterior_y, year)
286+
ss.update_posterior(auids, posterior_y, year, logging.getLogger("test"))
286287

287288
df = pd.read_parquet(ss.POSTIEOR_DATA_PATH.format(year=2020))
288289

0 commit comments

Comments
 (0)