Skip to content

Commit

Permalink
upload the code
Browse files Browse the repository at this point in the history
  • Loading branch information
YingtongDou committed Aug 21, 2020
1 parent 4c172a9 commit 63a8555
Show file tree
Hide file tree
Showing 13 changed files with 1,188 additions and 2 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

*.xml
.idea/$CACHE_FILE$
*.iml
.idea/dictionaries
*.mat
backup.py
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 81 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,81 @@
# CARE-GNN
Code CIKM 2020 paper Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters
# Nash-Detect

PyTorch implementation for CIKM 2020 paper **Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters**.
[Yingtong Dou](http://ytongdou.com/), [Zhiwei Liu](https://sites.google.com/view/zhiwei-jim), [Li Sun](https://www.researchgate.net/profile/Li_Sun118), Yutong Deng, [Hao Peng](https://penghao-buaa.github.io/)[Philip S. Yu](https://www.cs.uic.edu/PSYu/).
\[[Paper](https://arxiv.org/pdf/2008.08692.pdf)\]\[[Toolbox](https://github.com/safe-graph/DGFraud)\]

## Overview

<p align="center">
<br>
<a href="https://github.com/YingtongDou/CARE-GNN">
<img src="https://github.com/YingtongDou/Nash-Detect/blob/master/model.png" width="600"/>
</a>
<br>
<p>

**CA**mouflage-**RE**sistant **G**raph **N**eural **N**etwork **(CARE-GNN)** is an GNN-based fraud detector based on multi-relation graph equipped with three modules that enhance its performance against camouflaged fraudsters.

Three enhancement modules are:

- **A label-aware similarity measure** which measures the similarity scores between a center node and its neighboring nodes;
- **A similarity-aware neighbor selector** which leverages top-p sampling and reinforcement learning to select the optimal amount of neighbors under each relation;
- **A relation-aware neighbor aggregator** which directly aggeragate information from different relations using the optimal neighbor selection thresholds as weights.

CARE-GNN has following advantages:

- **Adaptability.** CARE-GNN adaptively selects best neighbors
for aggregation given arbitrary multi-relation graph;
- **High-efficiency.** CARE-GNN has a high computational efficiency without attention and deep reinforcement learning;
- **Flexibility.** Many other neural modules and external knowledge can be plugged into the CARE-GNN;

We have integrated more than **eight** GNN-based fraud detectors as a TensorFlow [toolbox](https://github.com/safe-graph/DGFraud)

## Setup

You can download the project and install required packages using following commands:

```bash
git clone https://github.com/YingtongDou/CARE-GNN.git
cd CARE-GNN
pip3 install -r requirements.txt
```

To run the code, you need to have at least **Python 3.6** or later version.

## Running

1. In CARE-GNN directory, run `unzip /data/Amazon.zip` and `unzip /data/YelpChi.zip` to unzip the datasets;
2. Run `python data_process.py` to generate adjacency lists used by CARE-GNN;
3. Run `python -m train.py` to run CARE-GNN with default settings.

For the other dataset and parameter settings, please refer to the argparser in `train.py`. Our model supports both CPU and GPU mode

## Running on your datasets

To run CARE-GNN on your datasets, you need to prepare following data:

- Multiple-single relation graphs with same nodes where each graph is stored in `scipy.sparse` matrix format, you can use `sparse_to_adjlist()` in `utils.py` to tranfer the sparse matrix into adjacency lists used by CARE-GNN;
- A numpy array with node labels, currently, CARE-GNN only supoorts binary classification;
- A node feature matrix stored in `scipy.sparse` matrix format.

### Repo Structure
The repository is organized as follows:
- `data/`: dataset files;
- `data_process.py`: transfer sparse matrix to adjacency lists;
- `graphsage.py`: model code for vanilla [GraphSAGE](https://github.com/williamleif/graphsage-simple/) model;
- `layers.py`: CARE-GNN layers implementations;
- `model.py`: CARE-GNN model implementations;
- `train.py`: training and testing all models;
- `utils.py`: utility functions for data i\o and model evaluation.

## Citation
If you use our code, please cite the paper below:
```bibtex
@inproceedings{dou2020Enhancing,
title={Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters},
author={Dou, Yingtong and Liu, Zhiwei and Sun, Li and Deng, Yutong and Peng, Hao and Yu, Philip S},
booktitle={Proceedings of the 29th ACM International Conference on Information and Knowledge Management (CIKM'20)},
year={2020}
}
```
Binary file added data/Amazon.zip
Binary file not shown.
Binary file added data/YelpChi.zip
Binary file not shown.
33 changes: 33 additions & 0 deletions data_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from utils import sparse_to_adjlist
from scipy.io import loadmat

"""
Read data and save the adjacency matrices to adjacency lists
"""


if __name__ == "__main__":

prefix = 'data/'

yelp = loadmat('data/YelpChi.mat')
net_rur = yelp['net_rur']
net_rtr = yelp['net_rtr']
net_rsr = yelp['net_rsr']
yelp_homo = yelp['homo']

sparse_to_adjlist(net_rur, prefix + 'yelp_rur_adjlists.pickle')
sparse_to_adjlist(net_rtr, prefix + 'yelp_rtr_adjlists.pickle')
sparse_to_adjlist(net_rsr, prefix + 'yelp_rsr_adjlists.pickle')
sparse_to_adjlist(yelp_homo, prefix + 'yelp_homo_adjlists.pickle')

amz = loadmat('data/Amazon.mat')
net_upu = amz['net_upu']
net_usu = amz['net_usu']
net_uvu = amz['net_uvu']
amz_homo = amz['homo']

sparse_to_adjlist(net_upu, prefix + 'amz_upu_adjlists.pickle')
sparse_to_adjlist(net_usu, prefix + 'amz_usu_adjlists.pickle')
sparse_to_adjlist(net_uvu, prefix + 'amz_uvu_adjlists.pickle')
sparse_to_adjlist(amz_homo, prefix + 'amz_homo_adjlists.pickle')
150 changes: 150 additions & 0 deletions graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.autograd import Variable
import random


"""
GraphSAGE implementations
Paper: Inductive Representation Learning on Large Graphs
Source: https://github.com/williamleif/graphsage-simple/
"""


class GraphSage(nn.Module):
"""
Vanilla GraphSAGE Model
Code partially from https://github.com/williamleif/graphsage-simple/
"""
def __init__(self, num_classes, enc):
super(GraphSage, self).__init__()
self.enc = enc
self.xent = nn.CrossEntropyLoss()
self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
init.xavier_uniform_(self.weight)

def forward(self, nodes):
embeds = self.enc(nodes)
scores = self.weight.mm(embeds)
return scores.t()

def to_prob(self, nodes):
pos_scores = torch.sigmoid(self.forward(nodes))
return pos_scores

def loss(self, nodes, labels):
scores = self.forward(nodes)
return self.xent(scores, labels.squeeze())


class MeanAggregator(nn.Module):
"""
Aggregates a node's embeddings using mean of neighbors' embeddings
"""

def __init__(self, features, cuda=False, gcn=False):
"""
Initializes the aggregator for a specific graph.
features -- function mapping LongTensor of node ids to FloatTensor of feature values.
cuda -- whether to use GPU
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
"""

super(MeanAggregator, self).__init__()

self.features = features
self.cuda = cuda
self.gcn = gcn

def forward(self, nodes, to_neighs, num_sample=10):
"""
nodes --- list of nodes in a batch
to_neighs --- list of sets, each set is the set of neighbors for node in batch
num_sample --- number of neighbors to sample. No sampling if None.
"""
# Local pointers to functions (speed hack)
_set = set
if not num_sample is None:
_sample = random.sample
samp_neighs = [_set(_sample(to_neigh,
num_sample,
)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
else:
samp_neighs = to_neighs

if self.gcn:
samp_neighs = [samp_neigh.union(set([int(nodes[i])])) for i, samp_neigh in enumerate(samp_neighs)]
unique_nodes_list = list(set.union(*samp_neighs))
unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)}
mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
mask[row_indices, column_indices] = 1
if self.cuda:
mask = mask.cuda()
num_neigh = mask.sum(1, keepdim=True)
mask = mask.div(num_neigh)
if self.cuda:
embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
else:
embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
to_feats = mask.mm(embed_matrix)
return to_feats


class Encoder(nn.Module):
"""
Vanilla GraphSAGE Encoder Module
Encodes a node's using 'convolutional' GraphSage approach
"""

def __init__(self, features, feature_dim,
embed_dim, adj_lists, aggregator,
num_sample=10,
base_model=None, gcn=False, cuda=False,
feature_transform=False):
super(Encoder, self).__init__()

self.features = features
self.feat_dim = feature_dim
self.adj_lists = adj_lists
self.aggregator = aggregator
self.num_sample = num_sample
if base_model != None:
self.base_model = base_model

self.gcn = gcn
self.embed_dim = embed_dim
self.cuda = cuda
self.aggregator.cuda = cuda
self.weight = nn.Parameter(
torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
init.xavier_uniform_(self.weight)

def forward(self, nodes):
"""
Generates embeddings for a batch of nodes.
nodes -- list of nodes
"""
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes],
self.num_sample)

if isinstance(nodes, list):
index = torch.LongTensor(nodes).cuda()
else:
index = nodes

if not self.gcn:
if self.cuda:
self_feats = self.features(index)
else:
self_feats = self.features(index)
combined = torch.cat((self_feats, neigh_feats), dim=1)
else:
combined = neigh_feats
combined = F.relu(self.weight.mm(combined.t()))
return combined
Loading

0 comments on commit 63a8555

Please sign in to comment.