|
11 | 11 | import colorsys
|
12 | 12 | import json
|
13 | 13 | import numpy as np
|
| 14 | +from sklearn.cluster import OPTICS as optics |
| 15 | +from sklearn.metrics.pairwise import pairwise_distances |
| 16 | +from scipy.spatial.distance import cosine as cosine_similarity |
14 | 17 |
|
15 | 18 | from board import Board
|
16 | 19 | from features import Features
|
|
40 | 43 | # Hardcoded max board size
|
41 | 44 | pos_len = 19
|
42 | 45 |
|
| 46 | +# Hardcoded correlation feature length size |
| 47 | +corr_feature_len = 32 |
| 48 | + |
43 | 49 | # Model ----------------------------------------------------------------
|
44 | 50 |
|
45 | 51 | logging.root.handlers = []
|
@@ -138,6 +144,8 @@ def get_outputs(gs, rules):
|
138 | 144 | seki = seki_probs[1] - seki_probs[2]
|
139 | 145 | seki2 = torch.sigmoid(seki_logits[3,:,:]).cpu().numpy()
|
140 | 146 | scorebelief = torch.nn.functional.softmax(scorebelief_logits,dim=0).cpu().numpy()
|
| 147 | + ownership_corr = torch.tanh(out_ownership_corr).cpu().numpy() |
| 148 | + futurepos_corr = torch.tanh(out_futurepos_corr).cpu().numpy() |
141 | 149 |
|
142 | 150 | board = gs.board
|
143 | 151 |
|
@@ -271,7 +279,9 @@ def get_outputs(gs, rules):
|
271 | 279 | "seki2": seki2,
|
272 | 280 | "seki_by_loc2": seki_by_loc2,
|
273 | 281 | "scorebelief": scorebelief,
|
274 |
| - "genmove_result": genmove_result |
| 282 | + "genmove_result": genmove_result, |
| 283 | + "ownership_corr": ownership_corr, |
| 284 | + "futurepos_corr": futurepos_corr |
275 | 285 | }
|
276 | 286 |
|
277 | 287 | def get_input_feature(gs, rules, feature_idx):
|
@@ -531,6 +541,19 @@ def print_scorebelief(gs,outputs):
|
531 | 541 | return ret
|
532 | 542 |
|
533 | 543 |
|
| 544 | +def abs_cosine_metric(x, y): |
| 545 | + return 1 - abs(cosine_similarity(x, y)) |
| 546 | + |
| 547 | + |
| 548 | +def corr_distances_by_cosine_metric(corr_input): |
| 549 | + clustering = optics(metric=abs_cosine_metric).fit(corr_input) |
| 550 | + if max(clustering.labels_ == -1): |
| 551 | + raise RuntimeError("No clusters found") |
| 552 | + centers = np.vstack([np.average(corr_input[clustering.labels_ == i], axis=0) for i in range(max(clustering.labels_) + 1)]) |
| 553 | + results = pairwise_distances(corr_input, centers, metric='cosine') |
| 554 | + return np.transpose(results) |
| 555 | + |
| 556 | + |
534 | 557 | # Basic parsing --------------------------------------------------------
|
535 | 558 | colstr = 'ABCDEFGHJKLMNOPQRST'
|
536 | 559 | def parse_coord(s,board):
|
@@ -578,6 +601,7 @@ def str_coord(loc,board):
|
578 | 601 | 'scorebelief',
|
579 | 602 | 'passalive',
|
580 | 603 | ]
|
| 604 | + |
581 | 605 | known_analyze_commands = [
|
582 | 606 | 'gfx/Policy/policy',
|
583 | 607 | 'gfx/Policy1/policy1',
|
@@ -675,6 +699,24 @@ def get_board_matrix_str(matrix, scale, formatstr):
|
675 | 699 | gs.boards.append(gs.board.copy())
|
676 | 700 | ret = str_coord(loc,gs.board)
|
677 | 701 |
|
| 702 | + elif command[0] == "genmoves": |
| 703 | + count = 10 |
| 704 | + if len(command) > 1: |
| 705 | + count = int(command[1]) |
| 706 | + if count < 0 or count > pos_len**2: |
| 707 | + count = 10 |
| 708 | + |
| 709 | + for i in range(count): |
| 710 | + outputs = get_outputs(gs, rules) |
| 711 | + loc = outputs["genmove_result"] |
| 712 | + pla = gs.board.pla |
| 713 | + |
| 714 | + gs.board.play(pla,loc) |
| 715 | + gs.moves.append((pla,loc)) |
| 716 | + gs.boards.append(gs.board.copy()) |
| 717 | + ret += str_coord(loc,gs.board) |
| 718 | + |
| 719 | + |
678 | 720 | elif command[0] == "name":
|
679 | 721 | ret = 'KataGo Raw Neural Net Debug/Test Script'
|
680 | 722 | elif command[0] == "version":
|
@@ -804,6 +846,35 @@ def get_board_matrix_str(matrix, scale, formatstr):
|
804 | 846 | elif command[0] == "futurepos1_raw":
|
805 | 847 | outputs = get_outputs(gs, rules)
|
806 | 848 | ret = get_board_matrix_str(outputs["futurepos"][1], 100.0, "%+7.3f")
|
| 849 | + |
| 850 | + elif command[0] == "ownership_corr": |
| 851 | + outputs = get_outputs(gs, rules) |
| 852 | + corr = np.reshape(outputs["ownership_corr"], (corr_feature_len, features.pos_len ** 2)) |
| 853 | + corr = np.transpose(corr) |
| 854 | + try: |
| 855 | + distances = corr_distances_by_cosine_metric(corr) |
| 856 | + ret = '\n\n'.join(list(get_board_matrix_str(i, 100.0, "%+7.3f") for i in distances)) |
| 857 | + except RuntimeError: |
| 858 | + ret = "No clusters found returning raw output instead\n\n" |
| 859 | + corr = np.transpose(corr) |
| 860 | + for i in range(corr_feature_len): |
| 861 | + ret += get_board_matrix_str(corr[i], 100.0, "%+7.3f") |
| 862 | + ret += '\n' |
| 863 | + |
| 864 | + elif command[0] == "futurepos_corr": |
| 865 | + outputs = get_outputs(gs, rules) |
| 866 | + corr = np.reshape(outputs["futurepos_corr"], (corr_feature_len, features.pos_len ** 2)) |
| 867 | + corr = np.transpose(corr) |
| 868 | + try: |
| 869 | + distances = corr_distances_by_cosine_metric(corr) |
| 870 | + ret = '\n\n'.join(list(get_board_matrix_str(i, 100.0, "%+7.3f") for i in distances)) |
| 871 | + except RuntimeError: |
| 872 | + ret = "No clusters found returning raw output instead\n\n" |
| 873 | + corr = np.transpose(corr) |
| 874 | + for i in range(corr_feature_len): |
| 875 | + ret += get_board_matrix_str(corr[i], 100.0, "%+7.3f") |
| 876 | + ret += '\n' |
| 877 | + |
807 | 878 | elif command[0] == "seki_raw":
|
808 | 879 | outputs = get_outputs(gs, rules)
|
809 | 880 | ret = get_board_matrix_str(outputs["seki"], 100.0, "%+7.3f")
|
|
0 commit comments