-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathsage.py
370 lines (298 loc) · 15.6 KB
/
sage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import argparse
import csv
import datetime
import glob
import json
import os
import re
import sys
import requests
from ag_generation import make_attack_graphs
from episode_sequence_generation import aggregate_into_episodes, host_episode_sequences, break_into_subbehaviors
from model_learning import generate_traces, flexfringe, load_model, encode_sequences
from plotting import plot_alert_filtering, plot_histogram, plot_state_groups
from signatures.attack_stages import MicroAttackStage
from signatures.mappings import micro_inv
from signatures.alert_signatures import usual_mapping, unknown_mapping, ccdc_combined, attack_stage_mapping
IANA_CSV_FILE = "https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.csv"
IANA_NUM_RETRIES = 5
SAVE_AG = True
CPTC_BAD_IP = '169.254.169.254'
def _get_attack_stage_mapping(signature):
"""
Infers the attack stage based on the alert signature.
@param signature: the signature of the alert
@return: the inferred attack stage
"""
result = MicroAttackStage.NON_MALICIOUS
if signature in usual_mapping.keys():
result = usual_mapping[signature]
elif signature in unknown_mapping.keys():
result = unknown_mapping[signature]
elif signature in ccdc_combined.keys():
result = ccdc_combined[signature]
else:
for k, v in attack_stage_mapping.items():
if signature in v:
result = k
break
return micro_inv[str(result)]
# Step 0: Download the IANA port-service mapping
def load_iana_mapping():
"""
Downloads the IANA port-service mapping. In case of a failure or a timeout, retries IANA_NUM_RETRIES times.
@return: a dictionary that maps a port to the corresponding service based on the IANA mapping
"""
# Perform the first request and in case of a failure retry the specified number of times
for attempt in range(IANA_NUM_RETRIES + 1):
response = requests.get(IANA_CSV_FILE)
if response.ok:
content = response.content.decode("utf-8")
break
elif attempt < IANA_NUM_RETRIES:
print('Could not download IANA ports. Retrying...')
else:
raise RuntimeError('Cannot download IANA ports')
table = csv.reader(content.splitlines())
# Drop headers (service name, port, protocol, description, ...)
next(table)
# Note that ports might have holes
ports = {}
for row in table:
# Drop missing port number, Unassigned and Reserved ports
if row[1] and 'Unassigned' not in row[3]: # and 'Reserved' not in row[3]:
# Split range in single ports
if '-' in row[1]:
low_port, high_port = map(int, row[1].split('-'))
else:
low_port = high_port = int(row[1])
for port in range(low_port, high_port + 1):
ports[port] = {
"name": row[0] if row[0] else "unknown",
"description": row[3] if row[3] else "---",
}
return ports
def _readfile(fname):
"""
Reads the file with the alerts.
@param fname: the name of the file with the alerts
@return: the unparsed alerts
"""
with open(fname, 'r') as f:
unparsed_data = json.load(f)
unparsed_data = unparsed_data[::-1]
return unparsed_data
# Step 1.1: Parse the input alerts
def _parse(unparsed_data):
"""
Parses the alerts and converts them into the specific format:
(diff_dt, src_ip, src_port, dst_ip, dst_port, sig, cat, host, dt, mcat).
@param unparsed_data: the unparsed alerts
@return: parsed alerts, sorted by the start time
"""
parsed_data = []
prev = -1
for d in unparsed_data:
if 'result' in d and '_raw' in d['result']:
raw = json.loads(d['result']['_raw'])
elif '_raw' in d:
raw = json.loads(d['_raw'])
else:
raw = d
if raw['event_type'] != 'alert':
continue
if 'host' in raw:
host = raw['host']
elif 'host' in d:
host = d['host'][3:]
else:
host = 'dummy'
dt = datetime.datetime.strptime(raw['timestamp'], '%Y-%m-%dT%H:%M:%S.%f%z') # 2018-11-03T23:16:09.148520+0000
diff_dt = 0.0 if prev == -1 else round((dt - prev).total_seconds(), 2)
prev = dt
sig = raw['alert']['signature']
cat = raw['alert']['category']
src_ip = raw['src_ip']
src_port = None if 'src_port' not in raw.keys() else raw['src_port']
dst_ip = raw['dest_ip']
dst_port = None if 'dest_port' not in raw.keys() else raw['dest_port']
# Filter out mistaken alerts / uninteresting alerts
if (dataset_name == 'cptc' and CPTC_BAD_IP in (src_ip, dst_ip)) or cat == 'Not Suspicious Traffic':
continue
mcat = _get_attack_stage_mapping(sig)
parsed_data.append((diff_dt, src_ip, src_port, dst_ip, dst_port, sig, cat, host, dt, mcat))
print('Reading # alerts: ', len(parsed_data))
parsed_data = sorted(parsed_data, key=lambda al: al[8]) # Sort alerts into ascending order
return parsed_data
# Step 1.2: Remove duplicate alerts (defined by the alert_filtering_window parameter)
def _remove_duplicates(unfiltered_alerts, plot=False, gap=1.0):
"""
Removes the duplicate alerts, i.e. all alerts with identical attributes that occur within a gap (=1.0),
keeping only the first occurrence, as defined in the paper.
@param unfiltered_alerts: the parsed alerts that have not yet been filtered
@param plot: whether to plot the alert frequencies per Micro Attack Stage before and after filtering
@param gap: the filtering gap, i.e. alert filtering window (parameter `t` in the paper, default 1.0 sec)
@return: the alerts without duplicates
"""
filtered_alerts = [unfiltered_alerts[x] for x in range(1, len(unfiltered_alerts))
if unfiltered_alerts[x][9] != MicroAttackStage.NON_MALICIOUS.value # Skip non-malicious alerts
and not (unfiltered_alerts[x][0] <= gap # Diff from previous alert is less than gap sec
and unfiltered_alerts[x][1] == unfiltered_alerts[x - 1][1] # Same srcIP
and unfiltered_alerts[x][3] == unfiltered_alerts[x - 1][3] # Same destIP
and unfiltered_alerts[x][5] == unfiltered_alerts[x - 1][5] # Same suricata category
and unfiltered_alerts[x][2] == unfiltered_alerts[x - 1][2] # Same srcPort
and unfiltered_alerts[x][4] == unfiltered_alerts[x - 1][4])] # Same destPort
if plot:
plot_alert_filtering(unfiltered_alerts, filtered_alerts)
print('Filtered # alerts (remaining):', len(filtered_alerts))
return filtered_alerts
# Step 1: Read the input alerts
def load_data(path_to_alerts, filtering_window, start, end):
"""
Reads the input alerts, parses them, removes duplicates, and groups them per attacker team.
@param path_to_alerts: the path to the directory with the alerts
@param filtering_window: filtering window (aka gap, aka t, default: 1.0)
@param start: the start hour to limit alerts based on the user preferences (default: 0)
@param end: the end hour to limit alerts based on the user preferences (default: 100)
@return: parsed and filtered alerts grouped by team, team labels and the first timestamp for each team
"""
_team_alerts = []
_team_labels = []
_team_start_times = [] # Record the first alert just to get the real elapsed time (if the user filters (s,e) range)
files = glob.glob(path_to_alerts + "/*.json")
print('About to read json files...')
if len(files) < 1:
print('No alert files found.')
sys.exit()
for f in files:
name = os.path.basename(f)[:-5]
print(name)
_team_labels.append(name)
parsed_alerts = _parse(_readfile(f))
parsed_alerts = _remove_duplicates(parsed_alerts, gap=filtering_window)
# EXP: Limit alerts by timing is better than limiting volume because each team is on a different scale.
# 50% alerts for one team end at a diff time than for others
end_time_limit = 3600 * end # Which hour to end at?
start_time_limit = 3600 * start # Which hour to start from?
first_ts = parsed_alerts[0][8]
_team_start_times.append(first_ts)
filtered_alerts = [x for x in parsed_alerts if (((x[8] - first_ts).total_seconds() <= end_time_limit)
and ((x[8] - first_ts).total_seconds() >= start_time_limit))]
_team_alerts.append(filtered_alerts)
return _team_alerts, _team_labels, _team_start_times
def group_alerts_per_team(alerts, port_mapping):
"""
Reorganises the alerts per team, for each attacker and victim pair.
@param alerts: the parsed and filtered alerts, grouped by team
@param port_mapping: the IANA port-service mapping
@return: alerts grouped by team and by (src_ip, dst_ip)
"""
_team_data = dict()
for tid, team in enumerate(alerts):
host_alerts = dict() # (attacker, victim) -> alerts
for alert in team:
# Alert format: (diff_dt, src_ip, src_port, dst_ip, dst_port, sig, cat, host, ts, mcat)
src_ip, dst_ip, signature, ts, mcat = alert[1], alert[3], alert[5], alert[8], alert[9]
dst_port = alert[4] if alert[4] is not None else 65000
# Say 'unknown' if the port cannot be resolved
if dst_port not in port_mapping.keys() or port_mapping[dst_port] == 'unknown':
dst_port = 'unknown'
else:
dst_port = port_mapping[dst_port]['name']
# For the CPTC dataset, attacker IPs (src_ip) start with '10.0.254', but this prefix might also be in dst_ip
# TODO: for the future, we might want to address internal paths
if dataset_name == 'cptc' and not src_ip.startswith('10.0.254') and not dst_ip.startswith('10.0.254'):
continue
# Swap src_ip and dst_ip, so that the prefix '10.0.254' is in src_ip
if dataset_name == 'cptc' and dst_ip.startswith('10.0.254'):
src_ip, dst_ip = dst_ip, src_ip
if (src_ip, dst_ip) not in host_alerts.keys() and (dst_ip, src_ip) not in host_alerts.keys():
host_alerts[(src_ip, dst_ip)] = []
if (src_ip, dst_ip) in host_alerts.keys(): # TODO: remove the redundant host names
host_alerts[(src_ip, dst_ip)].append((dst_ip, mcat, ts, dst_port, signature))
else:
host_alerts[(dst_ip, src_ip)].append((src_ip, mcat, ts, dst_port, signature))
_team_data[tid] = host_alerts.items()
return _team_data
# ----- MAIN ------
parser = argparse.ArgumentParser(description='SAGE: Intrusion Alert-Driven Attack Graph Extractor.')
parser.add_argument('path_to_json_files', type=str, help='Directory containing intrusion alerts in json format. sample-input.json provides an example of the accepted file format')
parser.add_argument('experiment_name', type=str, help='Custom name for all artefacts')
parser.add_argument('-t', type=float, required=False, default=1.0, help='Time window in which duplicate alerts are discarded (default: 1.0 sec)')
parser.add_argument('-w', type=int, required=False, default=150, help='Aggregate alerts occuring in this window as one episode (default: 150 sec)')
parser.add_argument('--timerange', type=int, nargs=2, required=False, default=[0, 100], metavar=('STARTRANGE', 'ENDRANGE'), help='Filtering alerts. Only parsing from and to the specified hours, relative to the start of the alert capture (default: (0, 100))')
parser.add_argument('--dataset', required=False, type=str, choices=['cptc', 'other'], default='other', help='The name of the dataset with the alerts (default: other)')
parser.add_argument('--keep-files', action='store_true', help='Do not delete the dot files after the program ends')
args = parser.parse_args()
path_to_json_files = args.path_to_json_files
experiment_name = args.experiment_name
alert_filtering_window = args.t
alert_aggr_window = args.w
start_hour, end_hour = args.timerange
dataset_name = args.dataset
delete_files = not args.keep_files
path_to_ini = "FlexFringe/ini/spdfa-config.ini"
path_to_traces = experiment_name + '.txt'
ag_directory = experiment_name + 'AGs'
print('------ Downloading the IANA port-service mapping ------')
port_services = load_iana_mapping()
print('------ Reading alerts ------')
team_alerts, team_labels, team_start_times = load_data(path_to_json_files, alert_filtering_window, start_hour, end_hour)
plot_histogram(team_alerts, team_labels, experiment_name)
team_data = group_alerts_per_team(team_alerts, port_services)
print('------ Converting to episodes ------')
team_episodes, _ = aggregate_into_episodes(team_data, team_start_times, step=alert_aggr_window)
print('\n------ Converting to episode sequences ------')
host_data = host_episode_sequences(team_episodes)
print('------ Breaking into sub-sequences and generating traces ------')
episode_subsequences = break_into_subbehaviors(host_data)
episode_traces = generate_traces(episode_subsequences, path_to_traces)
print('------ Learning S-PDFA ------')
flexfringe(path_to_traces, ini=path_to_ini, symbol_count="2", state_count="4")
os.system("dot -Tpng " + path_to_traces + ".ff.final.dot -o " + path_to_traces + ".png")
print('------ !! Special: Fixing syntax error in main model and sink files ------')
print('--- Sinks')
with open(path_to_traces + ".ff.finalsinks.json", 'r') as file:
filedata = file.read()
stripped = re.sub(r'[\s+]', '', filedata)
extra_commas = re.search(r'(}(,+)]}$)', stripped)
if extra_commas is not None:
comma_count = (extra_commas.group(0)).count(',')
print(extra_commas.group(0), comma_count)
filedata = ''.join(filedata.rsplit(',', comma_count))
with open(path_to_traces + ".ff.finalsinks.json", 'w') as file:
file.write(filedata)
print('--- Main')
with open(path_to_traces + ".ff.final.json", 'r') as file:
filedata = file.read()
stripped = re.sub(r'[\s+]', '', filedata)
extra_commas = re.search(r'(}(,+)]}$)', stripped)
if extra_commas is not None:
comma_count = (extra_commas.group(0)).count(',')
print(extra_commas.group(0), comma_count)
filedata = ''.join(filedata.rsplit(',', comma_count))
with open(path_to_traces + ".ff.final.json", 'w') as file:
file.write(filedata)
print('------ Loading and traversing S-PDFA ------')
main_model = load_model(path_to_traces + ".ff.final.json")
sinks_model = load_model(path_to_traces + ".ff.finalsinks.json")
print('------ Encoding traces into state sequences ------')
state_sequences, severe_sinks = encode_sequences(main_model, sinks_model, episode_subsequences)
# print('------ Clustering state groups ------')
# state_groups = plot_state_groups(state_sequences, path_to_traces)
print('------ Making alert-driven AGs ------')
make_attack_graphs(state_sequences, severe_sinks, path_to_traces, ag_directory, SAVE_AG)
if delete_files:
print('Deleting extra files')
os.system("rm " + path_to_traces + ".ff.final.dot")
os.system("rm " + path_to_traces + ".ff.final.json")
os.system("rm " + path_to_traces + ".ff.finalsinks.json")
os.system("rm " + path_to_traces + ".ff.finalsinks.dot")
os.system("rm " + path_to_traces + ".ff.init.dot")
os.system("rm " + path_to_traces + ".ff.init.json")
os.system("rm " + path_to_traces + ".ff.initsinks.dot")
os.system("rm " + path_to_traces + ".ff.initsinks.json")
# os.system("rm " + "spdfa-clustered-" + path_to_traces + "-dfa.dot") # Comment out if this file is created
os.system("rm " + ag_directory + "/*.dot")
print('\n------- FIN -------')
# ----- END MAIN ------