Skip to content

Commit 246a4c3

Browse files
authored
更新retrieve场景生成文件夹 (#110)
* Update hutbthesis_main.tex * Add files via upload * 添加scripts目录 包含一些运行脚本 * retrieve文件夹 retrieve 文件夹在项目中用于存储与场景描述检索相关的数据和代码,帮助从数据库中获取并生成动态场景。 * 添加 undergraduate/demo 文件夹 * Add files via upload * Add files via upload * 添加docs文件夹 * Add files via upload * 添加safebench.egg-info文件夹 * 添加 safebench 文件夹 * 添加 undergraduate/tools 文件夹 * 添加 undergraduate/Scenic 文件夹 * 为 Python 包 safebench 创建一个可安装的包 为 Python 包 safebench 创建一个可安装的包,并指定了依赖项(gym 和 pygame)以及包的基本信息。 * 这是原作者的一个许可证 MIT 许可证是一个非常宽松的开源许可协议,允许他人自由地使用和修改代码,只要他们保留版权声明并且了解代码的“原样”提供,不承担任何责任。 * 包含用于构建和运行 Docker 容器的相关配置文件 docker 文件夹包含用于构建和运行 Docker 容器的相关配置文件(如 Dockerfile 和 run_docker.sh),目的是创建一个可重复的环境来运行项目或应用,通常用于依赖管理和环境配置。 * 用于存储和管理与自然语言描述生成场景相关的数据和代码。 用于存储和管理与自然语言描述生成场景相关的数据和代码。 * Update README.md * 输出文件outputs * 更新存放量化评估的文件夹 * Update README.md * Update README.md * Update README.md * Delete scene/undergraduate/demo/.idea directory * Delete scene/undergraduate/retrieve/__pycache__ directory * Delete scene/undergraduate/safebench.egg-info directory * Delete scene/undergraduate/outputs/screenshots directory * Delete scene/undergraduate/setup.py * Delete scene/undergraduate/safebench/__pycache__ directory * Delete scene/undergraduate/scripts/__pycache__ directory * Delete scene/undergraduate/safebench/agent/__pycache__ directory * Delete scene/undergraduate/Scenic directory * Delete scene/undergraduate/demo directory * Delete scene/undergraduate/docker directory * Delete scene/undergraduate/evaluation directory * Delete scene/undergraduate/docs directory * Delete scene/undergraduate/retrieve directory * Delete scene/undergraduate/safebench directory * Delete scene/undergraduate/scripts directory * Delete scene/undergraduate/tools directory * Delete scene/undergraduate/LICENSE * 增加代码文件夹 一些脚本 一些场景生成文件 一些量化评估文件 * Delete scene/retrieve/__pycache__ directory * Delete scene/retrieve/database_v1.pkl * Delete scene/scripts/__pycache__ directory * Update latexmkrc * Update hutbthesis_main.tex * 增加自动化编译 * Update CICID.yml 调整格式 * Update CICID.yml * Update hutbthesis_main.tex * 忽视一些文件 * 添加safebench * Delete scene/safebench directory * Create safebench * Delete scene/safebench * Create carla_runner.py * Add files via upload * 更新代码文件 * Create README.md * Update README.md * 增加util文件夹 包含项目中的辅助工具和常用函数 * Delete scene/safebench/util/__pycache__ directory * Delete scene/safebench/util directory * Update architecture.py * Update retrieve.py * Update architecture.py * Update architecture.py * Update architecture.py * Update architecture.py
1 parent ba45d97 commit 246a4c3

1 file changed

Lines changed: 28 additions & 4 deletions

File tree

scene/retrieve/retrieve.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,59 @@
1-
import setGPU
21
import os
2+
import setGPU
33
import csv
44
import pickle
55
import re
6-
from sentence_transformers import SentenceTransformer
6+
from sentence_transformers import SentenceTransformer, models
77
from os import path as osp
88
from tqdm import tqdm
99
import argparse
1010
from architecture import LLMChat
1111
from utils import load_file, retrieve_topk, generate_code_snippet, save_scenic_code
1212

13+
1314
# no need for faiss currently
1415
# import faiss
1516

17+
# Argument parsing
1618
parser = argparse.ArgumentParser(description="Set up configurations for your script.")
1719
parser.add_argument('--port_ip', type=int, default=2000, help='Port IP address (default: 2000)')
1820
parser.add_argument('--topk', type=int, default=3, help='Top K value (default: 3) for retrieval')
1921
parser.add_argument('--model', type=str, default='gpt-4o', help="Model name (default: 'gpt-4o'), also support transformers model")
2022
parser.add_argument('--use_llm', action='store_true', help='if use llm for generating new snippets')
2123
args = parser.parse_args()
2224

25+
# Configuration
2326
port_ip = args.port_ip
2427
topk = args.topk
2528
use_llm = args.use_llm
2629

30+
# LLM model initialization
2731
llm_model = LLMChat(args.model)
2832
local_path = osp.abspath(osp.dirname(osp.dirname(osp.realpath(__file__))))
2933
extraction_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'extraction.txt'))
3034
behavior_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'behavior.txt'))
3135
geometry_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'geometry.txt'))
3236
spawn_prompt = load_file(osp.join(local_path, 'retrieve', 'prompts', 'spawn.txt'))
3337
scenario_descriptions = load_file(osp.join(local_path, 'retrieve', 'scenario_descriptions.txt')).split('\n')
34-
encoder = SentenceTransformer('sentence-transformers/sentence-t5-large', device='cuda')
38+
39+
# 🔥 修改开始:本地加载 sentence-t5-large 模型
40+
model_dir = r"D:\sceneMain\chatScene\models\sentence-t5-large"
41+
if not os.path.exists(model_dir):
42+
raise FileNotFoundError(f"本地模型路径不存在:{model_dir}")
43+
44+
required_files = ["config.json", "pytorch_model.bin"]
45+
for filename in required_files:
46+
if not os.path.exists(os.path.join(model_dir, filename)):
47+
raise FileNotFoundError(f"缺少必要的文件: {filename}{model_dir} 中")
48+
49+
word_embedding_model = models.Transformer(model_dir, max_seq_length=512)
50+
pooling_model = models.Pooling(
51+
word_embedding_model.get_word_embedding_dimension(),
52+
pooling_mode='mean'
53+
)
54+
encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cuda')
55+
print("✅ 成功加载本地 sentence-t5-large 模型!")
56+
# 🔥 修改结束
3557

3658
# Load the database
3759
with open(osp.join(local_path, 'retrieve/database_v1.pkl'), 'rb') as file:
@@ -48,7 +70,7 @@
4870
geometry_embeddings = encoder.encode(geometry_descriptions, device='cuda', convert_to_tensor=True)
4971
spawn_embeddings = encoder.encode(spawn_descriptions, device='cuda', convert_to_tensor=True)
5072

51-
## This is the head for scenic file, you can modify the carla map or ego model here
73+
# This is the head for scenic file, you can modify the carla map or ego model here
5274
head = '''param map = localPath(f'../maps/{Town}.xodr')
5375
param carla_map = Town
5476
model scenic.simulators.carla.model
@@ -57,10 +79,12 @@
5779

5880
log_file_path = osp.join(local_path, 'safebench', 'scenario', 'scenario_data', 'scenic_data', 'dynamic_scenario', 'dynamic_log.csv')
5981

82+
# Write log results
6083
with open(log_file_path, mode='w', newline='') as file:
6184
log_writer = csv.writer(file)
6285
log_writer.writerow(['Scenario', 'AdvObject', 'Behavior Description', 'Behavior Snippet', 'Geometry Description', 'Geometry Snippet', 'Spawn Description', 'Spawn Snippet', 'Success'])
6386

87+
# Process each scenario description
6488
for q, current_scenario in tqdm(enumerate(scenario_descriptions)):
6589
messages = [
6690
{"role": "system", "content": "You are a helpful assistant."},

0 commit comments

Comments
 (0)