Skip to content

Commit

Permalink
Support for custom prompty for user simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Nagkumar Arkalgud committed Jul 15, 2024
1 parent a546027 commit 6aef44b
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/promptflow-evals/promptflow/evals/synthetic/task_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(self, azure_ai_project: Dict[str, Any], credential=None):
self.azure_ai_project["api_version"] = "2024-02-15-preview"
self.credential = credential

async def build_query(self, *, user_persona, conversation_history, user_simulator_prompty):
async def build_query(
self, *, user_persona, conversation_history, user_simulator_prompty, user_simulator_prompty_kwargs
):
# make a call to llm with user_persona and query
prompty_model_config = {"configuration": self.azure_ai_project}
prompty_model_config.update(
Expand All @@ -86,9 +88,13 @@ async def build_query(self, *, user_persona, conversation_history, user_simulato
if not user_simulator_prompty:
current_dir = os.path.dirname(__file__)
prompty_path = os.path.join(current_dir, "_prompty", "task_simulate_with_persona.prompty")
_flow = load_flow(source=prompty_path, model=prompty_model_config)
else:
raise NotImplementedError("Custom prompty not supported yet")
_flow = load_flow(source=prompty_path, model=prompty_model_config)
_flow = load_flow(
source=user_simulator_prompty,
model=prompty_model_config,
**user_simulator_prompty_kwargs,
)
response = _flow(user_persona=user_persona, conversation_history=conversation_history)
except Exception as e:
print("Something went wrong running the prompty")
Expand All @@ -114,6 +120,8 @@ async def __call__(
query_response_generating_prompty: str = None,
user_simulator_prompty: str = None,
api_call_delay_sec: float = 1,
query_response_generating_prompty_kwargs: Dict[str, Any] = {},
user_simulator_prompty_kwargs: Dict[str, Any] = {},
**kwargs,
):
if num_queries != len(user_persona):
Expand All @@ -129,7 +137,8 @@ async def __call__(
prompty_path = os.path.join(current_dir, "_prompty", "task_query_response.prompty")
_flow = load_flow(source=prompty_path, model=prompty_model_config)
else:
query_response_generating_prompty_kwargs = {**kwargs}
if not query_response_generating_prompty_kwargs:
query_response_generating_prompty_kwargs = {**kwargs}
_flow = load_flow(
source=query_response_generating_prompty,
model=prompty_model_config,
Expand Down Expand Up @@ -170,6 +179,7 @@ async def __call__(
max_conversation_turns=max_conversation_turns,
user_persona=user_persona_item,
user_simulator_prompty=user_simulator_prompty,
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
target=target,
api_call_delay_sec=api_call_delay_sec,
progress_bar=progress_bar,
Expand All @@ -192,6 +202,7 @@ async def complete_conversation(
max_conversation_turns,
user_persona,
user_simulator_prompty,
user_simulator_prompty_kwargs,
target,
api_call_delay_sec,
progress_bar,
Expand Down Expand Up @@ -224,6 +235,7 @@ async def complete_conversation(
user_persona=user_persona,
conversation_history=conversation_history.to_conv_history(),
user_simulator_prompty=user_simulator_prompty,
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
)
await asyncio.sleep(api_call_delay_sec)
# Append user simulator's response
Expand Down

0 comments on commit 6aef44b

Please sign in to comment.