ludekcizinsky commited on
Commit
4bccd80
Β·
1 Parent(s): 137a427

Final Demo

Browse files
Files changed (3) hide show
  1. app.py +5 -11
  2. graph.pkl +3 -0
  3. utils.py +518 -0
app.py CHANGED
@@ -1,16 +1,8 @@
1
  import gradio as gr
2
 
3
- # ------------- Global Variables
4
- DAYS = [
5
- "Su, 26.11.2023 ",
6
- "Mo, 27.11.2023",
7
- "Tu, 28.11.2023",
8
- "We, 29.11.2023",
9
- "Th, 30.11.2023",
10
- "Fr, 01.12.2023",
11
- "Sa, 02.12.2023",
12
- ]
13
 
 
14
  TITLE = "Travel Planner" # TODO: Change title as we need
15
 
16
  # ------------- Main App
@@ -34,7 +26,7 @@ with gr.Blocks(
34
 
35
  with gr.Row() as metainfo:
36
  limit = gr.Slider(minimum=1, maximum=5, step=1, label="Limit", interactive=True, value=3)
37
- days = gr.Dropdown(DAYS, label="Day", value=DAYS[-1])
38
  time = gr.Textbox(label="Time", placeholder="hh:mm")
39
 
40
  submit = gr.Button(value="Search", variant="primary")
@@ -53,6 +45,8 @@ with gr.Blocks(
53
  ---
54
  """
55
 
 
 
56
  return {
57
  output: gr.Column(visible=True),
58
  route: md
 
1
  import gradio as gr
2
 
3
+ from utils import get_best_path
 
 
 
 
 
 
 
 
 
4
 
5
+ # ------------- Constants
6
  TITLE = "Travel Planner" # TODO: Change title as we need
7
 
8
  # ------------- Main App
 
26
 
27
  with gr.Row() as metainfo:
28
  limit = gr.Slider(minimum=1, maximum=5, step=1, label="Limit", interactive=True, value=3)
29
+ days = gr.Textbox(label="Day", placeholder="YYYY-MM-DD")
30
  time = gr.Textbox(label="Time", placeholder="hh:mm")
31
 
32
  submit = gr.Button(value="Search", variant="primary")
 
45
  ---
46
  """
47
 
48
+ md = get_best_path(A, B, day, time, limit)
49
+
50
  return {
51
  output: gr.Column(visible=True),
52
  route: md
graph.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d389ba765d0f6688beb576d4cc55ced125252491e5a20f49e124373128386331
3
+ size 11175230
utils.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Some utility functions.
3
+ """
4
+
5
+ import re
6
+ import argparse
7
+ import requests
8
+ import pandas as pd
9
+ import requests
10
+ from geopy.geocoders import Nominatim
11
+ from geopy.distance import distance
12
+ import heapq
13
+ import pickle
14
+
15
+ PR_STATIONS = "./data/pr_stations"
16
+
17
+ # Average speed in m/s
18
+ AVG_SPEED = {
19
+ "foot": 1.4,
20
+ "bike": 4.17,
21
+ "car": 13.89,
22
+ }
23
+
24
+ # Max travel time in seconds
25
+ MAX_TRAVEL_TIME = {
26
+ "foot": 60 * 60 // 4, # 15min
27
+ "bike": 60 * 60 // 2, # 30min
28
+ "car": 60 * 60, # 1h
29
+ }
30
+
31
+ PENALITIES = {
32
+ "foot": 1,
33
+ "bike": 1,
34
+ "car": 1,
35
+ "train": 1,
36
+ }
37
+
38
+
39
+ def get_penalties(sustainability: bool):
40
+ if sustainability:
41
+ PENALITIES["bike"] *= 1.1
42
+ PENALITIES["train"] *= 1.2
43
+ PENALITIES["car"] *= 5
44
+ else:
45
+ PENALITIES["car"] *= 2
46
+
47
+ return PENALITIES
48
+
49
+
50
+ def get_args():
51
+ parser = argparse.ArgumentParser()
52
+
53
+ # Default date and time args
54
+ # current_date = date.today().strftime("%Y-%m-%d")
55
+ # current_time = datetime.now().strftime("%H:%M")
56
+ current_date = "2023-12-01"
57
+ current_time = "12:00"
58
+
59
+ # Help messages
60
+ loc_types = ["address", "station name", "station abbreviation", "coordinate"]
61
+ transportation_types = ["train", "tram", "ship", "bus", "cableway"]
62
+ start_help = f"Start location (Specify either {', '.join(loc_types)})"
63
+ # via_help = f"Locations to pass through (Specify either {', '.join(loc_types)})"
64
+ stop_help = f"Stop location (Specify either {', '.join(loc_types)})"
65
+ date_help = "Date of departure (Format: YYYY-MM-DD). Default: Today"
66
+ time_help = "Time of departure (Format: YYYY-MM-DD). Default: Now"
67
+ transportation_help = f"Modes of transportation (Specify from {', '.join(transportation_types)}). Default: All"
68
+ outage_simulation = f"Simulate outage of a station"
69
+
70
+ # Specify line arguments
71
+ parser.add_argument("--start", type=str, required=True, help=start_help)
72
+ # parser.add_argument("--via", type=list[str], help=via_help)
73
+ parser.add_argument("--end", type=str, required=True, help=stop_help)
74
+ parser.add_argument("--date", type=str, default=current_date, help=date_help)
75
+ parser.add_argument("--time", type=str, default=current_time, help=time_help)
76
+ parser.add_argument(
77
+ "--limit", type=int, default=3, help="Number of journeys to return"
78
+ )
79
+ parser.add_argument(
80
+ "--transportations",
81
+ type,
82
+ choices=transportation_types,
83
+ default=["train"],
84
+ help=transportation_help,
85
+ )
86
+ parser.add_argument("--exact-travel-time", action="store_true", help=time_help)
87
+ parser.add_argument(
88
+ "--change-penalty", type=int, default=300, help="Change penalty"
89
+ )
90
+ parser.add_argument(
91
+ "--sustainability",
92
+ action="store_true",
93
+ help="Sustainability of journey",
94
+ )
95
+ parser.add_argument("--outage", action="store_true", help=outage_simulation)
96
+
97
+ return parser.parse_args()
98
+
99
+
100
+ def get_location(G, address: str) -> str:
101
+ """
102
+ Converts an address to coordinates (latitude, longitude)
103
+
104
+ Address can be:
105
+ - Full name
106
+ - Coordinates (Format: "latitude, longitude")
107
+
108
+ Args:
109
+ address (str): Address to convert to coordinates.
110
+
111
+ Returns:
112
+ str: Coordinates of the address (Format: "latitude, longitude").
113
+ """
114
+ pattern = re.compile(r"^\s*-?\d+\.\d+\s*,\s*-?\d+\.\d+\s*$")
115
+
116
+ # Check if it is already a coordinate
117
+ if bool(pattern.match(address)):
118
+ return address
119
+
120
+ # Check if already stationo
121
+ if address in G.nodes():
122
+ return G.nodes[address]["pos"]
123
+
124
+ # Use geopy to convert coordinates to address
125
+ geolocator = Nominatim(user_agent="sbb_project")
126
+ location = geolocator.geocode(
127
+ address, country_codes="ch", language="de", exactly_one=True
128
+ )
129
+
130
+ return "{}, {}".format(location.latitude, location.longitude)
131
+
132
+
133
+ def get_distance(start, end):
134
+ """
135
+ Calculate distance between two coordinates in meters.
136
+ Coordinates must be in the format "latitude, longitude".
137
+
138
+ Args:
139
+ start (str): Start coordinate.
140
+ end (str): End coordinate.
141
+
142
+ Returns:
143
+ int: Distance between the two coordinates in meters.
144
+ """
145
+ start_converted = tuple(map(float, start.split(", ")))
146
+ end_converted = tuple(map(float, end.split(", ")))
147
+ return distance(start_converted, end_converted).meters
148
+
149
+
150
+ def get_exact_travel_time(start, end, method="car"):
151
+ """
152
+ Calculate travel time between two coordinates in seconds.
153
+ """
154
+ endpoint = "http://www.mapquestapi.com/directions/v2/route"
155
+ if method not in ["foot", "bike", "car"]:
156
+ raise ValueError("Method must be either foot, bike or car")
157
+
158
+ names = {
159
+ "foot": "pedestrian",
160
+ "bike": "bicycle",
161
+ "car": "fastest",
162
+ }
163
+ # Search params
164
+ params = {
165
+ "key": "HnDX3JAuALRTge28jbZVWO1L538fJbZE",
166
+ "from": start,
167
+ "to": end,
168
+ "unit": "k", # Use km instead of miles
169
+ "narrativeType": "none", # Just some other parameters to omit information we don't care about
170
+ "sideOfStreetDisplay": False,
171
+ "routeType": names[method],
172
+ }
173
+
174
+ # Do GET request and read JSON
175
+ response = requests.get(endpoint, params=params)
176
+ data = response.json()
177
+ seconds = data["route"]["time"]
178
+
179
+ return seconds
180
+
181
+
182
+ def get_approx_travel_time(dist, method="car"):
183
+ """
184
+ Calculate the approximate travel time between two coordinates in seconds.
185
+
186
+ Args:
187
+ dist (int): Distance between the two coordinates in meters.
188
+ method (str, optional): Method of transportation. Defaults to "car".
189
+ """
190
+ return dist / AVG_SPEED[method]
191
+
192
+
193
+ def dijkstra(G, start, end, start_time, change_penalty=300, mode_penalties=PENALITIES):
194
+ # Initialize distances and time dictionary with all distances set to infinity
195
+ assert start in G.nodes(), f"Start node {start} not in graph"
196
+ assert end in G.nodes(), f"End node {end} not in graph"
197
+
198
+ distances = {
199
+ node: {
200
+ "distance": float("infinity"),
201
+ "time": start_time,
202
+ }
203
+ for node in list(G.nodes())
204
+ }
205
+
206
+ # Initialise dictionary to keep track of the edges that lead to the node with the smallest distance
207
+ edges_to = {node: None for node in list(G.nodes())}
208
+
209
+ # Set the distance from the start node to itself to 0
210
+ distances[start]["distance"] = 0
211
+ edges_to[start] = ("", "Start", {"type": "foot", "journey_id": None, "duration": 0})
212
+
213
+ # Priority queue to keep track of nodes with their current distances
214
+ priority_queue = [(0, start)]
215
+
216
+ while priority_queue:
217
+ # Get the node with the smallest distance from the priority queue
218
+ current_distance, current_node = heapq.heappop(priority_queue)
219
+
220
+ if current_node == end:
221
+ return distances[end], edges_to
222
+
223
+ # Check if the current distance is smaller than the stored distance
224
+ if current_distance > distances[current_node]["distance"]:
225
+ continue
226
+
227
+ for _, neighbor, attributes in G.out_edges(current_node, data=True):
228
+ no_duration_avail = not (type(attributes["duration"]) in [int, float])
229
+ if no_duration_avail:
230
+ continue
231
+
232
+ # Initialise distance to neighbor
233
+ distance = attributes["duration"]
234
+
235
+ # Add penalty
236
+ distance *= mode_penalties[attributes["type"]]
237
+
238
+ # Compute wait time for train (if next edge is a train)
239
+ wait = 0 # number of seconds you have to wait for the train
240
+ next_is_train = attributes["type"] == "train"
241
+ prev_is_train = edges_to[current_node][-1]["type"] == "train"
242
+ changed_trip = False
243
+ if next_is_train:
244
+ wait = (
245
+ attributes["departure"] - distances[current_node]["time"]
246
+ ).total_seconds()
247
+ train_departed = (
248
+ distances[current_node]["time"] > attributes["departure"]
249
+ )
250
+
251
+ if train_departed:
252
+ continue
253
+
254
+ if prev_is_train:
255
+ changed_trip = (
256
+ edges_to[current_node][-1]["journey_id"]
257
+ != attributes["journey_id"]
258
+ )
259
+
260
+ # Penalise changing and waiting times
261
+ changed_mode = edges_to[current_node][-1]["type"] != attributes["type"]
262
+
263
+ # Add max of waiting time or changing penalty to current dist
264
+ if changed_mode or changed_trip:
265
+ distance += max(wait, change_penalty)
266
+
267
+ # Overall dist (prev dist + dist to neighbor)
268
+ distance = current_distance + distance
269
+
270
+ # If the new distance is smaller, update the distance and add to the priority queue
271
+ if distance < distances[neighbor]["distance"]:
272
+ # Update distance and time of arrival for neighbor
273
+ time_of_arrival = distances[current_node]["time"] + pd.Timedelta(
274
+ seconds=distance
275
+ )
276
+ distances[neighbor]["distance"] = distance
277
+ distances[neighbor]["time"] = time_of_arrival
278
+
279
+ # Add edge that leads to neighbor
280
+ edges_to[neighbor] = (current_node, neighbor, attributes)
281
+
282
+ # Update priority queue
283
+ heapq.heappush(priority_queue, (distance, neighbor))
284
+
285
+ print(f"Probably there is no path between {start} and {end}")
286
+
287
+ return distances, edges_to
288
+
289
+
290
+ def reconstruct_edges(edges_to, start: str, end: str):
291
+ """
292
+ Given shortest path from start to end, reconstruct edges.
293
+ """
294
+ path = [end]
295
+ edges = [edges_to[end]]
296
+ while path[-1] != start:
297
+ path.append(edges_to[path[-1]][0])
298
+ edges.append(edges_to[path[-1]])
299
+
300
+ return edges
301
+
302
+
303
+ def postprocess_path(edges):
304
+ """Merge two edges if they have same transport type.
305
+
306
+ Edge Structure: (
307
+ start,
308
+ end,
309
+ {
310
+ type,
311
+ duration,
312
+ departure (only for train),
313
+ arrival (only for train),
314
+ journey_id (only for train),
315
+ trip_name (only for train),
316
+ )
317
+
318
+ Args:
319
+ edges (list[tuple]): list of travel edges.
320
+
321
+ Returns:
322
+ list[tuple]: post-processed list of travel edges.
323
+ """
324
+ traversed = []
325
+ prev = None
326
+
327
+ for edge in edges[::-1]:
328
+ if prev is None:
329
+ prev = edge
330
+ traversed.append(edge)
331
+ continue
332
+
333
+ # Need to check the transport type and, if it is train,
334
+ # the journey id
335
+ if edge[2]["type"] == prev[2]["type"] and (
336
+ edge[2]["type"] != "train" or edge[2]["journey_id"] == prev[2]["journey_id"]
337
+ ):
338
+ prev = traversed.pop()
339
+
340
+ # Merge the two edges
341
+ new_edge = (
342
+ prev[0],
343
+ edge[1],
344
+ {
345
+ "type": prev[2]["type"],
346
+ "duration": prev[2]["duration"] + edge[2]["duration"],
347
+ },
348
+ )
349
+
350
+ if edge[2]["type"] == "train":
351
+ new_edge[2]["departure"] = prev[2]["departure"]
352
+ new_edge[2]["arrival"] = edge[2]["arrival"]
353
+ new_edge[2]["journey_id"] = edge[2]["journey_id"]
354
+ new_edge[2]["trip_name"] = edge[2]["trip_name"]
355
+ else:
356
+ new_edge = edge
357
+
358
+ prev = new_edge
359
+ traversed.append(new_edge)
360
+
361
+ return traversed
362
+
363
+
364
+ def pretty_time_delta(seconds):
365
+ seconds = int(seconds)
366
+ days, seconds = divmod(seconds, 86400)
367
+ hours, seconds = divmod(seconds, 3600)
368
+ minutes, seconds = divmod(seconds, 60)
369
+ if days > 0:
370
+ return "%dd%dh%dm%ds" % (days, hours, minutes, seconds)
371
+ elif hours > 0:
372
+ return "%dh%dm%ds" % (hours, minutes, seconds)
373
+ elif minutes > 0:
374
+ return "%dm%ds" % (minutes, seconds)
375
+ else:
376
+ return "%ds" % (seconds,)
377
+
378
+
379
+ def pretty_print(edges, args):
380
+ """Prints the edges in a nice format.
381
+
382
+ Args:
383
+ edges (list[tuple]): list of travel edges.
384
+ """
385
+ print(f"# Your Journey from {args.start} to {args.end}")
386
+
387
+ print(f"\nDate/ Time: {args.date} at {args.time}")
388
+ print(f"Sustainability: {args.sustainability}\n")
389
+
390
+ print("### Travel Information\n---\n")
391
+
392
+ for i, (src, dst, attr) in enumerate(edges):
393
+ if src == "Start":
394
+ src = args.start
395
+ if dst == "End":
396
+ dst = args.end
397
+
398
+ duration = pretty_time_delta(attr["duration"])
399
+
400
+ print(f"{i+1}. Go by {attr['type']} from {src} to {dst} for {duration}")
401
+
402
+ """
403
+ Remove all the trains from one station to another to simulate an outage.
404
+ """
405
+
406
+
407
+ def remove_all_trains(G, from_station, to_station):
408
+ edges_to_remove = []
409
+
410
+ for edge in G.out_edges(from_station, data=True):
411
+ if edge[2]["type"] == "train" and edge[1] == to_station:
412
+ edges_to_remove.append(edge[:2])
413
+
414
+ print(
415
+ f"Removing all the edges from {from_station} to {to_station} : {len(edges_to_remove)}"
416
+ )
417
+ for edge in edges_to_remove:
418
+ G.remove_edge(*edge)
419
+
420
+
421
+ def get_final_path_md(edges, start, end, date, time, sustainability):
422
+
423
+ md = f"## Your Journey from {start} to {end}\n\n"
424
+ md += f"πŸ“… Date/ Time: {date} at {time}\n"
425
+ md += "### Travel Information\n"
426
+
427
+ for i, (src, dst, attr) in enumerate(edges):
428
+ if src == "Start":
429
+ src = start
430
+ if dst == "End":
431
+ dst = end
432
+
433
+ duration = pretty_time_delta(attr["duration"])
434
+
435
+ travel_type = attr['type']
436
+ emoji = "πŸš‰" if travel_type == "train" else "πŸš—" if travel_type == "car" else "🚢" if travel_type == "foot" else "πŸš€"
437
+
438
+ md += f"{i+1}. {emoji} Go by {travel_type} from {src} to {dst} for {duration}\n\n"
439
+
440
+ return md
441
+
442
+ def get_best_path(start, end, date, time, limit, exact_travel_time=False, outage=False, sustainability=False, change_penalty=300):
443
+
444
+ # Load graph
445
+ with open("graph.pkl", "rb") as f:
446
+ G = pickle.load(f)
447
+
448
+ if outage:
449
+ # Remove all the edges from 2 stations to simulate an outage
450
+ remove_all_trains(G, from_station="St-Maurice", to_station="Martigny")
451
+
452
+ # Convert start and destination to location (lon, lat)
453
+ start_loc = get_location(G, start)
454
+ end_loc = get_location(G, end)
455
+
456
+ # Add both locations to graph
457
+ G.add_node("Start", pos=start_loc)
458
+ G.add_node("End", pos=end_loc)
459
+
460
+ # Find k-closest stations from start and end
461
+ dists_from_start = []
462
+ dists_to_end = []
463
+ for station, attr in G.nodes(data=True):
464
+ try:
465
+ station_pos = attr["pos"]
466
+ dists_from_start.append(
467
+ (station, get_distance(start_loc, station_pos))
468
+ )
469
+ dists_to_end.append((station, get_distance(end_loc, station_pos)))
470
+ except Exception as e:
471
+ continue
472
+
473
+ # Sort the distances in place
474
+ start_k_closest = sorted(dists_from_start, key=lambda x: x[1])[: limit]
475
+ end_k_closest = sorted(dists_to_end, key=lambda x: x[1])[: limit]
476
+
477
+ # Compute travel time from start to k closest stations
478
+ for mode in ["foot", "bike", "car"]:
479
+ for station, dist in start_k_closest:
480
+ if exact_travel_time:
481
+ travel_time = get_exact_travel_time(
482
+ start_loc, station, method=mode
483
+ )
484
+ G.add_edge("Start", station, duration=travel_time, type=mode)
485
+ else:
486
+ travel_time = get_approx_travel_time(dist, method=mode)
487
+ G.add_edge("Start", station, duration=travel_time, type=mode)
488
+
489
+ for station, dist in end_k_closest:
490
+ if exact_travel_time:
491
+ travel_time = get_exact_travel_time(dist, method=mode)
492
+ G.add_edge(station, "End", duration=travel_time, type=mode)
493
+ else:
494
+ travel_time = get_approx_travel_time(dist, method=mode)
495
+ G.add_edge(station, "End", duration=travel_time, type=mode)
496
+
497
+ # Run Dijkstra on graph
498
+ start_time = pd.to_datetime(f"{date} {time}")
499
+ mode_penalties = get_penalties(sustainability)
500
+ dists, edges_to = dijkstra(
501
+ G,
502
+ "Start",
503
+ "End",
504
+ start_time=start_time,
505
+ change_penalty=change_penalty,
506
+ mode_penalties=mode_penalties,
507
+ )
508
+
509
+ # Reconstruct path
510
+ edges = reconstruct_edges(edges_to, "Start", "End")
511
+
512
+ # Postprocess path
513
+ path = postprocess_path(edges[:-1])
514
+
515
+ # Print journey
516
+ md = get_final_path_md(path, start, end, date, time, sustainability)
517
+
518
+ return md