-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph_structure.py
More file actions
39 lines (31 loc) · 1.49 KB
/
graph_structure.py
File metadata and controls
39 lines (31 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch_geometric.data import HeteroData, Data
from feature_extraction import tokenize_text, encode_tokens
import torch
import torch_geometric.transforms as T
def generate_graph_from_text(text_sample, label, tokenizer, llm, context_window=3):
tokenized_sample = tokenize_text(text_sample, tokenizer=tokenizer)
token_vectors = encode_tokens(tokenized_sample, bert_model=llm)
node_ids = tokenized_sample['input_ids'][0].tolist()[1:-1]
node_features = token_vectors.last_hidden_state[0, 1:-1, :]
edge_index = [[], []]
for i_from in range(0, len(node_ids)):
for i_to in range(i_from+1, i_from+context_window):
if i_to < len(node_ids):
edge_index[0].append(i_from)
edge_index[1].append(i_to)
hyperedge_index = [[], []]
node_ids = torch.LongTensor(node_ids)
unique_elements, counts = node_ids.unique(return_counts=True)
at_least_twice = unique_elements[counts >= 2]
for idx, val in enumerate(at_least_twice.tolist()):
val_pos = torch.nonzero(node_ids == val).squeeze()
for index_val_pos in val_pos:
hyperedge_index[0].append(index_val_pos.item())
hyperedge_index[1].append(idx)
graph = Data(x=node_features,
edge_index=torch.LongTensor(edge_index),
y=torch.LongTensor([label]),
hyperedge_index=torch.LongTensor(hyperedge_index))
graph = T.RemoveDuplicatedEdges()(graph)
graph = T.ToUndirected()(graph)
return graph