Skip to content

Commit b990e53

Browse files
authored
更新architecture.py (#107)
* Update hutbthesis_main.tex * Add files via upload * 添加scripts目录 包含一些运行脚本 * retrieve文件夹 retrieve 文件夹在项目中用于存储与场景描述检索相关的数据和代码,帮助从数据库中获取并生成动态场景。 * 添加 undergraduate/demo 文件夹 * Add files via upload * Add files via upload * 添加docs文件夹 * Add files via upload * 添加safebench.egg-info文件夹 * 添加 safebench 文件夹 * 添加 undergraduate/tools 文件夹 * 添加 undergraduate/Scenic 文件夹 * 为 Python 包 safebench 创建一个可安装的包 为 Python 包 safebench 创建一个可安装的包,并指定了依赖项(gym 和 pygame)以及包的基本信息。 * 这是原作者的一个许可证 MIT 许可证是一个非常宽松的开源许可协议,允许他人自由地使用和修改代码,只要他们保留版权声明并且了解代码的“原样”提供,不承担任何责任。 * 包含用于构建和运行 Docker 容器的相关配置文件 docker 文件夹包含用于构建和运行 Docker 容器的相关配置文件(如 Dockerfile 和 run_docker.sh),目的是创建一个可重复的环境来运行项目或应用,通常用于依赖管理和环境配置。 * 用于存储和管理与自然语言描述生成场景相关的数据和代码。 用于存储和管理与自然语言描述生成场景相关的数据和代码。 * Update README.md * 输出文件outputs * 更新存放量化评估的文件夹 * Update README.md * Update README.md * Update README.md * Delete scene/undergraduate/demo/.idea directory * Delete scene/undergraduate/retrieve/__pycache__ directory * Delete scene/undergraduate/safebench.egg-info directory * Delete scene/undergraduate/outputs/screenshots directory * Delete scene/undergraduate/setup.py * Delete scene/undergraduate/safebench/__pycache__ directory * Delete scene/undergraduate/scripts/__pycache__ directory * Delete scene/undergraduate/safebench/agent/__pycache__ directory * Delete scene/undergraduate/Scenic directory * Delete scene/undergraduate/demo directory * Delete scene/undergraduate/docker directory * Delete scene/undergraduate/evaluation directory * Delete scene/undergraduate/docs directory * Delete scene/undergraduate/retrieve directory * Delete scene/undergraduate/safebench directory * Delete scene/undergraduate/scripts directory * Delete scene/undergraduate/tools directory * Delete scene/undergraduate/LICENSE * 增加代码文件夹 一些脚本 一些场景生成文件 一些量化评估文件 * Delete scene/retrieve/__pycache__ directory * Delete scene/retrieve/database_v1.pkl * Delete scene/scripts/__pycache__ directory * Update latexmkrc * Update hutbthesis_main.tex * 增加自动化编译 * Update CICID.yml 调整格式 * Update CICID.yml * Update hutbthesis_main.tex * 忽视一些文件 * 添加safebench * Delete scene/safebench directory * Create safebench * Delete scene/safebench * Create carla_runner.py * Add files via upload * 更新代码文件 * Create README.md * Update README.md * 增加util文件夹 包含项目中的辅助工具和常用函数 * Delete scene/safebench/util/__pycache__ directory * Delete scene/safebench/util directory * Update architecture.py
1 parent 3d4d264 commit b990e53

1 file changed

Lines changed: 41 additions & 15 deletions

File tree

scene/retrieve/architecture.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,60 @@
22
import openai
33
import torch
44
import transformers
5-
os.environ["OPENAI_API_KEY"] = 'sk-proj-xxx'
6-
7-
class LLMChat():
8-
def __init__(self, model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'):
5+
6+
# 设置 OpenAI API 密钥
7+
os.environ["OPENAI_API_KEY"] = 'sk-...s0kA' # 请替换为你的实际 API 密钥
8+
9+
10+
class LLMChat:
11+
def __init__(self, model_name='gpt-4', use_gpu=True):
12+
"""
13+
初始化聊天模型。
14+
- model_name: 使用的模型名称(例如 'gpt-4' 或 HuggingFace 模型名称)。
15+
- use_gpu: 是否使用 GPU(默认是 True)。
16+
"""
917
super(LLMChat, self).__init__()
1018
self.model_name = model_name
19+
self.use_gpu = use_gpu
20+
21+
# 如果是 GPT 模型,使用 OpenAI API
1122
if model_name.startswith('gpt'):
12-
self.client = openai.OpenAI()
23+
openai.api_key = os.environ["OPENAI_API_KEY"] # 设置 OpenAI API 密钥
1324
else:
25+
# 设置 HuggingFace 模型并选择 GPU 或 CPU
26+
self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
1427
self.pipeline = transformers.pipeline(
1528
"text-generation",
1629
model=model_name,
17-
model_kwargs={"torch_dtype": torch.bfloat16},
18-
device="cuda",
30+
model_kwargs={"torch_dtype": torch.float32}, # 使用 float32 避免兼容性问题
31+
device=0 if self.device == "cuda" else -1, # 如果使用 GPU,选择 GPU 设备
1932
)
2033

21-
def generate(self, messages, max_new_tokens = 500):
34+
def generate(self, messages, max_new_tokens=500):
35+
"""
36+
生成文本。
37+
- messages: 输入消息(适用于 GPT 模型和 HuggingFace 模型)。
38+
- max_new_tokens: 最大生成的 tokens 数量。
39+
"""
2240
if self.model_name.startswith('gpt'):
23-
response = self.client.chat.completions.create(
41+
# 对于 OpenAI GPT 模型,使用新的聊天 API
42+
response = openai.ChatCompletion.create(
2443
model=self.model_name,
25-
messages=messages,
26-
temperature=0,
44+
messages=messages, # 将消息数组直接传递给 chat 模型
45+
temperature=0.7, # 控制生成文本的多样性
46+
max_tokens=max_new_tokens,
2747
)
28-
return response.choices[0].message.content
48+
return response['choices'][0]['message']['content'].strip() # 返回生成的文本
49+
2950
else:
51+
# 对于 HuggingFace 模型,使用文本生成 pipeline
52+
if isinstance(messages, str):
53+
# 如果输入是单个字符串,构建消息格式
54+
messages = [{"role": "user", "content": messages}]
55+
3056
outputs = self.pipeline(
31-
messages,
57+
messages[0]['content'], # 传递用户消息的内容
3258
max_new_tokens=max_new_tokens,
33-
do_sample=False
59+
do_sample=True, # 启用采样,生成更多样的文本
3460
)
35-
return outputs[0]["generated_text"][-1]
61+
return outputs[0]["generated_text"].strip() # 返回生成的文本

0 commit comments

Comments
 (0)