Skip to content

Commit 7a84253

Browse files
authored
上传代码
1 parent 30e38b9 commit 7a84253

1 file changed

Lines changed: 373 additions & 0 deletions

File tree

sim/Target_detection_code.py

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
import cv2
6+
import pygame
7+
from pathlib import Path
8+
from queue import Queue
9+
from torch.utils.data import Dataset, DataLoader
10+
import carla
11+
12+
# 创建YOLOv5工作目录并克隆仓库
13+
os.system('git clone https://github.com/ultralytics/yolov5')
14+
os.chdir('yolov5')
15+
16+
# 安装依赖
17+
os.system('pip install -r requirements.txt')
18+
19+
# 创建虚拟环境并安装核心依赖
20+
os.system('python -m venv carla_env')
21+
os.system('source carla_env/bin/activate')
22+
os.system('pip install carla pygame numpy matplotlib')
23+
os.system('pip install torch torchvision tensorboard')
24+
25+
# 启动Carla服务器
26+
def launch_carla_server():
27+
os.system('./CarlaUE4.sh Town01 -windowed -ResX=800 -ResY=600')
28+
29+
# 连接Carla客户端
30+
def connect_carla():
31+
client = carla.Client('localhost', 2000)
32+
client.set_timeout(10.0)
33+
world = client.get_world()
34+
return world
35+
36+
# 生成车辆
37+
def spawn_vehicle(world):
38+
blueprint = world.get_blueprint_library().find('vehicle.tesla.model3')
39+
spawn_point = world.get_map().get_spawn_points()[0]
40+
vehicle = world.spawn_actor(blueprint, spawn_point)
41+
return vehicle
42+
43+
# 加载YOLOv5模型
44+
def load_yolov5_model():
45+
FILE = Path(__file__).resolve()
46+
ROOT = FILE.parents[0] # YOLOv5 根目录
47+
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # 加载预训练模型
48+
return model
49+
50+
# 处理图像并进行目标检测
51+
def process_image(image, model):
52+
img_array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
53+
img_array = np.reshape(img_array, (image.height, image.width, 4))
54+
img_array = img_array[:, :, :3] # 仅保留RGB通道
55+
img = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
56+
57+
# 使用YOLOv5进行目标检测
58+
results = model(img)
59+
detections = results.pandas().xyxy[0] # 结果转换为pandas DataFrame
60+
61+
# 在 Carla 中绘制检测结果
62+
for _, detection in detections.iterrows():
63+
xmin, ymin, xmax, ymax = int(detection.xmin), int(detection.ymin), int(detection.xmax), int(detection.ymax)
64+
label = f"{detection.name} {detection.confidence:.2f}"
65+
color = (0, 255, 0) # 绿色边界框
66+
img_array = cv2.rectangle(img_array, (xmin, ymin), (xmax, ymax), color, 2)
67+
img_array = cv2.putText(img_array, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
68+
69+
# 显示检测结果
70+
cv2.imshow("YOLOv5 Detection", img_array)
71+
cv2.waitKey(1)
72+
73+
# 附加传感器
74+
def attach_sensors(world, vehicle):
75+
blueprint_library = world.get_blueprint_library()
76+
77+
# RGB相机配置
78+
cam_bp = blueprint_library.find('sensor.camera.rgb')
79+
cam_bp.set_attribute('image_size_x', '800')
80+
cam_bp.set_attribute('image_size_y', '600')
81+
cam_bp.set_attribute('fov', '110')
82+
83+
# 生成RGB相机
84+
cam = world.spawn_actor(cam_bp, carla.Transform(), attach_to=vehicle)
85+
86+
# 监听相机数据
87+
cam.listen(lambda data: process_image(data, model))
88+
89+
return cam
90+
91+
# 数据记录器
92+
class SensorDataRecorder:
93+
def __init__(self):
94+
self.image_queue = Queue(maxsize=100)
95+
self.control_queue = Queue(maxsize=100)
96+
self.sync_counter = 0
97+
98+
def record_image(self, image):
99+
self.image_queue.put(image)
100+
self.sync_counter += 1
101+
102+
def record_control(self, control):
103+
self.control_queue.put(control)
104+
105+
def save_episode(self, episode_id):
106+
images = []
107+
controls = []
108+
while not self.image_queue.empty():
109+
images.append(self.image_queue.get())
110+
while not self.control_queue.empty():
111+
controls.append(self.control_queue.get())
112+
113+
np.savez(f'expert_data/episode_{episode_id}.npz',
114+
images=np.array(images),
115+
controls=np.array(controls))
116+
117+
# 手动控制车辆
118+
def manual_control(vehicle):
119+
keys = pygame.key.get_pressed()
120+
control = carla.VehicleControl()
121+
control.throttle = 0.5 * keys[pygame.K_UP]
122+
control.brake = 1.0 * keys[pygame.K_DOWN]
123+
control.steer = 2.0 * (keys[pygame.K_RIGHT] - keys[pygame.K_LEFT])
124+
vehicle.apply_control(control)
125+
return control
126+
127+
# 图像增强
128+
def augment_image(image):
129+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
130+
hsv[:, :, 2] = np.clip(hsv[:, :, 2] * np.random.uniform(0.8, 1.2), 0, 255)
131+
M = cv2.getRotationMatrix2D((400, 300), np.random.uniform(-5, 5), 1)
132+
augmented = cv2.warpAffine(hsv, M, (800, 600))
133+
return cv2.cvtColor(augmented, cv2.COLOR_HSV2BGR)
134+
135+
# 自动驾驶模型
136+
class AutonomousDriver(nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.conv_layers = nn.Sequential(
140+
nn.Conv2d(3, 24, 5, stride=2),
141+
nn.ReLU(),
142+
nn.Conv2d(24, 32, 5, stride=2),
143+
nn.ReLU(),
144+
nn.Conv2d(32, 64, 3),
145+
nn.ReLU(),
146+
nn.Flatten()
147+
)
148+
149+
self.fc_layers = nn.Sequential(
150+
nn.Linear(64 * 94 * 70, 512),
151+
nn.ReLU(),
152+
nn.Linear(512, 256),
153+
nn.ReLU(),
154+
nn.Linear(256, 3) # throttle, brake, steer
155+
)
156+
157+
def forward(self, x):
158+
x = self.conv_layers(x)
159+
return self.fc_layers(x)
160+
161+
# 训练模型
162+
def train_model(model, dataloader, epochs=50):
163+
criterion = nn.MSELoss()
164+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
165+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166+
model.to(device)
167+
168+
for epoch in range(epochs):
169+
total_loss = 0
170+
# 数据集类
171+
class DrivingDataset(Dataset):
172+
def __init__(self, data_dir, transform=None):
173+
self.files = glob.glob(f'{data_dir}/*.npz')
174+
self.transform = transform
175+
176+
def __len__(self):
177+
return len(self.files) * 100 # 假设每个episode有100帧
178+
179+
def __getitem__(self, idx):
180+
file_idx = idx // 100
181+
frame_idx = idx % 100
182+
data = np.load(self.files[file_idx])
183+
image = data['images'][frame_idx].transpose(2, 0, 1).astype(np.float32) / 255.0 # 转换为CHW格式
184+
control = data['controls'][frame_idx].astype(np.float32)
185+
186+
if self.transform:
187+
image = self.transform(image)
188+
189+
return torch.tensor(image), torch.tensor(control)
190+
191+
# 评估模型
192+
def evaluate_model(model, world, episodes=10):
193+
metrics = {
194+
'collision_rate': 0,
195+
'route_completion': 0,
196+
'traffic_violations': 0,
197+
'control_smoothness': 0
198+
}
199+
200+
control_filter = ControlFilter()
201+
202+
for _ in range(episodes):
203+
vehicle = spawn_vehicle(world)
204+
recorder = SensorDataRecorder()
205+
206+
# 附加传感器
207+
cam_bp = world.get_blueprint_library().find('sensor.camera.rgb')
208+
cam_bp.set_attribute('image_size_x', '800')
209+
cam_bp.set_attribute('image_size_y', '600')
210+
cam_bp.set_attribute('fov', '110')
211+
cam = world.spawn_actor(cam_bp, carla.Transform(), attach_to=vehicle)
212+
cam.listen(lambda data: recorder.record_image(data))
213+
214+
while True:
215+
# 获取控制输入(示例:手动控制)
216+
manual_control(vehicle)
217+
218+
# 检查碰撞
219+
if vehicle.get_transform().location.distance(vehicle.get_world().get_map().get_spawn_points()[0].location) < 5.0:
220+
metrics['route_completion'] += 1
221+
break
222+
223+
# 检查交通违规(示例)
224+
if random.random() < 0.01:
225+
metrics['traffic_violations'] += 1
226+
227+
# 控制平滑度
228+
control = vehicle.get_control()
229+
smoothed_control = control_filter.smooth(control)
230+
vehicle.apply_control(smoothed_control)
231+
232+
# 终止条件
233+
if metrics['collision_rate'] >= episodes:
234+
break
235+
236+
return calculate_safety_scores(metrics)
237+
238+
# 模型量化
239+
def quantize_model(model):
240+
model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
241+
torch.ao.quantization.prepare(model, inplace=True)
242+
torch.ao.quantization.convert(model, inplace=True)
243+
return model
244+
245+
# 控制信号平滑
246+
class ControlFilter:
247+
def __init__(self, alpha=0.8):
248+
self.prev_control = None
249+
self.alpha = alpha
250+
251+
def smooth(self, current_control):
252+
if self.prev_control is None:
253+
self.prev_control = current_control
254+
return current_control
255+
256+
smoothed = self.alpha * self.prev_control + (1 - self.alpha) * current_control
257+
self.prev_control = smoothed
258+
return smoothed
259+
260+
# 导出模型
261+
def export_model(model, output_path):
262+
model.eval()
263+
traced_model = torch.jit.trace(model, torch.randn(1, 3, 600, 800))
264+
traced_model.save(output_path)
265+
266+
# 加载模型
267+
def load_deployed_model(model_path):
268+
model = AutonomousDriver()
269+
model.load_state_dict(torch.load(model_path))
270+
return model
271+
272+
# 自动驾驶主循环
273+
def autonomous_driving_loop(world):
274+
model = load_deployed_model('deployed_model.pt')
275+
vehicle = spawn_vehicle(world)
276+
control_filter = ControlFilter()
277+
278+
while True:
279+
# 获取相机图像
280+
image = get_camera_image(world, vehicle)
281+
preprocessed = preprocess_image(image)
282+
283+
# 模型推理
284+
with torch.no_grad():
285+
control = model(preprocessed)
286+
287+
# 控制信号平滑
288+
smoothed_control = control_filter.smooth(control)
289+
290+
# 执行控制
291+
vehicle.apply_control(smoothed_control)
292+
293+
# 安全监控
294+
if detect_critical_situation(vehicle):
295+
trigger_emergency_stop(vehicle)
296+
297+
# 辅助函数:获取相机图像
298+
def get_camera_image(world, vehicle):
299+
# 这里需要根据你的传感器设置获取最新的相机图像
300+
# 示例代码:
301+
# for actor in world.get_actors():
302+
# if actor.type_id == 'sensor.camera.rgb' and actor.parent.id == vehicle.id:
303+
# return actor.queue.get()
304+
pass
305+
306+
# 辅助函数:预处理图像
307+
def preprocess_image(image):
308+
# 图像预处理逻辑
309+
img_array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
310+
img_array = np.reshape(img_array, (image.height, image.width, 4))
311+
img_array = img_array[:, :, :3] # 仅保留RGB通道
312+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
313+
img_array = cv2.resize(img_array, (800, 600))
314+
return torch.from_numpy(img_array.transpose(2, 0, 1).astype(np.float32) / 255.0).unsqueeze(0)
315+
316+
# 辅助函数:检测紧急情况
317+
def detect_critical_situation(vehicle):
318+
# 紧急情况检测逻辑
319+
# 示例:检测车辆是否静止超过一定时间
320+
velocity = vehicle.get_velocity()
321+
if velocity.x == 0 and velocity.y == 0 and velocity.z == 0:
322+
return True
323+
return False
324+
325+
# 辅助函数:触发紧急停止
326+
def trigger_emergency_stop(vehicle):
327+
# 紧急停止逻辑
328+
control = carla.VehicleControl()
329+
control.brake = 1.0
330+
control.hand_brake = True
331+
vehicle.apply_control(control)
332+
333+
# 辅助函数:计算安全分数
334+
def calculate_safety_scores(metrics):
335+
# 安全分数计算逻辑
336+
total_score = 0
337+
total_score -= metrics['collision_rate'] * 10
338+
total_score += metrics['route_completion'] * 5
339+
total_score -= metrics['traffic_violations'] * 3
340+
total_score += metrics['control_smoothness'] * 2
341+
return total_score
342+
343+
if __name__ == '__main__':
344+
# 启动Carla服务器
345+
server_thread = Thread(target=launch_carla_server)
346+
server_thread.start()
347+
time.sleep(5) # 等待服务器启动
348+
349+
# 连接Carla客户端
350+
world = connect_carla()
351+
352+
# 加载YOLOv5模型
353+
model = load_yolov5_model()
354+
355+
# 训练模型
356+
dataset = DrivingDataset('expert_data')
357+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
358+
autodrive_model = AutonomousDriver()
359+
train_model(autodrive_model, dataloader, epochs=50)
360+
361+
# 导出模型
362+
export_model(autodrive_model, 'deployed_model.pt')
363+
364+
# 量化模型
365+
quantized_model = quantize_model(autodrive_model)
366+
export_model(quantized_model, 'quantized_model.pt')
367+
368+
# 启动自动驾驶
369+
autonomous_driving_loop(world)
370+
371+
# 评估模型
372+
metrics = evaluate_model(autodrive_model, world)
373+
print(f'Model Safety Score: {metrics}')

0 commit comments

Comments
 (0)