@@ -13,21 +13,71 @@ See the License for the specific language governing permissions and
13
13
limitations under the License.
14
14
==============================================================================*/
15
15
import { Injectable } from '@angular/core' ;
16
- import { Observable , of } from 'rxjs' ;
17
- import { map } from 'rxjs/operators' ;
18
- import { TBHttpClient } from '../../webapp_data_source/tb_http_client' ;
16
+ import { Observable , of , throwError } from 'rxjs' ;
17
+ import { catchError , map , mergeMap } from 'rxjs/operators' ;
19
18
import {
19
+ HttpErrorResponse ,
20
+ TBHttpClient ,
21
+ } from '../../webapp_data_source/tb_http_client' ;
22
+ import * as backendTypes from './runs_backend_types' ;
23
+ import {
24
+ Domain ,
25
+ DomainType ,
20
26
HparamsAndMetadata ,
27
+ HparamSpec ,
28
+ HparamValue ,
29
+ MetricSpec ,
21
30
Run ,
22
31
RunsDataSource ,
32
+ RunToHparamsAndMetrics ,
23
33
} from './runs_data_source_types' ;
24
34
35
+ const HPARAMS_HTTP_PATH_PREFIX = 'data/plugin/hparams' ;
36
+
25
37
type BackendGetRunsResponse = string [ ] ;
26
38
27
39
function runToRunId ( run : string , experimentId : string ) {
28
40
return `${ experimentId } /${ run } ` ;
29
41
}
30
42
43
+ function transformBackendHparamSpec (
44
+ hparamInfo : backendTypes . HparamSpec
45
+ ) : HparamSpec {
46
+ let domain : Domain ;
47
+ if ( backendTypes . isDiscreteDomainHparamSpec ( hparamInfo ) ) {
48
+ domain = { type : DomainType . DISCRETE , values : hparamInfo . domainDiscrete } ;
49
+ } else if ( backendTypes . isIntervalDomainHparamSpec ( hparamInfo ) ) {
50
+ domain = { ...hparamInfo . domainInterval , type : DomainType . INTERVAL } ;
51
+ } else {
52
+ domain = {
53
+ type : DomainType . INTERVAL ,
54
+ minValue : - Infinity ,
55
+ maxValue : Infinity ,
56
+ } ;
57
+ }
58
+ return {
59
+ description : hparamInfo . description ,
60
+ displayName : hparamInfo . displayName ,
61
+ name : hparamInfo . name ,
62
+ type : hparamInfo . type ,
63
+ domain,
64
+ } ;
65
+ }
66
+
67
+ function transformBackendMetricSpec (
68
+ metricInfo : backendTypes . MetricSpec
69
+ ) : MetricSpec {
70
+ const { name, ...otherSpec } = metricInfo ;
71
+ return {
72
+ ...otherSpec ,
73
+ tag : name . tag ,
74
+ } ;
75
+ }
76
+
77
+ declare interface GetExperimentHparamRequestPayload {
78
+ experimentName : string ;
79
+ }
80
+
31
81
@Injectable ( )
32
82
export class TBRunsDataSource implements RunsDataSource {
33
83
constructor ( private readonly http : TBHttpClient ) { }
@@ -48,11 +98,112 @@ export class TBRunsDataSource implements RunsDataSource {
48
98
}
49
99
50
100
fetchHparamsMetadata ( experimentId : string ) : Observable < HparamsAndMetadata > {
51
- // Return a stub implementation.
52
- return of ( {
53
- hparamSpecs : [ ] ,
54
- metricSpecs : [ ] ,
55
- runToHparamsAndMetrics : { } ,
56
- } ) ;
101
+ const requestPayload : GetExperimentHparamRequestPayload = {
102
+ experimentName : experimentId ,
103
+ } ;
104
+ return this . http
105
+ . post < backendTypes . BackendHparamsExperimentResponse > (
106
+ `/experiment/${ experimentId } /${ HPARAMS_HTTP_PATH_PREFIX } /experiment` ,
107
+ requestPayload
108
+ )
109
+ . pipe (
110
+ map ( ( response ) => {
111
+ const colParams : backendTypes . BackendListSessionGroupRequest [ 'colParams' ] =
112
+ [ ] ;
113
+
114
+ for ( const hparamInfo of response . hparamInfos ) {
115
+ colParams . push ( { hparam : hparamInfo . name } ) ;
116
+ }
117
+ for ( const metricInfo of response . metricInfos ) {
118
+ colParams . push ( { metric : metricInfo . name } ) ;
119
+ }
120
+
121
+ const listSessionRequestParams : backendTypes . BackendListSessionGroupRequest =
122
+ {
123
+ experimentName : experimentId ,
124
+ allowedStatuses : [
125
+ backendTypes . RunStatus . STATUS_FAILURE ,
126
+ backendTypes . RunStatus . STATUS_RUNNING ,
127
+ backendTypes . RunStatus . STATUS_SUCCESS ,
128
+ backendTypes . RunStatus . STATUS_UNKNOWN ,
129
+ ] ,
130
+ colParams,
131
+ startIndex : 0 ,
132
+ // arbitrary large number so it does not get clipped.
133
+ sliceSize : 1e6 ,
134
+ } ;
135
+
136
+ return {
137
+ experimentHparamsInfo : response ,
138
+ listSessionRequestParams,
139
+ } ;
140
+ } ) ,
141
+ mergeMap ( ( { experimentHparamsInfo, listSessionRequestParams} ) => {
142
+ return this . http
143
+ . post < backendTypes . BackendListSessionGroupResponse > (
144
+ `/experiment/${ experimentId } /${ HPARAMS_HTTP_PATH_PREFIX } /session_groups` ,
145
+ JSON . stringify ( listSessionRequestParams )
146
+ )
147
+ . pipe (
148
+ map ( ( sessionGroupsList ) => {
149
+ return { experimentHparamsInfo, sessionGroupsList} ;
150
+ } )
151
+ ) ;
152
+ } ) ,
153
+ map ( ( { experimentHparamsInfo, sessionGroupsList} ) => {
154
+ const runToHparamsAndMetrics : RunToHparamsAndMetrics = { } ;
155
+
156
+ // Reorganize the sessionGroup/session into run to <hparams,
157
+ // metrics>.
158
+ for ( const sessionGroup of sessionGroupsList . sessionGroups ) {
159
+ const hparams : HparamValue [ ] = Object . entries (
160
+ sessionGroup . hparams
161
+ ) . map ( ( keyValue ) => {
162
+ const [ hparam , value ] = keyValue ;
163
+ return { name : hparam , value} ;
164
+ } ) ;
165
+
166
+ for ( const session of sessionGroup . sessions ) {
167
+ for ( const metricValue of session . metricValues ) {
168
+ const runName = metricValue . name . group
169
+ ? `${ session . name } /${ metricValue . name . group } `
170
+ : session . name ;
171
+ const runId = `${ experimentId } /${ runName } ` ;
172
+ const hparamsAndMetrics = runToHparamsAndMetrics [ runId ] || {
173
+ metrics : [ ] ,
174
+ hparams,
175
+ } ;
176
+ hparamsAndMetrics . metrics . push ( {
177
+ tag : metricValue . name . tag ,
178
+ trainingStep : metricValue . trainingStep ,
179
+ value : metricValue . value ,
180
+ } ) ;
181
+ runToHparamsAndMetrics [ runId ] = hparamsAndMetrics ;
182
+ }
183
+ }
184
+ }
185
+ return {
186
+ hparamSpecs : experimentHparamsInfo . hparamInfos . map (
187
+ transformBackendHparamSpec
188
+ ) ,
189
+ metricSpecs : experimentHparamsInfo . metricInfos . map (
190
+ transformBackendMetricSpec
191
+ ) ,
192
+ runToHparamsAndMetrics,
193
+ } ;
194
+ } ) ,
195
+ catchError ( ( error ) => {
196
+ // HParams plugin return 400 when there are no hparams for an
197
+ // experiment.
198
+ if ( error instanceof HttpErrorResponse && error . status === 400 ) {
199
+ return of ( {
200
+ hparamSpecs : [ ] ,
201
+ metricSpecs : [ ] ,
202
+ runToHparamsAndMetrics : { } ,
203
+ } ) ;
204
+ }
205
+ return throwError ( error ) ;
206
+ } )
207
+ ) ;
57
208
}
58
209
}
0 commit comments