Skip to content

Commit 87ad5a9

Browse files
committed
Merge captum into DashAI app
1 parent c968019 commit 87ad5a9

13 files changed

+153
-61
lines changed

captum/insights/attr_vis/app.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def serve(self, blocking=False, debug=False, port=None):
286286
return self._serve(blocking=blocking, debug=debug, port=port)
287287

288288
def _serve(self, blocking=False, debug=False, port=None):
289-
from .server import start_server
289+
# from .server import start_server
290+
from worker.app.app import start_server
290291

291292
return start_server(self, blocking=blocking, debug=debug, _port=port)
292293

Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
{
22
"files": {
33
"main.css": "./static/css/main.6b86a595.chunk.css",
4-
"main.js": "./static/js/main.f2e6456a.chunk.js",
5-
"main.js.map": "./static/js/main.f2e6456a.chunk.js.map",
4+
"main.js": "./static/js/main.22f2a356.chunk.js",
5+
"main.js.map": "./static/js/main.22f2a356.chunk.js.map",
66
"runtime-main.js": "./static/js/runtime-main.851de3a6.js",
77
"runtime-main.js.map": "./static/js/runtime-main.851de3a6.js.map",
88
"static/js/2.a727744b.chunk.js": "./static/js/2.a727744b.chunk.js",
99
"static/js/2.a727744b.chunk.js.map": "./static/js/2.a727744b.chunk.js.map",
1010
"index.html": "./index.html",
11-
"precache-manifest.6ad951b2d254443991f9c1cef999746b.js": "./precache-manifest.6ad951b2d254443991f9c1cef999746b.js",
11+
"precache-manifest.ef3d74cdb10886749b2b2b0204b52ee2.js": "./precache-manifest.ef3d74cdb10886749b2b2b0204b52ee2.js",
1212
"service-worker.js": "./service-worker.js",
1313
"static/css/main.6b86a595.chunk.css.map": "./static/css/main.6b86a595.chunk.css.map",
1414
"static/js/2.a727744b.chunk.js.LICENSE.txt": "./static/js/2.a727744b.chunk.js.LICENSE.txt"
@@ -17,6 +17,6 @@
1717
"static/js/runtime-main.851de3a6.js",
1818
"static/js/2.a727744b.chunk.js",
1919
"static/css/main.6b86a595.chunk.css",
20-
"static/js/main.f2e6456a.chunk.js"
20+
"static/js/main.22f2a356.chunk.js"
2121
]
2222
}
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1"/><title>Captum Insights</title><link href="./static/css/main.6b86a595.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,f,l=r[0],i=r[1],a=r[2],c=0,s=[];c<l.length;c++)f=l[c],Object.prototype.hasOwnProperty.call(o,f)&&o[f]&&s.push(o[f][0]),o[f]=0;for(n in i)Object.prototype.hasOwnProperty.call(i,n)&&(e[n]=i[n]);for(p&&p(r);s.length;)s.shift()();return u.push.apply(u,a||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,l=1;l<t.length;l++){var i=t[l];0!==o[i]&&(n=!1)}n&&(u.splice(r--,1),e=f(f.s=t[0]))}return e}var n={},o={1:0},u=[];function f(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,f),t.l=!0,t.exports}f.m=e,f.c=n,f.d=function(e,r,t){f.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},f.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},f.t=function(e,r){if(1&r&&(e=f(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(f.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)f.d(t,n,function(r){return e[r]}.bind(null,n));return t},f.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return f.d(r,"a",r),r},f.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},f.p="./";var l=this.webpackJsonpfrontend=this.webpackJsonpfrontend||[],i=l.push.bind(l);l.push=r,l=l.slice();for(var a=0;a<l.length;a++)r(l[a]);var p=i;t()}([])</script><script src="./static/js/2.a727744b.chunk.js"></script><script src="./static/js/main.f2e6456a.chunk.js"></script></body></html>
1+
<!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1"/><title>DashAI Insights</title><link href="./static/css/main.6b86a595.chunk.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div><script>!function(e){function r(r){for(var n,f,l=r[0],i=r[1],a=r[2],c=0,s=[];c<l.length;c++)f=l[c],Object.prototype.hasOwnProperty.call(o,f)&&o[f]&&s.push(o[f][0]),o[f]=0;for(n in i)Object.prototype.hasOwnProperty.call(i,n)&&(e[n]=i[n]);for(p&&p(r);s.length;)s.shift()();return u.push.apply(u,a||[]),t()}function t(){for(var e,r=0;r<u.length;r++){for(var t=u[r],n=!0,l=1;l<t.length;l++){var i=t[l];0!==o[i]&&(n=!1)}n&&(u.splice(r--,1),e=f(f.s=t[0]))}return e}var n={},o={1:0},u=[];function f(r){if(n[r])return n[r].exports;var t=n[r]={i:r,l:!1,exports:{}};return e[r].call(t.exports,t,t.exports,f),t.l=!0,t.exports}f.m=e,f.c=n,f.d=function(e,r,t){f.o(e,r)||Object.defineProperty(e,r,{enumerable:!0,get:t})},f.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},f.t=function(e,r){if(1&r&&(e=f(e)),8&r)return e;if(4&r&&"object"==typeof e&&e&&e.__esModule)return e;var t=Object.create(null);if(f.r(t),Object.defineProperty(t,"default",{enumerable:!0,value:e}),2&r&&"string"!=typeof e)for(var n in e)f.d(t,n,function(r){return e[r]}.bind(null,n));return t},f.n=function(e){var r=e&&e.__esModule?function(){return e.default}:function(){return e};return f.d(r,"a",r),r},f.o=function(e,r){return Object.prototype.hasOwnProperty.call(e,r)},f.p="./";var l=this.webpackJsonpfrontend=this.webpackJsonpfrontend||[],i=l.push.bind(l);l.push=r,l=l.slice();for(var a=0;a<l.length;a++)r(l[a]);var p=i;t()}([])</script><script src="./static/js/2.a727744b.chunk.js"></script><script src="./static/js/main.22f2a356.chunk.js"></script></body></html>

captum/insights/attr_vis/frontend/build/precache-manifest.6ad951b2d254443991f9c1cef999746b.js

-26
This file was deleted.

captum/insights/attr_vis/frontend/build/service-worker.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
importScripts("https://storage.googleapis.com/workbox-cdn/releases/4.3.1/workbox-sw.js");
1515

1616
importScripts(
17-
"./precache-manifest.6ad951b2d254443991f9c1cef999746b.js"
17+
"./precache-manifest.ef3d74cdb10886749b2b2b0204b52ee2.js"
1818
);
1919

2020
self.addEventListener('message', (event) => {

captum/insights/attr_vis/frontend/build/static/js/main.f2e6456a.chunk.js

-2
This file was deleted.

captum/insights/attr_vis/frontend/build/static/js/main.f2e6456a.chunk.js.map

-1
This file was deleted.

captum/insights/attr_vis/server.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def start_server(
9999
t = threading.Thread(target=run_app, kwargs={"debug": debug})
100100
t.start()
101101
sleep(0.01) # add a short delay to allow server to start up
102+
print(t, blocking, sep='\n')
102103
if blocking:
103104
t.join()
104105

data/response.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@
366366
"lin_ftrs": [
367367
512
368368
],
369-
"ps": 0.7,
369+
"ps": 0.5,
370370
"custom_head": null,
371371
"split_on": null,
372372
"bn_final": false,

data/train.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{
22
"training": {
3-
"type": "fit",
3+
"type": "fit_one_cycle",
44
"fit": {
55
"epochs": 0,
66
"lr": "slice(None, 0.003, None)",
77
"wd": null
88
},
99
"fit_one_cycle": {
10-
"cyc_len": 1,
10+
"cyc_len": 0,
1111
"max_lr": "slice(None, 0.003, None)",
1212
"moms": "(0.95, 0.85)",
1313
"div_factor": 25,

data/verum.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"metric": {
44
"name": "error",
55
"minimize": true,
6-
"num_trials": 1
6+
"num_trials": null
77
},
88
"learning_rate": {
99
"flag": true,

worker/__init__.py

Whitespace-only changes.

worker/app/app.py

+140-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from flask import Flask, render_template, request, jsonify
22
import json
33
import torch
4+
from torch import Tensor
45
import fastai
56
import os
67
from pathlib import Path
@@ -18,9 +19,19 @@
1819
from core.train import DashTrain
1920
from insights.DashInsights import DashInsights
2021
from captum.insights.attr_vis import AttributionVisualizer
21-
from flask_socketio import SocketIO, emit
22+
# from flask_socketio import SocketIO, emit
2223
from flask_cors import CORS
2324

25+
import logging
26+
import socket
27+
import threading
28+
from time import sleep
29+
from typing import Optional
30+
from captum.log import log_usage
31+
32+
visualizer = None
33+
port = None
34+
2435
# import sys
2536
# sys.stdout = open('result.txt', 'w')
2637

@@ -29,7 +40,10 @@
2940
# with open('result.txt', 'w+') as f:
3041
# f.write("worked here")
3142

32-
app = Flask(__name__, template_folder="template")
43+
app = Flask(
44+
__name__, static_folder="../../captum/insights/attr_vis/frontend/build/static",
45+
template_folder="../../captum/insights/attr_vis/frontend/build"
46+
)
3347
app.config['SECRET_KEY'] = 'some-super-secret-key'
3448
app.config['DEFAULT_PARSERS'] = [
3549
'flask.ext.api.parsers.JSONParser',
@@ -62,7 +76,7 @@
6276
@app.route("/", methods=['GET'])
6377
def helper():
6478
print("Started")
65-
return render_template("helper.html")
79+
return render_template("index.html")
6680

6781
@app.route('/gethome', methods=['GET'])
6882
def gethome():
@@ -117,19 +131,19 @@ def train():
117131
json.dump(data, outfile, indent=4)
118132

119133
print('STEP 2 (optional): Optimizing the hyper-parameters.')
120-
try:
121-
import ax
122-
print(ax.__version__)
123-
from verum.DashVerum import DashVerum
124-
step_2 = True
125-
with open('./data/verum.json') as f:
126-
response = json.load(f)
127-
verum = DashVerum(response, data, learn)
128-
learn, lr, num_epochs, moms = verum.veritize()
129-
print('Hyper-parameters optimized; completed step 2.')
134+
# try:
135+
# import ax
136+
# print(ax.__version__)
137+
# from verum.DashVerum import DashVerum
138+
# step_2 = True
139+
# with open('./data/verum.json') as f:
140+
# response = json.load(f)
141+
# verum = DashVerum(response, data, learn)
142+
# learn, lr, num_epochs, moms = verum.veritize()
143+
# print('Hyper-parameters optimized; completed step 2.')
130144

131-
except ImportError:
132-
print('Skipping step 2 as the module `ax` is not installed.')
145+
# except ImportError:
146+
# print('Skipping step 2 as the module `ax` is not installed.')
133147

134148

135149

@@ -148,10 +162,11 @@ def start():
148162
}
149163

150164
global all_processes
151-
152-
process = multiprocessing.Process(target=training_worker, args=())
153-
process.start()
154-
all_processes.append(process)
165+
training_worker()
166+
# process = multiprocessing.Process(target=training_worker, args=())
167+
# process.start()
168+
# process.join()
169+
# all_processes.append(process)
155170
return jsonify(res)
156171

157172
@app.route("/stop", methods=['GET'])
@@ -222,13 +237,117 @@ def training_worker():
222237
print('-' * 50)
223238
# print('Now we need to add production-serving.')
224239
print('COMPLETE')
240+
global all_processes
241+
print(all_processes)
225242

226243

227244
# @socketio.on('connect')
228245
# def talk_to_me():
229246
# print('after connect', {'data':'Lets dance'})
230247

248+
#socketio.run(app, port=5001, debug=True)
249+
250+
251+
# ----------------------------------------
252+
253+
#!/usr/bin/env python3
254+
255+
256+
# from flask import Flask, jsonify, render_template, request
257+
# from torch import Tensor
258+
259+
260+
261+
def namedtuple_to_dict(obj):
262+
if isinstance(obj, Tensor):
263+
return obj.item()
264+
if hasattr(obj, "_asdict"): # detect namedtuple
265+
return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj)))
266+
elif isinstance(obj, str): # iterables - strings
267+
return obj
268+
elif hasattr(obj, "keys"): # iterables - mapping
269+
return dict(
270+
zip(obj.keys(), (namedtuple_to_dict(item) for item in obj.values()))
271+
)
272+
elif hasattr(obj, "__iter__"): # iterables - sequence
273+
return type(obj)((namedtuple_to_dict(item) for item in obj))
274+
else: # non-iterable cannot contain namedtuples
275+
return obj
276+
277+
278+
@app.route("/attribute", methods=["POST"])
279+
def attribute():
280+
# force=True needed for Colab notebooks, which doesn't use the correct
281+
# Content-Type header when forwarding requests through the Colab proxy
282+
r = request.get_json(force=True)
283+
return jsonify(
284+
namedtuple_to_dict(
285+
visualizer._calculate_attribution_from_cache(r["instance"], r["labelIndex"])
286+
)
287+
)
288+
289+
290+
@app.route("/fetch", methods=["POST"])
291+
def fetch():
292+
# force=True needed, see comment for "/attribute" route above
293+
global visualizer
294+
visualizer._update_config(request.get_json(force=True))
295+
visualizer_output = visualizer.visualize()
296+
clean_output = namedtuple_to_dict(visualizer_output)
297+
return jsonify(clean_output)
298+
299+
300+
@app.route("/init")
301+
def init():
302+
visualizer
303+
return jsonify(visualizer.get_insights_config())
304+
305+
306+
@app.route("/")
307+
def index(id=0):
308+
return render_template("index.html")
309+
310+
311+
def get_free_tcp_port():
312+
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
313+
tcp.bind(("", 0))
314+
addr, port = tcp.getsockname()
315+
tcp.close()
316+
return port
317+
318+
319+
def run_app(debug: bool = True):
320+
app.run(port=port, use_reloader=False, debug=debug)
321+
322+
323+
@log_usage()
324+
def start_server(
325+
_viz, blocking: bool = False, debug: bool = False, _port: Optional[int] = None
326+
):
327+
global visualizer
328+
visualizer = _viz
329+
330+
global port
331+
if port is None:
332+
os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message
333+
if not debug:
334+
log = logging.getLogger("werkzeug")
335+
log.disabled = True
336+
app.logger.disabled = True
337+
338+
# port = _port or get_free_tcp_port()
339+
port = 5003
340+
# Start in a new thread to not block notebook execution
341+
t = threading.Thread(target=run_app, kwargs={"debug": debug})
342+
t.start()
343+
sleep(0.01) # add a short delay to allow server to start up
344+
print(t, blocking, sep='\n')
345+
if blocking:
346+
t.join()
347+
348+
print(f"\nFetch data and view Captum Insights at http://localhost:{port}/\n")
349+
return port
350+
231351

232352
if __name__ == "__main__":
233-
app.run(debug=True, port=5001)
234-
#socketio.run(app, port=5001, debug=True)
353+
app.run(debug=True, port=5001)

0 commit comments

Comments
 (0)