-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathday14.py
80 lines (60 loc) · 2.26 KB
/
day14.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import re
import itertools
class Decoder:
def __init__(self, mask_string=None):
self.mask_0 = 2 ** 36 - 1
self.mask_1 = 0
self.mask_x_positions = []
if mask_string:
for ix, mask in enumerate(reversed(mask_string)):
if mask == '0':
self.mask_0 -= 2 ** ix
elif mask == '1':
self.mask_1 += 2 ** ix
elif mask == 'X':
self.mask_x_positions.append(ix)
def decode(self, value: int, overwrite_with_0: bool, overwrite_with_1: bool, x_is_float: bool) -> list[int]:
decoded_value = value
if overwrite_with_0:
decoded_value &= self.mask_0
if overwrite_with_1:
decoded_value |= self.mask_1
if x_is_float:
# Set all positions of X to 0
decoded_value &= (2 ** 36 - 1) - sum(2 ** ix for ix in self.mask_x_positions)
# Return all combinations
return [
decoded_value + sum(2 ** self.mask_x_positions[ix] for ix, value in enumerate(values) if value == 1)
for values in itertools.product([0, 1], repeat=len(self.mask_x_positions))
]
else:
return [decoded_value]
def extract_mask(data: str) -> str:
return data.split(' ')[-1]
def extract_address_and_value(data: str) -> (int, int):
address = int(re.search(r'\[(.*?)\]', data).group(1))
value = int(data.split(' ')[-1])
return address, value
def run(lines: list[str], part: int) -> int:
memory = {}
decoder = Decoder()
for line in lines:
if 'mask' in line:
mask_string = extract_mask(line)
decoder = Decoder(mask_string)
else:
address, value = extract_address_and_value(line)
if part == 1:
memory[address] = decoder.decode(value, True, True, False)[0]
elif part == 2:
addresses = decoder.decode(address, False, True, True)
for ad in addresses:
memory[ad] = value
return sum(memory.values())
with open('day14-data.txt') as f:
inputs = [
line
for line in f.read().splitlines()
]
print(run(inputs, part=1))
print(run(inputs, part=2))