jskvrna commited on
Commit
c22c8c5
·
1 Parent(s): 118c9a5

Adds inference script and utils

Browse files

Adds a script for running inference on the dataset, including downloading the dataset from Hugging Face Hub if running locally.

It also adds utility functions for reading Colmap reconstructions and providing an empty solution in case of prediction errors.

Renames `test.py` to `train.py` and incorporates prediction and visualization logic.

Adds `.gitignore` file to exclude unnecessary files.

.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ .vscode/launch.json
3
+
4
+ __pycache__/
__pycache__/predict.cpython-310.pyc DELETED
Binary file (2.17 kB)
 
__pycache__/utils.cpython-310.pyc DELETED
Binary file (541 Bytes)
 
__pycache__/visu.cpython-310.pyc DELETED
Binary file (8.86 kB)
 
predict.py CHANGED
@@ -9,7 +9,7 @@ def convert_entry_to_human_readable(entry):
9
  if 'colmap' in k:
10
  out[k] = read_colmap_rec(v)
11
  elif k in ['wf_vertices', 'wf_edges', 'K', 'R', 't', 'depth']:
12
- out[k] = np.array(v)
13
  else:
14
  out[k]=v
15
  out['__key__'] = entry['order_id']
@@ -56,10 +56,9 @@ def predict_wireframe(entry) -> Tuple[np.ndarray, List[int]]:
56
 
57
  # Merge vertices from all images
58
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
59
-
60
-
61
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
62
  all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 4.0)
 
63
 
64
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
65
  print (f'Not enough vertices or connections in the 3D vertices')
 
9
  if 'colmap' in k:
10
  out[k] = read_colmap_rec(v)
11
  elif k in ['wf_vertices', 'wf_edges', 'K', 'R', 't', 'depth']:
12
+ out[k] = v
13
  else:
14
  out[k]=v
15
  out['__key__'] = entry['order_id']
 
56
 
57
  # Merge vertices from all images
58
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 0.5)
 
 
59
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d, keep_largest=False)
60
  all_3d_vertices_clean, connections_3d_clean = prune_too_far(all_3d_vertices_clean, connections_3d_clean, colmap_rec, th = 4.0)
61
+
62
 
63
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
64
  print (f'Not enough vertices or connections in the 3D vertices')
script.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from tqdm import tqdm
3
+ import pandas as pd
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from typing import Dict
7
+ from joblib import Parallel, delayed
8
+ import os
9
+ import json
10
+ import gc
11
+
12
+ from utils import empty_solution
13
+ from predict import predict_wireframe
14
+
15
+ if __name__ == "__main__":
16
+ print ("------------ Loading dataset------------ ")
17
+ param_path = Path('params.json')
18
+ print(param_path)
19
+ with param_path.open() as f:
20
+ params = json.load(f)
21
+ print(params)
22
+ import os
23
+
24
+ print('pwd:')
25
+ os.system('pwd')
26
+ print(os.system('ls -lahtr'))
27
+ print('/tmp/data/')
28
+ print(os.system('ls -lahtr /tmp/data/'))
29
+ print('/tmp/data/data')
30
+ print(os.system('ls -lahtrR /tmp/data/data'))
31
+
32
+
33
+ data_path_test_server = Path('/tmp/data')
34
+ data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/'
35
+
36
+ if data_path_test_server.exists():
37
+ # data_path = data_path_test_server
38
+ TEST_ENV = True
39
+ else:
40
+ # data_path = data_path_local
41
+ TEST_ENV = False
42
+ from huggingface_hub import snapshot_download
43
+ _ = snapshot_download(
44
+ repo_id=params['dataset'],
45
+ local_dir="/tmp/data",
46
+ repo_type="dataset",
47
+ )
48
+ data_path = data_path_test_server
49
+
50
+
51
+ print(data_path)
52
+
53
+ # dataset = load_dataset(params['dataset'], trust_remote_code=True, use_auth_token=params['token'])
54
+ # data_files = {
55
+ # "validation": [str(p) for p in [*data_path.rglob('*validation*.arrow')]+[*data_path.rglob('*public*/**/*.tar')]],
56
+ # "test": [str(p) for p in [*data_path.rglob('*test*.arrow')]+[*data_path.rglob('*private*/**/*.tar')]],
57
+ # }
58
+ data_files = {
59
+ "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')],
60
+ "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')],
61
+ }
62
+ print(data_files)
63
+ dataset = load_dataset(
64
+ str(data_path / 'hoho25k_test_x.py'),
65
+ data_files=data_files,
66
+ trust_remote_code=True,
67
+ writer_batch_size=100
68
+ )
69
+
70
+ print('load with webdataset')
71
+
72
+
73
+ print(dataset, flush=True)
74
+
75
+ print('------------ Now you can do your solution ---------------')
76
+ solution = []
77
+
78
+ def process_sample(sample, i):
79
+ try:
80
+ pred_vertices, pred_edges = predict_wireframe(sample)
81
+ except:
82
+ pred_vertices, pred_edges = empty_solution()
83
+ if i %10 == 0:
84
+ gc.collect()
85
+ return {
86
+ 'order_id': sample['order_id'],
87
+ 'wf_vertices': pred_vertices.tolist(),
88
+ 'wf_edges': pred_edges
89
+ }
90
+ num_cores = 4
91
+
92
+ for subset_name in dataset.keys():
93
+ print (f"Predicting {subset_name}")
94
+ for i, sample in enumerate(tqdm(dataset[subset_name])):
95
+ res = process_sample(sample, i)
96
+ solution.append(res)
97
+
98
+ print('------------ Saving results ---------------')
99
+ sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"])
100
+ sub.to_parquet("submission.parquet")
101
+ print("------------ Done ------------ ")
test.py → train.py RENAMED
@@ -7,7 +7,7 @@ import io
7
  import open3d as o3d
8
 
9
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
10
- from utils import read_colmap_rec
11
 
12
  #from hoho2025.example_solutions import predict_wireframe
13
  from hoho2025.metric_helper import hss
@@ -19,18 +19,25 @@ scores_hss = []
19
  scores_f1 = []
20
  scores_iou = []
21
 
 
 
22
  idx = 0
23
  for a in ds['train']:
24
  colmap = read_colmap_rec(a['colmap_binary'])
25
- pred_vertices, pred_edges = predict_wireframe(a)
26
 
27
- pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
28
- wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
29
- wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
30
- bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
 
 
 
 
 
 
31
 
32
- visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
33
- o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
34
 
35
  score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
36
  print(f"Score: {score}")
@@ -38,3 +45,9 @@ for a in ds['train']:
38
  scores_f1.append(score.f1)
39
  scores_iou.append(score.iou)
40
 
 
 
 
 
 
 
 
7
  import open3d as o3d
8
 
9
  from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color
10
+ from utils import read_colmap_rec, empty_solution
11
 
12
  #from hoho2025.example_solutions import predict_wireframe
13
  from hoho2025.metric_helper import hss
 
19
  scores_f1 = []
20
  scores_iou = []
21
 
22
+ show_visu = False
23
+
24
  idx = 0
25
  for a in ds['train']:
26
  colmap = read_colmap_rec(a['colmap_binary'])
 
27
 
28
+ try:
29
+ pred_vertices, pred_edges = predict_wireframe(a)
30
+ except:
31
+ pred_vertices, pred_edges = empty_solution()
32
+
33
+ if show_visu:
34
+ pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True)
35
+ wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications'])
36
+ wireframe2 = plot_wireframe_local(None, pred_vertices, pred_edges, None, color='rgb(255, 0, 0)')
37
+ bpo_cams = plot_bpo_cameras_from_entry_local(None, a)
38
 
39
+ visu_all = [pcd] + geometries + wireframe + bpo_cams + wireframe2
40
+ o3d.visualization.draw_geometries(visu_all, window_name="3D Reconstruction")
41
 
42
  score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5)
43
  print(f"Score: {score}")
 
45
  scores_f1.append(score.f1)
46
  scores_iou.append(score.iou)
47
 
48
+ for i in range(10):
49
+ print("END OF DATASET")
50
+ print(f"Mean HSS: {np.mean(scores_hss):.4f}")
51
+ print(f"Mean F1: {np.mean(scores_f1):.4f}")
52
+ print(f"Mean IoU: {np.mean(scores_iou):.4f}")
53
+
utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import pycolmap
2
  import tempfile,zipfile
3
  import io
 
 
4
 
5
  def read_colmap_rec(colmap_data):
6
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -8,4 +10,22 @@ def read_colmap_rec(colmap_data):
8
  zf.extractall(tmpdir) # unpacks cameras.txt, images.txt, etc. to tmpdir
9
  # Now parse with pycolmap
10
  rec = pycolmap.Reconstruction(tmpdir)
11
- return rec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pycolmap
2
  import tempfile,zipfile
3
  import io
4
+ import numpy as np
5
+ from typing import Dict
6
 
7
  def read_colmap_rec(colmap_data):
8
  with tempfile.TemporaryDirectory() as tmpdir:
 
10
  zf.extractall(tmpdir) # unpacks cameras.txt, images.txt, etc. to tmpdir
11
  # Now parse with pycolmap
12
  rec = pycolmap.Reconstruction(tmpdir)
13
+ return rec
14
+
15
+ def empty_solution():
16
+ '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
17
+ return np.zeros((2,3)), [(0, 1)]
18
+
19
+ class Sample(Dict):
20
+ def pick_repr_data(self, x):
21
+ if hasattr(x, 'shape'):
22
+ return x.shape
23
+ if isinstance(x, (str, float, int)):
24
+ return x
25
+ if isinstance(x, list):
26
+ return [type(x[0])] if len(x) > 0 else []
27
+ return type(x)
28
+
29
+ def __repr__(self):
30
+ # return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
31
+ return str({k: self.pick_repr_data(v) for k,v in self.items()})