@@ -160,29 +160,49 @@ def test_wrap_conversation_pipeline():
160
160
framework = "pt" ,
161
161
)
162
162
conv_pipe = wrap_conversation_pipeline (init_pipeline )
163
- data = {
164
- "past_user_inputs" : ["Which movie is the best ?" ],
165
- "generated_responses" : ["It's Die Hard for sure." ],
166
- "text" : "Can you explain why?" ,
167
- }
163
+ data = [
164
+ {
165
+ "role" : "user" ,
166
+ "content" : "Which movie is the best ?"
167
+ },
168
+ {
169
+ "role" : "assistant" ,
170
+ "content" : "It's Die Hard for sure."
171
+ },
172
+ {
173
+ "role" : "user" ,
174
+ "content" : "Can you explain why?"
175
+ }
176
+ ]
168
177
res = conv_pipe (data )
169
- assert "conversation" in res
170
- assert "generated_text" in res
178
+ assert "content" in res .messages [- 1 ]
171
179
172
180
173
181
@require_torch
174
182
def test_wrapped_pipeline ():
175
183
with tempfile .TemporaryDirectory () as tmpdirname :
176
- storage_dir = _load_repository_from_hf ("hf-internal-testing/tiny-random-blenderbot" , tmpdirname , framework = "pytorch" )
184
+ storage_dir = _load_repository_from_hf (
185
+ repository_id = "microsoft/DialoGPT-small" ,
186
+ target_dir = tmpdirname ,
187
+ framework = "pytorch"
188
+ )
177
189
conv_pipe = get_pipeline ("conversational" , storage_dir .as_posix ())
178
- data = {
179
- "past_user_inputs" : ["Which movie is the best ?" ],
180
- "generated_responses" : ["It's Die Hard for sure." ],
181
- "text" : "Can you explain why?" ,
182
- }
190
+ data = [
191
+ {
192
+ "role" : "user" ,
193
+ "content" : "Which movie is the best ?"
194
+ },
195
+ {
196
+ "role" : "assistant" ,
197
+ "content" : "It's Die Hard for sure."
198
+ },
199
+ {
200
+ "role" : "user" ,
201
+ "content" : "Can you explain why?"
202
+ }
203
+ ]
183
204
res = conv_pipe (data )
184
- assert "conversation" in res
185
- assert "generated_text" in res
205
+ assert "content" in res .messages [- 1 ]
186
206
187
207
188
208
def test_local_custom_pipeline ():
0 commit comments