Skip to content

Commit 00a033f

Browse files
author
thigg
committed
updated plotter to use sqlite
1 parent 946d1f1 commit 00a033f

File tree

2 files changed

+49
-117
lines changed

2 files changed

+49
-117
lines changed

plotter/main.py

+48-117
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,99 @@
11
import argparse
2-
import functools
3-
import itertools
4-
import json
52
import logging
6-
import lzma
7-
import os
3+
import sqlite3
84
from collections import defaultdict
9-
from concurrent.futures import ProcessPoolExecutor
105
from datetime import datetime, timedelta
11-
from typing import Dict, List, Tuple
6+
from typing import Dict, List
127

13-
import brotli
148
import matplotlib.dates as mdates
159
import matplotlib.pyplot as plt
1610

1711
parser = argparse.ArgumentParser(description='accumulate fahrpreis data')
18-
parser.add_argument('--accufile',
19-
help='file with latest accu data (/tmp/fahrpreise_akku) usefull for working on plotting')
20-
parser.add_argument('start_station',
21-
help='start station id')
22-
parser.add_argument('end_station', help='end station id')
12+
parser.add_argument('--dbfile', help='path to the sqlitefile with the data', required=True)
13+
parser.add_argument('--start_station', help='start station id', required=True)
14+
parser.add_argument('--end_station', help='end station id', required=True)
2315
parser.add_argument('--plot_timeframe_past', help='oldest travel start date on the plot, days relative to now',
2416
default=60)
2517
parser.add_argument('--plot_timeframe_future', help='newest travel start date on the plot, days relative to now',
2618
default=10)
19+
parser.add_argument('--plot_timeframe_date', help='what now is for timeframe',
20+
default=datetime.now())
2721

2822
args = parser.parse_args()
2923

3024

31-
def handle_file(path: str):
32-
"""
33-
reads a record from the crawler and transofrms it into a list with connections and prices
34-
:param path: the file to read
35-
:return: a list of connections with prices and when the price was queried
36-
"""
37-
try:
38-
result = list()
39-
with open(path, 'rb') as in_file:
40-
decompressor = brotli.Decompressor()
41-
s: str = ''
42-
read_chunk = functools.partial(in_file.read, )
43-
for data in iter(read_chunk, b''):
44-
s += bytes.decode(decompressor.process(data), 'utf-8')
45-
dict = json.loads(s)
46-
queried_at = dict['queried_at']
47-
for day in dict['data']:
48-
for travel in day:
49-
price = travel['price']['amount']
50-
start_station = travel['legs'][0]['origin']['id']
51-
start_time = travel['legs'][0]['departure']
52-
end_station = travel['legs'][-1]['origin']['id']
53-
end_time = travel['legs'][-1]['departure']
54-
dict_key = "$".join([start_station, start_time, end_station, end_time])
55-
result.append((dict_key, {"queried_at": queried_at, "price": price}))
56-
return result
57-
except Exception as e:
58-
print("Could not read file %s, %s, size: %d" % (path, e, os.stat(path).st_size))
59-
return []
60-
25+
def accumulate_sqlite(filename: str, start_station: int, end_station: int, timeframe_start: datetime,
26+
timeframe_end: datetime) -> dict[str, list[tuple[int, int]]]:
27+
conn = sqlite3.connect(filename)
28+
cursor = conn.cursor()
6129

62-
def accumulate_data() -> dict[str, list[dict[str, object]]]:
63-
"""
64-
preprocesses the raw data into the data we need for the plot
65-
:return: a dictionary of all travels (a connection at a specific datetime) with a list of prices and when the price was queried
66-
"""
67-
starttime = datetime.now()
68-
result: dict[str, list[dict[str, object]]] = defaultdict(list)
69-
with os.scandir("/tmp/fahrpreise/") as dirIterator:
70-
with ProcessPoolExecutor() as executor:
71-
resultlist = list(
72-
itertools.chain.from_iterable(
73-
executor.map(handle_file, (str(entry.path) for entry in dirIterator if
74-
entry.name.endswith('.brotli') and entry.is_file()))))
75-
print(f"got resultlist {len(resultlist)}")
76-
for item in resultlist:
77-
if item:
78-
key, value = item
79-
result[key].append(value)
30+
# Create a dictionary to store accumulated prices
31+
prices_dict = defaultdict(lambda: [])
8032

81-
print(f"accumulation took {datetime.now() - starttime}")
82-
return result
33+
# Execute SQL query to fetch data from the table
34+
cursor.execute("SELECT `when`, `price_cents`, `queried_at` FROM fahrpreise "
35+
f"where `from` = {start_station} and `to` = {end_station} "
36+
f"and `queried_at` >= {round(timeframe_start.timestamp() * 1000)} "
37+
f"and `queried_at` <= {round(timeframe_end.timestamp() * 1000)}")
8338

39+
# Fetch all rows and accumulate prices
40+
rows = cursor.fetchall()
41+
for row in rows:
42+
when = row[0]
43+
price = row[1]
44+
queried_at = row[2]
45+
prices_dict[when].append((int(queried_at), int(price)))
8446

85-
isoformatstr = "%Y-%m-%dT%H:%M:%S.%fZ"
47+
conn.close()
48+
return prices_dict
8649

8750

88-
def plot(result: Dict[str, List[Tuple[str, str]]], start_station_filter: str, end_station_filter: str,
89-
starttime_after: datetime, starttime_before):
51+
def plot(result: Dict[str, List[tuple[int, int]]]):
9052
"""
9153
9254
:param result: the data to ingest for the plot
93-
:param start_station_filter: which start station to consider
94-
:param end_station_filter: which end station to consider
95-
:param starttime_after: timeframe to plot lower limit
96-
:param starttime_before: timeframe to plot upper limit
9755
:return: shows the plot
9856
"""
99-
print(f"creating filtered plot for stations: {start_station_filter}, {end_station_filter}")
10057
time_to_departure = []
10158
# booking_date = []
10259
departure_date = []
10360
# y_axis2 = []
104-
travel_price = []
61+
travel_price_euro = []
10562
# z_axis2 = []
106-
recorded_connections: dict[tuple[str, str], int] = defaultdict(lambda: 0)
10763
for i, travelprices in enumerate(result.items()):
10864
try:
109-
keystr = travelprices[0]
110-
keystr_split = keystr.split("$")
111-
start_station = keystr_split[0]
112-
end_station = keystr_split[2]
113-
recorded_connections[(start_station, end_station)] += 1
114-
if start_station != start_station_filter or end_station != end_station_filter:
115-
# print(f"skipping '{start_station}'({type(start_station)}) '{end_station}' because filters '{start_station_filter}'({type(start_station_filter)}) '{end_station_filter}'")
116-
continue
117-
starttime = datetime.strptime(keystr_split[1], isoformatstr)
118-
endtime = datetime.strptime(keystr_split[3], isoformatstr)
65+
when, price_records = travelprices
66+
starttime = datetime.fromtimestamp(int(when) / 1000)
11967
# z_axis2.append((endtime-starttime).total_seconds())
12068
# y_axis2.append(starttime)
121-
for data in travelprices[1]:
122-
if starttime > starttime_after and starttime < starttime_before:
123-
time_to_departure.append(
124-
-(starttime - datetime.strptime(data["queried_at"], isoformatstr)).total_seconds() / (
125-
60 * 60 * 24))
126-
# booking_date.append(datetime.strptime(data["queried_at"], isoformatstr))
127-
departure_date.append(starttime)
128-
travel_price.append(data["price"])
69+
price_record: Dict[str, int]
70+
for queried_at, price in price_records:
71+
queried_at_date = datetime.fromtimestamp(queried_at / 1000)
72+
days_to_departure = (starttime - queried_at_date).total_seconds() / (60 * 60 * 24)
73+
time_to_departure.append(-days_to_departure)
74+
departure_date.append(starttime)
75+
travel_price_euro.append(price / 100)
12976
except:
13077
logging.exception("exception while prepping plot %d %s", i, travelprices)
131-
print("recorded connections: " + str(recorded_connections.items()))
132-
print("resulting datapoints: %d" % len(travel_price))
133-
print("resulting datapoints for filter: %d" % recorded_connections[(start_station_filter, end_station_filter)])
134-
color_map = plt.cm.plasma
135-
plt.rcParams['figure.figsize'] = [30, 50]
78+
print("resulting datapoints: %d" % len(travel_price_euro))
79+
plt.rcParams['figure.figsize'] = [12, 20]
13680
fig, ax = plt.subplots(1)
137-
prices = ax
81+
price_records = ax
13882
# plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%a %d.%m'))
13983
plt.gca().yaxis.set_major_formatter(mdates.DateFormatter('%a %d.%m'))
14084
# plt.gca().xaxis.set_major_locator(mdates.DayLocator())
14185
plt.gca().yaxis.set_major_locator(mdates.DayLocator())
142-
pcm = prices.scatter(time_to_departure, departure_date, c=travel_price, cmap=color_map, marker=".", s=1, vmax=75)
86+
pcm = price_records.scatter(time_to_departure, departure_date, c=travel_price_euro, cmap=plt.colormaps["plasma"], marker=".", s=1,
87+
vmax=75)
14388
# fig.autofmt_xdate()
14489
# prices.twinx().barh(y_axis2,z_axis2,height=0.1)
145-
fig.colorbar(pcm, label="price", ax=prices)
90+
fig.colorbar(pcm, label="price (Euro)", ax=price_records)
14691
plt.gca().xaxis.set_label_text("tage bis abfahrt / wann ich buche")
14792
plt.gca().yaxis.set_label_text("datum der reise / wann ich fahren möchte")
14893
plt.show()
14994

15095

151-
result = {}
152-
if args.accufile:
153-
print("reading data from tmpfile")
154-
with lzma.open(args.accufile, "rt") as infile:
155-
result = json.load(infile)
156-
print("done reading infile")
157-
else:
158-
starttime = datetime.now()
159-
result = accumulate_data()
160-
startcompresstime = datetime.now()
161-
with lzma.open("/tmp/fahrpreise_akku", "wt", preset=4) as outfile:
162-
print("writing accufile")
163-
json.dump(result, outfile)
164-
print(
165-
f"wrote accufile. whole process took total={datetime.now() - starttime} akku={startcompresstime - starttime} write={datetime.now() - startcompresstime}")
166-
167-
plot(result, args.start_station, args.end_station, datetime.now() - timedelta(days=args.plot_timeframe_past),
168-
datetime.now() + timedelta(days=args.plot_timeframe_future))
96+
timeframe_start: datetime = datetime.now() - timedelta(days=args.plot_timeframe_past)
97+
timeframe_end: datetime = datetime.now() + timedelta(days=args.plot_timeframe_future)
98+
result = accumulate_sqlite(args.dbfile, int(args.start_station), int(args.end_station), timeframe_start, timeframe_end)
99+
plot(result)

plotter/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
matplotlib

0 commit comments

Comments
 (0)