Skip to content

Commit bca0692

Browse files
authored
HParams: Port runs_data_source fetchHparamsMetadata to oss (#6318)
## Motivation for features / changes As part of the effort to surface hparams data in the time series dashboard we need to start actually fetching it. The logic existed internally but for some reason was not previously available in OSS. ## Technical description of changes I just copied the internal code to the oss implementation and made a few small adjustments to imports ## Screenshots of UI changes None ## Detailed steps to verify changes work correctly (as executed by you) Tests should pass Patch #6317, enable the feature flag, and ensure the hparams appear in the runs table ## Alternate designs / implementations considered Make a more detailed stub?
1 parent 78ceaf7 commit bca0692

File tree

4 files changed

+635
-11
lines changed

4 files changed

+635
-11
lines changed

tensorboard/webapp/runs/data_source/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ tf_ts_library(
4848
"runs_data_source_test.ts",
4949
],
5050
deps = [
51+
":backend_types",
5152
":data_source",
53+
":testing",
5254
"//tensorboard/webapp/angular:expect_angular_core_testing",
5355
"//tensorboard/webapp/webapp_data_source:http_client_testing",
5456
"@npm//@types/jasmine",

tensorboard/webapp/runs/data_source/runs_data_source.ts

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,71 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515
import {Injectable} from '@angular/core';
16-
import {Observable, of} from 'rxjs';
17-
import {map} from 'rxjs/operators';
18-
import {TBHttpClient} from '../../webapp_data_source/tb_http_client';
16+
import {Observable, of, throwError} from 'rxjs';
17+
import {catchError, map, mergeMap} from 'rxjs/operators';
1918
import {
19+
HttpErrorResponse,
20+
TBHttpClient,
21+
} from '../../webapp_data_source/tb_http_client';
22+
import * as backendTypes from './runs_backend_types';
23+
import {
24+
Domain,
25+
DomainType,
2026
HparamsAndMetadata,
27+
HparamSpec,
28+
HparamValue,
29+
MetricSpec,
2130
Run,
2231
RunsDataSource,
32+
RunToHparamsAndMetrics,
2333
} from './runs_data_source_types';
2434

35+
const HPARAMS_HTTP_PATH_PREFIX = 'data/plugin/hparams';
36+
2537
type BackendGetRunsResponse = string[];
2638

2739
function runToRunId(run: string, experimentId: string) {
2840
return `${experimentId}/${run}`;
2941
}
3042

43+
function transformBackendHparamSpec(
44+
hparamInfo: backendTypes.HparamSpec
45+
): HparamSpec {
46+
let domain: Domain;
47+
if (backendTypes.isDiscreteDomainHparamSpec(hparamInfo)) {
48+
domain = {type: DomainType.DISCRETE, values: hparamInfo.domainDiscrete};
49+
} else if (backendTypes.isIntervalDomainHparamSpec(hparamInfo)) {
50+
domain = {...hparamInfo.domainInterval, type: DomainType.INTERVAL};
51+
} else {
52+
domain = {
53+
type: DomainType.INTERVAL,
54+
minValue: -Infinity,
55+
maxValue: Infinity,
56+
};
57+
}
58+
return {
59+
description: hparamInfo.description,
60+
displayName: hparamInfo.displayName,
61+
name: hparamInfo.name,
62+
type: hparamInfo.type,
63+
domain,
64+
};
65+
}
66+
67+
function transformBackendMetricSpec(
68+
metricInfo: backendTypes.MetricSpec
69+
): MetricSpec {
70+
const {name, ...otherSpec} = metricInfo;
71+
return {
72+
...otherSpec,
73+
tag: name.tag,
74+
};
75+
}
76+
77+
declare interface GetExperimentHparamRequestPayload {
78+
experimentName: string;
79+
}
80+
3181
@Injectable()
3282
export class TBRunsDataSource implements RunsDataSource {
3383
constructor(private readonly http: TBHttpClient) {}
@@ -48,11 +98,112 @@ export class TBRunsDataSource implements RunsDataSource {
4898
}
4999

50100
fetchHparamsMetadata(experimentId: string): Observable<HparamsAndMetadata> {
51-
// Return a stub implementation.
52-
return of({
53-
hparamSpecs: [],
54-
metricSpecs: [],
55-
runToHparamsAndMetrics: {},
56-
});
101+
const requestPayload: GetExperimentHparamRequestPayload = {
102+
experimentName: experimentId,
103+
};
104+
return this.http
105+
.post<backendTypes.BackendHparamsExperimentResponse>(
106+
`/experiment/${experimentId}/${HPARAMS_HTTP_PATH_PREFIX}/experiment`,
107+
requestPayload
108+
)
109+
.pipe(
110+
map((response) => {
111+
const colParams: backendTypes.BackendListSessionGroupRequest['colParams'] =
112+
[];
113+
114+
for (const hparamInfo of response.hparamInfos) {
115+
colParams.push({hparam: hparamInfo.name});
116+
}
117+
for (const metricInfo of response.metricInfos) {
118+
colParams.push({metric: metricInfo.name});
119+
}
120+
121+
const listSessionRequestParams: backendTypes.BackendListSessionGroupRequest =
122+
{
123+
experimentName: experimentId,
124+
allowedStatuses: [
125+
backendTypes.RunStatus.STATUS_FAILURE,
126+
backendTypes.RunStatus.STATUS_RUNNING,
127+
backendTypes.RunStatus.STATUS_SUCCESS,
128+
backendTypes.RunStatus.STATUS_UNKNOWN,
129+
],
130+
colParams,
131+
startIndex: 0,
132+
// arbitrary large number so it does not get clipped.
133+
sliceSize: 1e6,
134+
};
135+
136+
return {
137+
experimentHparamsInfo: response,
138+
listSessionRequestParams,
139+
};
140+
}),
141+
mergeMap(({experimentHparamsInfo, listSessionRequestParams}) => {
142+
return this.http
143+
.post<backendTypes.BackendListSessionGroupResponse>(
144+
`/experiment/${experimentId}/${HPARAMS_HTTP_PATH_PREFIX}/session_groups`,
145+
JSON.stringify(listSessionRequestParams)
146+
)
147+
.pipe(
148+
map((sessionGroupsList) => {
149+
return {experimentHparamsInfo, sessionGroupsList};
150+
})
151+
);
152+
}),
153+
map(({experimentHparamsInfo, sessionGroupsList}) => {
154+
const runToHparamsAndMetrics: RunToHparamsAndMetrics = {};
155+
156+
// Reorganize the sessionGroup/session into run to <hparams,
157+
// metrics>.
158+
for (const sessionGroup of sessionGroupsList.sessionGroups) {
159+
const hparams: HparamValue[] = Object.entries(
160+
sessionGroup.hparams
161+
).map((keyValue) => {
162+
const [hparam, value] = keyValue;
163+
return {name: hparam, value};
164+
});
165+
166+
for (const session of sessionGroup.sessions) {
167+
for (const metricValue of session.metricValues) {
168+
const runName = metricValue.name.group
169+
? `${session.name}/${metricValue.name.group}`
170+
: session.name;
171+
const runId = `${experimentId}/${runName}`;
172+
const hparamsAndMetrics = runToHparamsAndMetrics[runId] || {
173+
metrics: [],
174+
hparams,
175+
};
176+
hparamsAndMetrics.metrics.push({
177+
tag: metricValue.name.tag,
178+
trainingStep: metricValue.trainingStep,
179+
value: metricValue.value,
180+
});
181+
runToHparamsAndMetrics[runId] = hparamsAndMetrics;
182+
}
183+
}
184+
}
185+
return {
186+
hparamSpecs: experimentHparamsInfo.hparamInfos.map(
187+
transformBackendHparamSpec
188+
),
189+
metricSpecs: experimentHparamsInfo.metricInfos.map(
190+
transformBackendMetricSpec
191+
),
192+
runToHparamsAndMetrics,
193+
};
194+
}),
195+
catchError((error) => {
196+
// HParams plugin return 400 when there are no hparams for an
197+
// experiment.
198+
if (error instanceof HttpErrorResponse && error.status === 400) {
199+
return of({
200+
hparamSpecs: [],
201+
metricSpecs: [],
202+
runToHparamsAndMetrics: {},
203+
});
204+
}
205+
return throwError(error);
206+
})
207+
);
57208
}
58209
}

0 commit comments

Comments
 (0)