mcosarinsky commited on
Commit
8d63b09
·
1 Parent(s): 4e79318
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: CheXmask U
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.0.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: Chest x-ray HybridGNet Segmentation
3
+ emoji:
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
+ license: gpl-3.0
11
  ---
12
 
13
+ Demo of the HybridGNet model with 2 image-to-graph skip connections from: arxiv.org/abs/2203.10977
14
+ Original HybridGNet model: arxiv.org/abs/2106.09832
15
+ The training procedure was taken from: arxiv.org/abs/2211.07395
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.segmentation import segment
3
+
4
+ # ------------------------- GRADIO -------------------------
5
+ if __name__ == "__main__":
6
+ with gr.Blocks() as demo:
7
+ gr.Markdown("""
8
+ # CheXmask-U: Uncertainty in Landmark-based Anatomical Segmentation
9
+
10
+ Demo of the **uncertainty estimation framework** proposed in the paper "CheXmask-U: Quantifying uncertainty in landmark-based anatomical segmentation for X-ray images".
11
+ The demonstration performs landmark-based segmentation (lungs and heart) and quantifies the uncertainty in the predicted position of each anatomical landmark.
12
+
13
+ ### 📝 Instructions
14
+
15
+ 1. **Upload** a chest X-ray image (PA or AP view) in PNG or JPEG format, or select an example image.
16
+ 2. **Explore Prediction Variability**: You can either use your mouse to draw on the input image and perform inpainting in different regions or adjust the Gaussian Noise Std Dev slider to simulate image corruption.
17
+ 3. Click on **"Segment Image"**.
18
+ * The output image will display the segmentation overlay where the color gradient indicates the node-wise predictive uncertainty.
19
+ * The Results file output will contain the coordinates and per-node uncertainty estimates.
20
+
21
+ Note: Pre-processing is not needed, it will be done automatically and removed after the segmentation.
22
+ """)
23
+
24
+ with gr.Tab("Segment Image"):
25
+ with gr.Row():
26
+ with gr.Column(scale=1):
27
+ image_input = gr.Image(
28
+ type="numpy",
29
+ tool="sketch",
30
+ image_mode="L",
31
+ height=450,
32
+ )
33
+
34
+ noise_slider = gr.Slider(
35
+ label="Gaussian Noise Std Dev",
36
+ minimum=0.0,
37
+ maximum=0.25,
38
+ step=0.01,
39
+ value=0.0
40
+ )
41
+
42
+ with gr.Row():
43
+ clear_button = gr.Button("Clear")
44
+ image_button = gr.Button("Segment Image")
45
+
46
+ gr.Examples(inputs=image_input, examples=[
47
+ 'utils/example1.jpg','utils/example2.jpg',
48
+ 'utils/example3.png','utils/example4.jpg'
49
+ ])
50
+
51
+ with gr.Column(scale=2):
52
+ image_output = gr.Image(type="filepath", height=450)
53
+ results = gr.File()
54
+
55
+ gr.Markdown("""
56
+ Example images extracted from Wikipedia, released under:
57
+ 1. CC0 Universial Public Domain. Source: https://commons.wikimedia.org/wiki/File:Normal_posteroanterior_(PA)_chest_radiograph_(X-ray).jpg
58
+ 2. Creative Commons Attribution-Share Alike 4.0 International. Source: https://commons.wikimedia.org/wiki/File:Chest_X-ray.jpg
59
+ 3. Creative Commons Attribution 3.0 Unported. Source https://commons.wikimedia.org/wiki/File:Implantable_cardioverter_defibrillator_chest_X-ray.jpg
60
+ 4. Creative Commons Attribution-Share Alike 3.0 Unported. Source: https://commons.wikimedia.org/wiki/File:Medical_X-Ray_imaging_PRD06_nevit.jpg
61
+ """)
62
+
63
+ clear_button.click(lambda: None, None, image_input, queue=False)
64
+ clear_button.click(lambda: None, None, image_output, queue=False)
65
+ image_button.click(
66
+ segment,
67
+ inputs=[image_input, noise_slider],
68
+ outputs=[image_output, results],
69
+ queue=False
70
+ )
71
+ demo.launch(share=True)
models/HybridGNet2IGSC.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.modelUtils import ChebConv, Pool, residualBlock
5
+ import torchvision.ops.roi_align as roi_align
6
+ import numpy as np
7
+
8
+ class EncoderConv(nn.Module):
9
+ def __init__(self, latents = 64, hw = 32):
10
+ super(EncoderConv, self).__init__()
11
+
12
+ self.latents = latents
13
+ self.c = 4
14
+
15
+ self.size = self.c * np.array([2,4,8,16,32], dtype = np.intc)
16
+
17
+ self.maxpool = nn.MaxPool2d(2)
18
+
19
+ self.dconv_down1 = residualBlock(1, self.size[0])
20
+ self.dconv_down2 = residualBlock(self.size[0], self.size[1])
21
+ self.dconv_down3 = residualBlock(self.size[1], self.size[2])
22
+ self.dconv_down4 = residualBlock(self.size[2], self.size[3])
23
+ self.dconv_down5 = residualBlock(self.size[3], self.size[4])
24
+ self.dconv_down6 = residualBlock(self.size[4], self.size[4])
25
+
26
+ self.fc_mu = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
27
+ self.fc_logvar = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
28
+
29
+ def forward(self, x):
30
+ x = self.dconv_down1(x)
31
+ x = self.maxpool(x)
32
+
33
+ x = self.dconv_down2(x)
34
+ x = self.maxpool(x)
35
+
36
+ conv3 = self.dconv_down3(x)
37
+ x = self.maxpool(conv3)
38
+
39
+ conv4 = self.dconv_down4(x)
40
+ x = self.maxpool(conv4)
41
+
42
+ conv5 = self.dconv_down5(x)
43
+ x = self.maxpool(conv5)
44
+
45
+ conv6 = self.dconv_down6(x)
46
+
47
+ x = conv6.view(conv6.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
48
+
49
+ x_mu = self.fc_mu(x)
50
+ x_logvar = self.fc_logvar(x)
51
+
52
+ return x_mu, x_logvar, conv6, conv5
53
+
54
+
55
+ class SkipBlock(nn.Module):
56
+ def __init__(self, in_filters, window):
57
+ super(SkipBlock, self).__init__()
58
+
59
+ self.window = window
60
+ self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False)
61
+
62
+ def lookup(self, pos, layer, salida = (1,1)):
63
+ B = pos.shape[0]
64
+ N = pos.shape[1]
65
+ F = layer.shape[1]
66
+ h = layer.shape[-1]
67
+
68
+ ## Scale from [0,1] to [0, h]
69
+ pos = pos * h
70
+
71
+ _x1 = (self.window[0] // 2) * 1.0
72
+ _x2 = (self.window[0] // 2 + 1) * 1.0
73
+ _y1 = (self.window[1] // 2) * 1.0
74
+ _y2 = (self.window[1] // 2 + 1) * 1.0
75
+
76
+ boxes = []
77
+ for batch in range(0, B):
78
+ x1 = pos[batch,:,0].reshape(-1, 1) - _x1
79
+ x2 = pos[batch,:,0].reshape(-1, 1) + _x2
80
+ y1 = pos[batch,:,1].reshape(-1, 1) - _y1
81
+ y2 = pos[batch,:,1].reshape(-1, 1) + _y2
82
+
83
+ aux = torch.cat([x1, y1, x2, y2], axis = 1)
84
+ boxes.append(aux)
85
+
86
+ skip = roi_align(layer, boxes, output_size = salida, aligned=True)
87
+ vista = skip.view([B, N, -1])
88
+
89
+ return vista
90
+
91
+ def forward(self, x, adj, conv_layer):
92
+ pos = self.graphConv_pre(x, adj)
93
+ skip = self.lookup(pos, conv_layer)
94
+
95
+ return torch.cat((x, skip, pos), axis = 2), pos
96
+
97
+
98
+ class Hybrid(nn.Module):
99
+ def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices):
100
+ super(Hybrid, self).__init__()
101
+
102
+ self.config = config
103
+ hw = config['inputsize'] // 32
104
+ self.z = config['latents']
105
+ self.encoder = EncoderConv(latents = self.z, hw = hw)
106
+ self.eval_sampling = config['eval_sampling']
107
+
108
+ self.downsample_matrices = downsample_matrices
109
+ self.upsample_matrices = upsample_matrices
110
+ self.adjacency_matrices = adjacency_matrices
111
+ self.kld_weight = 1e-5
112
+
113
+ n_nodes = config['n_nodes']
114
+ self.filters = config['filters']
115
+ self.K = 6
116
+ self.window = (3,3)
117
+
118
+ # Generate the fully connected layer for the decoder
119
+ outshape = self.filters[-1] * n_nodes[-1]
120
+ self.dec_lin = torch.nn.Linear(self.z, outshape)
121
+
122
+ self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1])
123
+ self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2])
124
+ self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3])
125
+ self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4])
126
+ self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5])
127
+
128
+ outsize1 = self.encoder.size[4]
129
+ outsize2 = self.encoder.size[4]
130
+
131
+ # Store graph convolution layers
132
+ self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K)
133
+ self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K)
134
+
135
+ self.SC_1 = SkipBlock(self.filters[4], self.window)
136
+
137
+ self.graphConv_up4 = ChebConv(self.filters[4] + outsize1 + 2, self.filters[3], self.K)
138
+ self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K)
139
+
140
+ self.SC_2 = SkipBlock(self.filters[2], self.window)
141
+
142
+ self.graphConv_up2 = ChebConv(self.filters[2] + outsize2 + 2, self.filters[1], self.K)
143
+ self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False)
144
+
145
+ self.pool = Pool()
146
+
147
+ self.reset_parameters()
148
+
149
+ def reset_parameters(self):
150
+ torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1)
151
+
152
+ def sampling(self, mu, log_var):
153
+ std = torch.exp(0.5*log_var)
154
+ eps = torch.randn_like(std)
155
+ return eps.mul(std).add_(mu)
156
+
157
+ def encode(self, x):
158
+ """Encode the input and return latent representations and skip connections"""
159
+ mu, log_var, conv6, conv5 = self.encoder(x)
160
+ return mu, log_var, conv6, conv5
161
+
162
+ def decode(self, z, conv6, conv5):
163
+ """Decode from latent space using skip connections"""
164
+ x = self.dec_lin(z)
165
+ x = F.relu(x)
166
+
167
+ x = x.reshape(x.shape[0], -1, self.filters[-1])
168
+
169
+ x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices())
170
+ x = self.normalization6u(x)
171
+ x = F.relu(x)
172
+
173
+ x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices())
174
+ x = self.normalization5u(x)
175
+ x = F.relu(x)
176
+
177
+ x, pos1 = self.SC_1(x, self.adjacency_matrices[3]._indices(), conv6)
178
+
179
+ x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices())
180
+ x = self.normalization4u(x)
181
+ x = F.relu(x)
182
+
183
+ x = self.pool(x, self.upsample_matrices[0])
184
+
185
+ x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices())
186
+ x = self.normalization3u(x)
187
+ x = F.relu(x)
188
+
189
+ x, pos2 = self.SC_2(x, self.adjacency_matrices[1]._indices(), conv5)
190
+
191
+ x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices())
192
+ x = self.normalization2u(x)
193
+ x = F.relu(x)
194
+
195
+ x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # No relu and no bias
196
+
197
+ return x, pos1, pos2
198
+
199
+ def forward(self, x):
200
+ """Full forward pass (both encoding and decoding)"""
201
+ self.mu, self.log_var, conv6, conv5 = self.encode(x)
202
+
203
+ if self.training or self.eval_sampling:
204
+ z = self.sampling(self.mu, self.log_var)
205
+ else:
206
+ z = self.mu
207
+
208
+ return self.decode(z, conv6, conv5)
models/__pycache__/HybridGNet2IGSC.cpython-310.pyc ADDED
Binary file (6.26 kB). View file
 
models/__pycache__/HybridGNet2IGSC.cpython-39.pyc ADDED
Binary file (6.23 kB). View file
 
models/__pycache__/modelUtils.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
models/__pycache__/modelUtils.cpython-39.pyc ADDED
Binary file (2.68 kB). View file
 
models/modelUtils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch_geometric.nn.conv import MessagePassing
2
+ from torch_geometric.nn.conv.cheb_conv import ChebConv
3
+ from torch_geometric.nn.inits import zeros, normal
4
+
5
+ # We change the default initialization from zeros to a normal distribution
6
+ class ChebConv(ChebConv):
7
+ def reset_parameters(self):
8
+ for lin in self.lins:
9
+ normal(lin, mean = 0, std = 0.1)
10
+ #lin.reset_parameters()
11
+ normal(self.bias, mean = 0, std = 0.1)
12
+ #zeros(self.bias)
13
+
14
+ # Pooling from COMA: https://github.com/pixelite1201/pytorch_coma/blob/master/layers.py
15
+ class Pool(MessagePassing):
16
+ def __init__(self):
17
+ # source_to_target is the default value for flow, but is specified here for explicitness
18
+ super(Pool, self).__init__(flow='source_to_target')
19
+
20
+ def forward(self, x, pool_mat, dtype=None):
21
+ pool_mat = pool_mat.transpose(0, 1)
22
+ out = self.propagate(edge_index=pool_mat._indices(), x=x, norm=pool_mat._values(), size=pool_mat.size())
23
+ return out
24
+
25
+ def message(self, x_j, norm):
26
+ return norm.view(1, -1, 1) * x_j
27
+
28
+
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ class residualBlock(nn.Module):
33
+ def __init__(self, in_channels, out_channels, stride=1):
34
+ """
35
+ Args:
36
+ in_channels (int): Number of input channels.
37
+ out_channels (int): Number of output channels.
38
+ stride (int): Controls the stride.
39
+ """
40
+ super(residualBlock, self).__init__()
41
+
42
+ self.skip = nn.Sequential()
43
+
44
+ if stride != 1 or in_channels != out_channels:
45
+ self.skip = nn.Sequential(
46
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
47
+ nn.BatchNorm2d(out_channels, track_running_stats=False))
48
+ else:
49
+ self.skip = None
50
+
51
+ self.block = nn.Sequential(nn.BatchNorm2d(in_channels, track_running_stats=False),
52
+ nn.ReLU(inplace=True),
53
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
54
+ nn.BatchNorm2d(out_channels, track_running_stats=False),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
57
+ )
58
+
59
+ def forward(self, x):
60
+ identity = x
61
+ out = self.block(x)
62
+
63
+ if self.skip is not None:
64
+ identity = self.skip(x)
65
+
66
+ out += identity
67
+ out = F.relu(out)
68
+
69
+ return out
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ numpy==1.25.0
3
+ opencv-python==4.8.0.74
4
+ scipy==1.10.1
5
+ torch_geometric==2.3.0
6
+ torchvision==0.15.2
tmp/.gitkeep ADDED
File without changes
weights/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c5658bf5ff990fbcccde968a198bb48b57db22ffeaca32a837d6031a10395e2
3
+ size 70083499