-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
66 lines (55 loc) · 2.01 KB
/
main.py
File metadata and controls
66 lines (55 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import pandas as pd
import numpy.fft as fft
import matplotlib.pyplot as plt
import pickle
import os
import math
import h5py
import random
import cv2
import gc
from tqdm import tqdm
from multiprocessing import Pool
import gc, glob, os
from concurrent.futures import ProcessPoolExecutor
import torch
import torch.nn as nn
from scipy.stats import norm
from timm import create_model
from gravit.normalize import normalize
from gravit.model import Model
from gravit.preprocess import preprocess
model = Model()
##importer la data
def dataload(filepath):
astime = np.full([2, 360, 5760], np.nan, dtype=np.float32)
with h5py.File(filepath, "r") as f:
fid, _ = os.path.splitext(os.path.split(filepath)[1])
HT = (np.asarray(f[fid]["H1"]["timestamps_GPS"]) / 1800).round().astype(np.int64)
LT = (np.asarray(f[fid]["L1"]["timestamps_GPS"]) / 1800).round().astype(np.int64)
MIN = min(HT.min(), LT.min()); HT -= MIN; LT -= MIN
H1 = normalize(np.asarray(f[fid]["H1"]["SFTs"], np.complex128))
valid = HT < 5760; astime[0][:, HT[valid]] = H1[:, valid]
L1 = normalize(np.asarray(f[fid]["L1"]["SFTs"], np.complex128))
valid = LT < 5760; astime[1][:, LT[valid]] = L1[:, valid]
gc.collect()
return fid, astime, H1.mean(), L1.mean()
@torch.no_grad()
def inference(model, path):
file_path = glob.glob(os.path.join(path, "*.hdf5"))
FID, RES = [], []
with ProcessPoolExecutor(2) as pool:
for fid, input, H1, L1 in pool.map(dataload, sorted(file_path)):
tta = preprocess(64, input, H1, L1)
FID += [fid]
RES += [model(tta).softmax(-1)[..., 1].mean(0)]
return FID, torch.stack(RES, 0).cpu().float().numpy()
if __name__ == "__main__":
model = get_model("model_best.pth")
fid, infer = inference(
model, "../input/g2net-detecting-continuous-gravitational-waves/test"
)
result = pd.DataFrame.from_dict({"id": fid, "target": infer})
print(result)
#result.to_csv("submission.csv", index=False)