-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcollect_data.py
62 lines (47 loc) · 2.44 KB
/
collect_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import logging
import sys
from typing import List
from core.src.utils.df_utils import read_df
from core.src.model.api.platform_objects import Platform
from data_collection.src.hyperskill.hyperskill_client import HyperskillClient
from data_collection.src.stepik.stepik_client import StepikClient
from data_collection.src.utils.csv_utils import save_objects_to_csv
platform_client = {
Platform.HYPERSKILL: HyperskillClient,
Platform.STEPIK: StepikClient,
}
def configure_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('platform', type=str, help='platform to collect data from', choices=Platform.values())
parser.add_argument('object', type=str,
help='objects to request from platform (can be defaults like `step` or custom like `java`')
parser.add_argument('output_path', type=str, help='path to directory where to save the results')
parser.add_argument('--ids', '-i', nargs='*', type=int, default=None, help='ids of requested objects')
parser.add_argument('--ids-from-file-path', '-f', type=str, default=None, help='csv file to get ids from')
parser.add_argument('--ids-from-column', '-c', type=str, default=None, help='column in csv file to get ids from')
parser.add_argument('--count', '-cnt', type=int, default=None, help='count of requested objects')
parser.add_argument('--port', '-p', type=int, default=8000, help='port to run authorization server at')
return parser
def get_object_ids_from_file(csv_file_path: str, column_name: str) -> List[int]:
"""
Get ids from scv file column. Method is useful when extra information is required for some subset of objects,
which are already used in existing dataset (e.x. dataset of solutions).
"""
return list(read_df(csv_file_path)[column_name].unique().values)
logging.basicConfig(level=logging.DEBUG)
def main():
parser = configure_parser()
args = parser.parse_args(sys.argv[1:])
platform = Platform(args.platform)
client = platform_client[platform]()
if args.ids is not None:
ids = args.ids
elif args.ids_from_file_path is not None and args.ids_from_column is not None:
ids = get_object_ids_from_file(args.ids_from_file_path, args.ids_from_column)
else:
ids = None
objects = client.get_objects(args.object, ids, args.count)
save_objects_to_csv(args.output_path, objects, args.object)
if __name__ == '__main__':
main()