|
| 1 | +import setGPU |
| 2 | +import os |
| 3 | +import csv |
| 4 | +import pickle |
| 5 | +import re |
| 6 | +from sentence_transformers import SentenceTransformer |
| 7 | +from os import path as osp |
| 8 | +from tqdm import tqdm |
| 9 | +import argparse |
| 10 | +from architecture import LLMChat |
| 11 | +from utils import load_file, retrieve_topk, generate_code_snippet, save_scenic_code |
| 12 | + |
| 13 | +# no need for faiss currently |
| 14 | +# import faiss |
| 15 | + |
| 16 | +parser = argparse.ArgumentParser(description="Set up configurations for your script.") |
| 17 | +parser.add_argument('--port_ip', type=int, default=2000, help='Port IP address (default: 2000)') |
| 18 | +parser.add_argument('--topk', type=int, default=3, help='Top K value (default: 3) for retrieval') |
| 19 | +parser.add_argument('--model', type=str, default='gpt-4o', help="Model name (default: 'gpt-4o'), also support transformers model") |
| 20 | +parser.add_argument('--use_llm', action='store_true', help='if use llm for generating new snippets') |
| 21 | +args = parser.parse_args() |
| 22 | + |
| 23 | +port_ip = args.port_ip |
| 24 | +topk = args.topk |
| 25 | +use_llm = args.use_llm |
| 26 | + |
| 27 | +llm_model = LLMChat(args.model) |
| 28 | +local_path = osp.abspath(osp.dirname(osp.dirname(osp.realpath(__file__)))) |
| 29 | +extraction_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'extraction.txt')) |
| 30 | +behavior_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'behavior.txt')) |
| 31 | +geometry_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'geometry.txt')) |
| 32 | +spawn_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'spawn.txt')) |
| 33 | +scenario_descriptions = load_file(osp.join(local_path, 'retrieve', 'scenario_descriptions.txt')).split('\n') |
| 34 | +encoder = SentenceTransformer('sentence-transformers/sentence-t5-large', device='cuda') |
| 35 | + |
| 36 | +# Load the database |
| 37 | +with open(osp.join(local_path, 'retrieve/database_v1.pkl'), 'rb') as file: |
| 38 | + database = pickle.load(file) |
| 39 | + |
| 40 | +behavior_descriptions = database['behavior']['description'] |
| 41 | +geometry_descriptions = database['geometry']['description'] |
| 42 | +spawn_descriptions = database['spawn']['description'] |
| 43 | +behavior_snippets = database['behavior']['snippet'] |
| 44 | +geometry_snippets = database['geometry']['snippet'] |
| 45 | +spawn_snippets = database['spawn']['snippet'] |
| 46 | + |
| 47 | +behavior_embeddings = encoder.encode(behavior_descriptions, device='cuda', convert_to_tensor=True) |
| 48 | +geometry_embeddings = encoder.encode(geometry_descriptions, device='cuda', convert_to_tensor=True) |
| 49 | +spawn_embeddings = encoder.encode(spawn_descriptions, device='cuda', convert_to_tensor=True) |
| 50 | + |
| 51 | +## This is the head for scenic file, you can modify the carla map or ego model here |
| 52 | +head = '''param map = localPath(f'../maps/{Town}.xodr') |
| 53 | +param carla_map = Town |
| 54 | +model scenic.simulators.carla.model |
| 55 | +EGO_MODEL = "vehicle.lincoln.mkz_2017" |
| 56 | +''' |
| 57 | + |
| 58 | +log_file_path = osp.join(local_path, 'safebench', 'scenario', 'scenario_data', 'scenic_data', 'dynamic_scenario', 'dynamic_log.csv') |
| 59 | + |
| 60 | +with open(log_file_path, mode='w', newline='') as file: |
| 61 | + log_writer = csv.writer(file) |
| 62 | + log_writer.writerow(['Scenario', 'AdvObject', 'Behavior Description', 'Behavior Snippet', 'Geometry Description', 'Geometry Snippet', 'Spawn Description', 'Spawn Snippet', 'Success']) |
| 63 | + |
| 64 | + for q, current_scenario in tqdm(enumerate(scenario_descriptions)): |
| 65 | + messages = [ |
| 66 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 67 | + {"role": "user", "content": extraction_prompt.format(scenario=current_scenario)}, |
| 68 | + ] |
| 69 | + |
| 70 | + response = llm_model.generate(messages) |
| 71 | + |
| 72 | + try: |
| 73 | + match = re.search(r"Adversarial Object:(.*?)Behavior:(.*?)Geometry:(.*?)Spawn Position:(.*)", response, re.DOTALL) |
| 74 | + if not match: |
| 75 | + raise ValueError("Failed to extract components from the response") |
| 76 | + |
| 77 | + current_adv_object, current_behavior, current_geometry, current_spawn = [s.strip() for s in match.groups()] |
| 78 | + |
| 79 | + # Retrieve the top K relevant snippets |
| 80 | + top_behavior_descriptions, top_behavior_snippets = retrieve_topk(encoder, topk, behavior_descriptions, behavior_snippets, behavior_embeddings, current_behavior) |
| 81 | + top_geometry_descriptions, top_geometry_snippets = retrieve_topk(encoder, topk, geometry_descriptions, geometry_snippets, geometry_embeddings, current_geometry) |
| 82 | + top_spawn_descriptions, top_spawn_snippets = retrieve_topk(encoder, topk, spawn_descriptions, spawn_snippets, spawn_embeddings, current_spawn) |
| 83 | + |
| 84 | + # Generate code snippets using the LLM |
| 85 | + generated_behavior_code = generate_code_snippet( |
| 86 | + llm_model, behavior_prompt, top_behavior_descriptions, top_behavior_snippets, current_behavior, topk, use_llm |
| 87 | + ) |
| 88 | + |
| 89 | + generated_geometry_code = generate_code_snippet( |
| 90 | + llm_model, geometry_prompt, top_geometry_descriptions, top_geometry_snippets, current_geometry, topk, use_llm |
| 91 | + ) |
| 92 | + |
| 93 | + generated_spawn_code = generate_code_snippet( |
| 94 | + llm_model, spawn_prompt, top_spawn_descriptions, top_spawn_snippets, current_spawn, topk, use_llm |
| 95 | + ) |
| 96 | + |
| 97 | + # Log the results |
| 98 | + log_writer.writerow([current_scenario, current_adv_object, current_behavior, generated_behavior_code, current_geometry, generated_geometry_code, current_spawn, generated_spawn_code, 1]) |
| 99 | + |
| 100 | + Town, generated_geometry_code = generated_geometry_code.split('\n', 1) |
| 101 | + scenic_code = '\n'.join([f"'''{current_scenario}'''", Town, head, generated_behavior_code, generated_geometry_code, generated_spawn_code.format(AdvObject=current_adv_object)]) |
| 102 | + save_scenic_code(local_path, port_ip, scenic_code, q) |
| 103 | + |
| 104 | + except Exception as e: |
| 105 | + log_writer.writerow([current_scenario, '', '', '', '', '', '', '', 0]) |
| 106 | + print(f"Failure for scenario: {current_scenario} - Error: {e}") |
0 commit comments