-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
50 lines (43 loc) · 1.56 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
def load_model(model_path, device):
model = torch.load(model_path, map_location=device)
model.to(device)
model.eval()
return model
def preprocess_image(image_path):
# 读取图像并转换为灰度
img = nib.load(image_path).dataobj
img = img[:, :, 5] # 假设只处理一张slice
img = Image.fromarray(img)
# 图像预处理
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
])
img_tensor = transform(img).unsqueeze(0) # 增加batch维度
return img_tensor
def predict(model, img_tensor, device):
with torch.no_grad():
img_tensor = img_tensor.to(device)
output = model(img_tensor)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
return pred
def save_prediction(pred, output_path):
plt.imsave(output_path, pred, cmap='gray')
def main(image_path, model_path, output_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(model_path, device)
img_tensor = preprocess_image(image_path)
pred = predict(model, img_tensor, device)
save_prediction(pred, output_path)
print(f"Prediction saved to {output_path}")
if __name__ == "__main__":
image_path = "dataset/testing/patient101/patient101_frame01.nii.gz"
model_path = "models/model-unetpp-0.012335.pt"
output_path = "pred.png"
main(image_path, model_path, output_path)