stack-3d/src/stack3d/formats/unet.py

136 lines
No EOL
4.4 KiB
Python

""" Format conversion for Russ Bate's unet format.
"""
from __future__ import print_function
import inspect
import os
import csv
from subprocess import call
from argparse import ArgumentParser
import scipy.io
import networkx as nx
import stack3d.formats
def load_skeleton(path):
"""
Load the skeleton from a pickle
"""
# Delayed import so script can be run with both Python 2 and 3
from unet_core.vessel_analysis import VesselTree
v = VesselTree()
v.load_skeleton(path)
return v.skeleton
def skeleton_to_vtp(path):
temp_output_path = os.path.dirname(os.path.realpath(path)) + "/temp_network_csv.csv"
script_path = inspect.getfile(stack3d.formats.unet).replace(".pyc", ".py")
launch_args = " --input '" + path + "' --output '" + temp_output_path + "' --format csv --field diameter"
call("python3 " + script_path + launch_args, shell=True)
# Avoid vtk import when method not used to reduce dependency overhead
import vtk
polydata = vtk.vtkPolyData()
points = vtk.vtkPoints()
cells = vtk.vtkCellArray()
data = vtk.vtkDoubleArray()
with open(temp_output_path, 'rb') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='|')
counter = 0
first = 0
for row in reader:
if first == 0:
first += 1
continue
point0 = [float(x) for x in row[0:3]]
point1 = [float(x) for x in row[3:6]]
diameter = float(row[6])
points.InsertNextPoint(point0)
points.InsertNextPoint(point1)
line = vtk.vtkLine()
line.GetPointIds().SetId(0, counter)
line.GetPointIds().SetId(1, counter+1)
counter += 2
cells.InsertNextCell(line)
data.InsertNextTuple1(diameter)
polydata.SetPoints(points)
polydata.SetLines(cells)
polydata.GetCellData().SetScalars(data)
clean = vtk.vtkCleanPolyData()
clean.SetInputData(polydata)
clean.Update()
return clean.GetOutput()
def skeleton_to_matlab(skeleton, field, output_path):
"""
Convert the skeleton's network description to matlab sparse matrix format.
Each mat file corresponds to a connected component
"""
# if os.path.isabs(output_path) and not os.path.exists(output_path):
# os.makedirs(output_path)
for idx, component in enumerate(skeleton.components):
for n, nbrs in component.graph.adjacency_iter():
for nbr, edict in nbrs.items():
branch = edict[0]["branch"]
component.graph.add_edge(n, nbr, weight=branch.diameter)
scipy_mat = nx.to_scipy_sparse_matrix(component.graph)
output = {"skeleton": scipy_mat}
output_file = output_path + "/skeleton_" + field + "_Comp_" + str(idx) + ".mat"
scipy.io.savemat(output_file, output)
def skeleton_to_csv(skeleton, output_path):
"""
Convert the skeleton's network description to csv format
"""
# if os.path.isabs(output_path) and not os.path.exists(output_path):
# os.makedirs(output_path)
f = open(output_path, "w")
f.write("P0 - x, P0 - y, P0 - z, P1 - x, P1 - y, P1 - z, diameter")
for _, _, branch in skeleton.skeleton_branch_iter():
for idx, p1 in enumerate(branch.points):
if idx > 0:
p0 = branch.points[idx-1]
p0_s = str(p0.x) + "," + str(p0.y) + "," + str(p0.z)
p1_s = str(p1.x) + "," + str(p1.y) + "," + str(p1.z)
f.write(p0_s + "," + p1_s + "," + str(branch.diameter) + "\n")
f.close()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--input', type=str, help='Input skeleton file.')
parser.add_argument('--output', type=str, help='Output directory.')
parser.add_argument('--format', type=str,
help='Output format.',
choices=['csv', 'mat'],
default='mat')
parser.add_argument('--field', type=str,
help='Output field.',
choices=['diameter', "length", "tortuosity"],
default='diameter')
args = parser.parse_args()
skeleton = load_skeleton(args.input)
if "csv" in args.format:
skeleton_to_csv(skeleton, args.output)
elif "mat" in args.format:
skeleton_to_matlab(skeleton, args.field, args.output)