Skip to content

Commit be7835d

Browse files
committed
solve save model bug and add gcn
1 parent 391a39a commit be7835d

File tree

13 files changed

+754
-3
lines changed

13 files changed

+754
-3
lines changed

.gitignore

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,120 @@
1-
21
data
32
.idea
43
tmp
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
10+
# C extensions
11+
*.so
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
pip-wheel-metadata/
28+
share/python-wheels/
29+
*.egg-info/
30+
.installed.cfg
31+
*.egg
32+
MANIFEST
33+
34+
# PyInstaller
35+
# Usually these files are written by a python script from a template
36+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
37+
*.manifest
38+
*.spec
39+
40+
# Installer logs
41+
pip-log.txt
42+
pip-delete-this-directory.txt
43+
44+
# Unit test / coverage reports
45+
htmlcov/
46+
.tox/
47+
.nox/
48+
.coverage
49+
.coverage.*
50+
.cache
51+
nosetests.xml
52+
coverage.xml
53+
*.cover
54+
.hypothesis/
55+
.pytest_cache/
56+
57+
# Translations
58+
*.mo
59+
*.pot
60+
61+
# Django stuff:
62+
*.log
63+
local_settings.py
64+
db.sqlite3
65+
66+
# Flask stuff:
67+
instance/
68+
.webassets-cache
69+
70+
# Scrapy stuff:
71+
.scrapy
72+
73+
# Sphinx documentation
74+
docs/_build/
75+
76+
# PyBuilder
77+
target/
78+
79+
# Jupyter Notebook
80+
.ipynb_checkpoints
81+
82+
# IPython
83+
profile_default/
84+
ipython_config.py
85+
86+
# pyenv
87+
.python-version
88+
89+
# celery beat schedule file
90+
celerybeat-schedule
91+
92+
# SageMath parsed files
93+
*.sage.py
94+
95+
# Environments
96+
.env
97+
.venv
98+
env/
99+
venv/
100+
ENV/
101+
env.bak/
102+
venv.bak/
103+
104+
# Spyder project settings
105+
.spyderproject
106+
.spyproject
107+
108+
# Rope project settings
109+
.ropeproject
110+
111+
# mkdocs documentation
112+
/site
113+
114+
# mypy
115+
.mypy_cache/
116+
.dmypy.json
117+
dmypy.json
118+
119+
# Pyre type checker
120+
.pyre/

lesson28-GCN/README.MD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Graph Convolution Network for TF2
2+
3+
GCN implementation for paper: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/pdf/1609.02907.pdf)
4+
5+
# Benchmark
6+
7+
| dataset | Citeseea | Cora | Pubmed | NELL |
8+
|---------------|----------|------|--------|------|
9+
| GCN(official) | 70.3 | 81.5 | 79.0 | 66.0 |
10+
| This repo. | | 81.8 | 78.9 | |
11+
12+
# HOWTO
13+
```
14+
python train.py
15+
```
16+
17+
# Screenshot
18+
19+
![](res/screen.png)

lesson28-GCN/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import argparse
2+
3+
args = argparse.ArgumentParser()
4+
args.add_argument('--dataset', default='cora')
5+
args.add_argument('--model', default='gcn')
6+
args.add_argument('--learning_rate', default=0.01)
7+
args.add_argument('--epochs', default=200)
8+
args.add_argument('--hidden1', default=16)
9+
args.add_argument('--dropout', default=0.5)
10+
args.add_argument('--weight_decay', default=5e-4)
11+
args.add_argument('--early_stopping', default=10)
12+
args.add_argument('--max_degree', default=3)
13+
14+
15+
args = args.parse_args()
16+
print(args)

lesson28-GCN/inits.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
4+
5+
def uniform(shape, scale=0.05, name=None):
6+
"""Uniform init."""
7+
initial = tf.random.uniform(shape, minval=-scale, maxval=scale, dtype=tf.float32)
8+
return tf.Variable(initial, name=name)
9+
10+
11+
def glorot(shape, name=None):
12+
"""Glorot & Bengio (AISTATS 2010) init."""
13+
init_range = np.sqrt(6.0/(shape[0]+shape[1]))
14+
initial = tf.random.uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32)
15+
return tf.Variable(initial, name=name)
16+
17+
18+
def zeros(shape, name=None):
19+
"""All zeros."""
20+
initial = tf.zeros(shape, dtype=tf.float32)
21+
return tf.Variable(initial, name=name)
22+
23+
24+
def ones(shape, name=None):
25+
"""All ones."""
26+
initial = tf.ones(shape, dtype=tf.float32)
27+
return tf.Variable(initial, name=name)

lesson28-GCN/layers.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from inits import *
2+
import tensorflow as tf
3+
from tensorflow import keras
4+
from tensorflow.keras import layers
5+
from config import args
6+
7+
8+
9+
10+
# global unique layer ID dictionary for layer name assignment
11+
_LAYER_UIDS = {}
12+
13+
14+
def get_layer_uid(layer_name=''):
15+
"""Helper function, assigns unique layer IDs."""
16+
if layer_name not in _LAYER_UIDS:
17+
_LAYER_UIDS[layer_name] = 1
18+
return 1
19+
else:
20+
_LAYER_UIDS[layer_name] += 1
21+
return _LAYER_UIDS[layer_name]
22+
23+
24+
def sparse_dropout(x, rate, noise_shape):
25+
"""
26+
Dropout for sparse tensors.
27+
"""
28+
random_tensor = 1 - rate
29+
random_tensor += tf.random.uniform(noise_shape)
30+
dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
31+
pre_out = tf.sparse.retain(x, dropout_mask)
32+
return pre_out * (1./(1 - rate))
33+
34+
35+
def dot(x, y, sparse=False):
36+
"""
37+
Wrapper for tf.matmul (sparse vs dense).
38+
"""
39+
if sparse:
40+
res = tf.sparse.sparse_dense_matmul(x, y)
41+
else:
42+
res = tf.matmul(x, y)
43+
return res
44+
45+
46+
47+
48+
class Dense(layers.Layer):
49+
"""Dense layer."""
50+
def __init__(self, input_dim, output_dim, placeholders, dropout=0., sparse_inputs=False,
51+
act=tf.nn.relu, bias=False, featureless=False, **kwargs):
52+
super(Dense, self).__init__(**kwargs)
53+
54+
if dropout:
55+
self.dropout = placeholders['dropout']
56+
else:
57+
self.dropout = 0.
58+
59+
self.act = act
60+
self.sparse_inputs = sparse_inputs
61+
self.featureless = featureless
62+
self.bias = bias
63+
64+
# helper variable for sparse dropout
65+
self.num_features_nonzero = placeholders['num_features_nonzero']
66+
67+
with tf.variable_scope(self.name + '_vars'):
68+
self.vars['weights'] = glorot([input_dim, output_dim],
69+
name='weights')
70+
if self.bias:
71+
self.vars['bias'] = zeros([output_dim], name='bias')
72+
73+
if self.logging:
74+
self._log_vars()
75+
76+
def _call(self, inputs):
77+
x = inputs
78+
79+
# dropout
80+
if self.sparse_inputs:
81+
x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
82+
else:
83+
x = tf.nn.dropout(x, 1-self.dropout)
84+
85+
# transform
86+
output = dot(x, self.vars['weights'], sparse=self.sparse_inputs)
87+
88+
# bias
89+
if self.bias:
90+
output += self.vars['bias']
91+
92+
return self.act(output)
93+
94+
95+
class GraphConvolution(layers.Layer):
96+
"""
97+
Graph convolution layer.
98+
"""
99+
def __init__(self, input_dim, output_dim, num_features_nonzero,
100+
dropout=0.,
101+
is_sparse_inputs=False,
102+
activation=tf.nn.relu,
103+
bias=False,
104+
featureless=False, **kwargs):
105+
super(GraphConvolution, self).__init__(**kwargs)
106+
107+
self.dropout = dropout
108+
self.activation = activation
109+
self.is_sparse_inputs = is_sparse_inputs
110+
self.featureless = featureless
111+
self.bias = bias
112+
self.num_features_nonzero = num_features_nonzero
113+
114+
self.weights_ = []
115+
for i in range(1):
116+
w = self.add_variable('weight' + str(i), [input_dim, output_dim])
117+
self.weights_.append(w)
118+
if self.bias:
119+
self.bias = self.add_variable('bias', [output_dim])
120+
121+
122+
# for p in self.trainable_variables:
123+
# print(p.name, p.shape)
124+
125+
126+
127+
def call(self, inputs, training=None):
128+
x, support_ = inputs
129+
130+
# dropout
131+
if training is not False and self.is_sparse_inputs:
132+
x = sparse_dropout(x, self.dropout, self.num_features_nonzero)
133+
elif training is not False:
134+
x = tf.nn.dropout(x, self.dropout)
135+
136+
137+
# convolve
138+
supports = list()
139+
for i in range(len(support_)):
140+
if not self.featureless: # if it has features x
141+
pre_sup = dot(x, self.weights_[i], sparse=self.is_sparse_inputs)
142+
else:
143+
pre_sup = self.weights_[i]
144+
145+
support = dot(support_[i], pre_sup, sparse=True)
146+
supports.append(support)
147+
148+
output = tf.add_n(supports)
149+
150+
# bias
151+
if self.bias:
152+
output += self.bias
153+
154+
return self.activation(output)

lesson28-GCN/metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import tensorflow as tf
2+
3+
4+
def masked_softmax_cross_entropy(preds, labels, mask):
5+
"""
6+
Softmax cross-entropy loss with masking.
7+
"""
8+
loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
9+
mask = tf.cast(mask, dtype=tf.float32)
10+
mask /= tf.reduce_mean(mask)
11+
loss *= mask
12+
return tf.reduce_mean(loss)
13+
14+
15+
def masked_accuracy(preds, labels, mask):
16+
"""
17+
Accuracy with masking.
18+
"""
19+
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
20+
accuracy_all = tf.cast(correct_prediction, tf.float32)
21+
mask = tf.cast(mask, dtype=tf.float32)
22+
mask /= tf.reduce_mean(mask)
23+
accuracy_all *= mask
24+
return tf.reduce_mean(accuracy_all)

0 commit comments

Comments
 (0)