-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbar_plot.py
158 lines (124 loc) · 5.49 KB
/
bar_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/python3
# -*- coding: utf-8 -*-
__author__ = "kirintw and Billy Su"
__license__ = "GPL-2.0"
import time
import pyqtgraph as pg
from PyQt5 import QtCore, QtGui, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colorbar import ColorbarBase
from mne.viz import topomap
from channel_loca_dict import channel_dict_2D
from ws_pb import WS_PB
import time
from dialogs import Big_Bar_Plot
import gc
class Bar_Plot(QtGui.QWidget):
def __init__(self,url=None):
super().__init__()
if url is None:
url = "ws://localhost:7777"
self.ws_data = WS_PB(url=url, plot_name="PB_bar")
self.setWindowTitle("Bar Plot")
self.timer_interval = 0.5
while self.ws_data.ch_label is None:
pass
self.num_color = 8
self.big_plots = [None] * len(self.ws_data.ch_label)
self.init_tick = None
self.cur_tick = None
self.init_ui()
self.setup_signal_handler()
self.show()
self.value = 0
def init_ui(self):
self.cmap = mpl.cm.get_cmap('Dark2')
self.norm = mpl.colors.Normalize(vmin=0, vmax=8)
self.fig, (self.axes) = plt.subplots(len(self.ws_data.ch_label)+2, 1)
self.colorbar = ColorbarBase(self.axes[-1], cmap=self.cmap, norm=self.norm, ticks=[i+0.5 for i in range(8)])
self.colorbar.set_ticklabels(['Delta (1-3 Hz)',
'Theta (4-7 Hz)',
'Low Alpha (8-10 Hz)',
'High Alpha (11-12 Hz)',
'Low Beta (13-15 Hz)',
'Mid Beta (16-19 Hz)',
'High Beta (20-35 Hz)',
'Gamma (36-50 Hz)'])
self.canvas = FigureCanvas(self.fig)
self.pos = list(channel_dict_2D.values()) # get all X,Y values
self.ch_names_ = list(channel_dict_2D.keys()) # get all channel's names
self.pos, self.outlines = topomap._check_outlines(self.pos, 'head') # from mne.viz libs, normalize the pos
topomap._draw_outlines(self.axes[0], self.outlines)
self.plt_idx = [self.ch_names_.index(name) for name in self.ws_data.ch_label] # get the index of those required channels
ch_pos = [self.pos[idx]*5/6 for idx in self.plt_idx]
for idx, pos in enumerate(ch_pos):
pos += [0.47, 0.47]
self.axes[idx+1].set_position(list(pos) + [0.06, 0.06])
self.axes[idx+1].axis("off")
self.axes[idx+1].grid(True, axis='y')
self.axes[0].set_position([0, 0, 1, 1])
self.axes[-1].set_position([0.80, 0.85, 0.03, 0.13])
self.resize(1000, 800)
hlayout = QtGui.QHBoxLayout(self)
hlayout.addWidget(self.canvas)
def setup_signal_handler(self):
cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
self.timer = QtCore.QTimer()
self.timer.setInterval(self.timer_interval*1000)
self.timer.timeout.connect(self.draw)
self.timer.start()
def onclick(self, event):
for idx, ax in enumerate(self.axes[1:]):
if ax == event.inaxes:
print(self.ws_data.ch_label[idx])
if self.big_plots[idx] == None:
self.big_plots[idx] = Big_Bar_Plot(self, self.ws_data.ch_label[idx])
def big_plot_closed(self, plot):
for idx, p in enumerate(self.big_plots):
if plot == p:
self.big_plots[idx] = None
def draw(self):
if self.init_tick is None and self.ws_data.ticks:
self.init_tick = self.ws_data.ticks[0]
self.cur_tick = self.init_tick
# Delete outdated
count = 0
if self.cur_tick:
self.cur_tick += 500
for tick in self.ws_data.ticks:
if tick <= self.cur_tick + 250:
count += 1
while count > 1:
self.ws_data.ticks.pop(0)
self.ws_data.power_data.pop(0)
self.ws_data.z_all_data.pop(0)
self.ws_data.z_each_data.pop(0)
count -= 1
if self.ws_data.power_data:
ch_data = self.ws_data.power_data.pop(0)
for idx, data in enumerate(ch_data):
color = [self.cmap(self.norm(i)) for i in range(8)]
self.axes[idx+1].cla()
self.axes[idx+1].bar(list(range(8)), data, color=color)
self.axes[idx+1].spines["bottom"].set_visible(False)
self.axes[idx+1].spines["top"].set_visible(False)
self.axes[idx+1].spines["right"].set_visible(False)
self.axes[idx+1].spines["left"].set_visible(False)
self.axes[idx+1].yaxis.grid(True, c='r')
self.axes[idx+1].get_xaxis().set_visible(False)
if self.big_plots[idx] != None:
self.big_plots[idx].draw(data, color)
self.fig.canvas.draw()
if self.ws_data.ticks:
tick = self.ws_data.ticks.pop(0)
if self.ws_data.z_all_data:
z_all = self.ws_data.z_all_data.pop(0)
if self.ws_data.z_each_data:
z_each = self.ws_data.z_each_data.pop(0)
gc.collect()
if __name__ == "__main__":
app = QtGui.QApplication([])
plot = Bar_Plot()
app.exec_()