Skip to content

Commit 8a0de60

Browse files
Attempting to cluster correlation outputs using OPTICS
1 parent 71139d1 commit 8a0de60

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

python/play.py

+72-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import colorsys
1212
import json
1313
import numpy as np
14+
from sklearn.cluster import OPTICS as optics
15+
from sklearn.metrics.pairwise import pairwise_distances
16+
from scipy.spatial.distance import cosine as cosine_similarity
1417

1518
from board import Board
1619
from features import Features
@@ -40,6 +43,9 @@
4043
# Hardcoded max board size
4144
pos_len = 19
4245

46+
# Hardcoded correlation feature length size
47+
corr_feature_len = 32
48+
4349
# Model ----------------------------------------------------------------
4450

4551
logging.root.handlers = []
@@ -138,6 +144,8 @@ def get_outputs(gs, rules):
138144
seki = seki_probs[1] - seki_probs[2]
139145
seki2 = torch.sigmoid(seki_logits[3,:,:]).cpu().numpy()
140146
scorebelief = torch.nn.functional.softmax(scorebelief_logits,dim=0).cpu().numpy()
147+
ownership_corr = torch.tanh(out_ownership_corr).cpu().numpy()
148+
futurepos_corr = torch.tanh(out_futurepos_corr).cpu().numpy()
141149

142150
board = gs.board
143151

@@ -271,7 +279,9 @@ def get_outputs(gs, rules):
271279
"seki2": seki2,
272280
"seki_by_loc2": seki_by_loc2,
273281
"scorebelief": scorebelief,
274-
"genmove_result": genmove_result
282+
"genmove_result": genmove_result,
283+
"ownership_corr": ownership_corr,
284+
"futurepos_corr": futurepos_corr
275285
}
276286

277287
def get_input_feature(gs, rules, feature_idx):
@@ -531,6 +541,19 @@ def print_scorebelief(gs,outputs):
531541
return ret
532542

533543

544+
def abs_cosine_metric(x, y):
545+
return 1 - abs(cosine_similarity(x, y))
546+
547+
548+
def corr_distances_by_cosine_metric(corr_input):
549+
clustering = optics(metric=abs_cosine_metric).fit(corr_input)
550+
if max(clustering.labels_ == -1):
551+
raise RuntimeError("No clusters found")
552+
centers = np.vstack([np.average(corr_input[clustering.labels_ == i], axis=0) for i in range(max(clustering.labels_) + 1)])
553+
results = pairwise_distances(corr_input, centers, metric='cosine')
554+
return np.transpose(results)
555+
556+
534557
# Basic parsing --------------------------------------------------------
535558
colstr = 'ABCDEFGHJKLMNOPQRST'
536559
def parse_coord(s,board):
@@ -578,6 +601,7 @@ def str_coord(loc,board):
578601
'scorebelief',
579602
'passalive',
580603
]
604+
581605
known_analyze_commands = [
582606
'gfx/Policy/policy',
583607
'gfx/Policy1/policy1',
@@ -675,6 +699,24 @@ def get_board_matrix_str(matrix, scale, formatstr):
675699
gs.boards.append(gs.board.copy())
676700
ret = str_coord(loc,gs.board)
677701

702+
elif command[0] == "genmoves":
703+
count = 10
704+
if len(command) > 1:
705+
count = int(command[1])
706+
if count < 0 or count > pos_len**2:
707+
count = 10
708+
709+
for i in range(count):
710+
outputs = get_outputs(gs, rules)
711+
loc = outputs["genmove_result"]
712+
pla = gs.board.pla
713+
714+
gs.board.play(pla,loc)
715+
gs.moves.append((pla,loc))
716+
gs.boards.append(gs.board.copy())
717+
ret += str_coord(loc,gs.board)
718+
719+
678720
elif command[0] == "name":
679721
ret = 'KataGo Raw Neural Net Debug/Test Script'
680722
elif command[0] == "version":
@@ -804,6 +846,35 @@ def get_board_matrix_str(matrix, scale, formatstr):
804846
elif command[0] == "futurepos1_raw":
805847
outputs = get_outputs(gs, rules)
806848
ret = get_board_matrix_str(outputs["futurepos"][1], 100.0, "%+7.3f")
849+
850+
elif command[0] == "ownership_corr":
851+
outputs = get_outputs(gs, rules)
852+
corr = np.reshape(outputs["ownership_corr"], (corr_feature_len, features.pos_len ** 2))
853+
corr = np.transpose(corr)
854+
try:
855+
distances = corr_distances_by_cosine_metric(corr)
856+
ret = '\n\n'.join(list(get_board_matrix_str(i, 100.0, "%+7.3f") for i in distances))
857+
except RuntimeError:
858+
ret = "No clusters found returning raw output instead\n\n"
859+
corr = np.transpose(corr)
860+
for i in range(corr_feature_len):
861+
ret += get_board_matrix_str(corr[i], 100.0, "%+7.3f")
862+
ret += '\n'
863+
864+
elif command[0] == "futurepos_corr":
865+
outputs = get_outputs(gs, rules)
866+
corr = np.reshape(outputs["futurepos_corr"], (corr_feature_len, features.pos_len ** 2))
867+
corr = np.transpose(corr)
868+
try:
869+
distances = corr_distances_by_cosine_metric(corr)
870+
ret = '\n\n'.join(list(get_board_matrix_str(i, 100.0, "%+7.3f") for i in distances))
871+
except RuntimeError:
872+
ret = "No clusters found returning raw output instead\n\n"
873+
corr = np.transpose(corr)
874+
for i in range(corr_feature_len):
875+
ret += get_board_matrix_str(corr[i], 100.0, "%+7.3f")
876+
ret += '\n'
877+
807878
elif command[0] == "seki_raw":
808879
outputs = get_outputs(gs, rules)
809880
ret = get_board_matrix_str(outputs["seki"], 100.0, "%+7.3f")

0 commit comments

Comments
 (0)