Machine Learning with CAD Data¶
Download the ABC dataset from https://deep-geometry.github.io/abc-dataset
Import¶
import meshplot as mp
import numpy as np
import igl
import yaml
from yaml import CLoader as Loader
Reading CAD Data¶
def read_model(obj_path, feat_path):
v, _, n, f, _, ni = igl.read_obj(obj_path)
with open(feat_path) as fi:
feat = yaml.load(fi, Loader=Loader)
m = {"vertices": v, "face_indices": f, "normals": n,
"normal_indices": ni, "features": feat}
return m
m = read_model("data/test_trimesh.obj", "data/test_features.yml")
v, f, feat = m["vertices"], m["face_indices"], m["features"]
print(v.shape, f.shape)
print(list(feat.keys()))
CAD Features: Surface Normals¶
from data.utils import get_averaged_normals
# Average normals at vertices with multiple normals
av_normals = get_averaged_normals(m)
p = mp.plot(v, f, c=np.abs(av_normals))
# Add normals to the plot
p.add_lines(m["vertices"], m["vertices"] + av_normals,
shading={"line_color": "black"})
# Determine normals with uniform weighting in libigl
#normals = igl.per_vertex_normals(v, f)
#p.add_lines(m["vertices"], m["vertices"] + normals,
# shading={"line_color": "red"})
CAD Features: Sharp Edges/Curves¶
# Retrieve the sharp features
lines = []
for i, fe in enumerate(feat["curves"]):
if fe["sharp"]:
for j in range(len(fe["vert_indices"])-1):
lines.append([fe["vert_indices"][j], fe["vert_indices"][j+1]])
# Visualize the sharp features
p = mp.plot(v, f)
p.add_edges(v, np.array(lines))
CAD Features: Sharp Edges/Curves¶
# Retrieve the sharp features
v_class = np.zeros((v.shape[0], 1))
for i, fe in enumerate(feat["curves"]):
if fe["sharp"]:
v_class[fe["vert_indices"]] = 1
# Visualize the sharp features
mp.plot(v, c=-v_class, shading={"point_size": 4.})
CAD Features: Surface Patch Types¶
# Retrieve the surface patch types
t_map = {"Plane": 0, "Cylinder": 1,
"Cone": 2, "Sphere": 3,
"Torus": 4, "Bezier": 5,
"BSpline": 6, "Revolution": 7,
"Extrusion": 8, "Other": 9}
c1 = np.zeros(f.shape[0])
for fe in feat["surfaces"]:
t = t_map[fe["type"]]
for j in fe["face_indices"]:
c1[j] = t
# Visualize the patch types
mp.plot(v, f, -c1)
CAD Features: Surface Patch Types¶
# Retrieve the surface patch types per vertex
c2 = np.zeros(v.shape[0])
for fe in feat["surfaces"]:
t = t_map[fe["type"]]
for j in fe["vert_indices"]:
c2[j] = t
# Visualize the vertices
mp.plot(v, c=-c2, shading={"point_size": 4.})
Machine Learning Setup¶
Installation¶
We use Pytorch and Pytorch Geometric that can be installed as described in their documentation.
Import¶
import numpy as np
import meshplot as mp
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Dropout, Linear
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.nn import DynamicEdgeConv
from data.utils import MLP
from data.utils import ABCDataset
Loading the CAD Data¶
tf_train = T.Compose([
T.FixedPoints(5000, replace=False),
T.RandomTranslate(0.002),
T.RandomRotate(15, axis=0),
T.RandomRotate(15, axis=1),
T.RandomRotate(15, axis=2)
])
tf_test = T.Compose([T.FixedPoints(10000, replace=False)])
pre = T.NormalizeScale()
train_dataset_n = ABCDataset("data/ml/ABC", "Normals", True, tf_train, pre)
test_dataset_n = ABCDataset("data/ml/ABC", "Normals", False, tf_test, pre)
train_dataset_e = ABCDataset("data/ml/ABC", "Edges", True, tf_train, pre)
test_dataset_e = ABCDataset("data/ml/ABC", "Edges", False, tf_test, pre)
Statistics and Visualization¶
dataset = test_dataset_e
print("Number of models:", len(dataset))
print("Number of classes:", dataset.num_classes)
counts = [0]*dataset.num_classes
total = 0
for d in dataset:
y = d.y.numpy()
for i in range(dataset.num_classes):
counts[i] += np.sum(y==i)
total += y.shape[0]
for i, c in enumerate(counts):
print("%0.2f%% labels are of class %i."%(c/total, i))
d = test_dataset_e[3]
v = d.pos.numpy()
y = d.y.numpy()
print("Shape of model:", v.shape)
print("Shape of labels:", y.shape)
mp.plot(v, c=-y, shading={"point_size": 0.15})
d = test_dataset_n[4]
v = d.pos.numpy()
y = d.y.numpy()
mp.plot(v, c=np.abs(y), shading={"point_size": 0.15})
Defining the Network (DGCNN)¶
class Net(torch.nn.Module):
def __init__(self, out_channels, k=30, aggr='max',
typ='Edges'):
super(Net, self).__init__()
self.typ = typ
self.conv1 = DynamicEdgeConv(MLP([2 * 3, 64, 64]), k, aggr)
self.conv2 = DynamicEdgeConv(MLP([2* 64, 64, 64]), k, aggr)
self.conv3 = DynamicEdgeConv(MLP([2* 64, 64, 64]), k, aggr)
self.lin1 = MLP([3 * 64, 1024])
self.mlp = Sequential(MLP([1024, 256]), Dropout(0.5),
MLP([256, 128]), Dropout(0.5),
Linear(128, out_channels))
def forward(self, data):
pos, batch = data.pos, data.batch
x1 = self.conv1(pos, batch)
x2 = self.conv2(x1, batch)
x3 = self.conv3(x2, batch)
out = self.lin1(torch.cat([x1, x2, x3], dim=1))
out = self.mlp(out)
if self.typ == "Edges" or self.typ == "Types":
return F.log_softmax(out, dim=1)
if self.typ == "Normals":
return F.normalize(out, p=2, dim=-1)
Visualizing the Nearest Neighbour Graph¶
tf_pre = T.Compose([
T.FixedPoints(1000),
T.NormalizeScale(),
T.KNNGraph(k=6)
])
dataset = ABCDataset("data/ml/ABC_graph", "Edges", pre_transform=tf_pre)
vd = dataset[0].pos.numpy()
p = mp.plot(vd, shading={"point_size": 0.15})
p.add_edges(vd, dataset[0].edge_index.numpy().T)
Defining the Loss Function (Normals)¶
class Cosine_Loss(torch.nn.Module):
def __init__(self):
super(Cosine_Loss,self).__init__()
def forward(self, x, y):
dotp = torch.mul(x, y).sum(1)
loss = torch.sum(1 - dotp.pow(2)) / x.shape[0]
angle = torch.sum(torch.acos(
torch.clamp(torch.abs(dotp), 0.0, 1.0))) / x.shape[0]
return loss, angle
cosine_loss = Cosine_Loss()
Defining the Training Procedure¶
def train(loader, typ="Edges"):
model.train()
for i, data in enumerate(loader):
total_loss = correct_nodes = total_nodes = 0
data = data.to(device)
optimizer.zero_grad()
out = model(data)
if typ == "Edges" or typ == "Types":
loss = F.nll_loss(out, data.y)
if typ == "Normals":
loss, angle = cosine_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if typ == "Edges" or typ == "Types":
pred = out.max(dim=1)[1]
correct_nodes += pred.eq(data.y).sum().item()
total_nodes += data.num_nodes
acc = correct_nodes / total_nodes
if typ == "Normals":
acc = angle.item()*180/np.pi
print('[Train {}/{}] Loss: {:.4f}, Accuracy: {:.4f}'.format(
i + 1, len(loader), total_loss / loader.batch_size, acc))
Defining the Testing Procedure¶
def test(loader, typ="Edges"):
model.eval()
correct_nodes = total_nodes = 0
for data in loader:
data = data.to(device)
with torch.no_grad():
out = model(data)
if typ == "Edges" or typ == "Types":
pred = out.max(dim=1)[1]
correct_nodes += pred.eq(data.y).sum().item()
total_nodes += data.num_nodes
if typ == "Normals":
_, angle = cosine_loss(out, data.y)
correct_nodes += angle.item() * 180 / np.pi
total_nodes += 1
return correct_nodes / total_nodes
Running the Training¶
typ = "Edges"
if typ == "Edges":
train_dataset = train_dataset_e
test_dataset = test_dataset_e
if typ == "Normals":
train_dataset = train_dataset_n
test_dataset = test_dataset_n
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(train_dataset.num_classes, k=30, typ=typ).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=20, gamma=0.8)
for epoch in range(1, 2):
train(train_loader, typ=typ)
acc = test(test_loader, typ=typ)
print('Test: {:02d}, Accuracy: {:.4f}'.format(epoch, acc))
torch.save(model.state_dict(), "%02i_%.2f.dat"%(epoch, acc))
scheduler.step()
Loading a Pretrained Model - Edges¶
typ = "Edges"
test_dataset = test_dataset_e
state_file = "Edges_72_0.96.dat"
model = Net(test_dataset.num_classes, k=30, typ=typ)
if torch.cuda.is_available():
state = torch.load("data/ml/ABC/models/%s"%state_file)
device = torch.device('cuda')
else:
state = torch.load("data/ml/ABC/models/%s"%state_file,
map_location=torch.device('cpu'))
device = torch.device('cpu')
model.load_state_dict(state)
model.to(device);
Visualizing the Predicted Results¶
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
loader = iter(test_loader)
d = loader.next()
with torch.no_grad():
out = model(d.to(device))
v = d.pos.cpu().numpy()
y = d.y.cpu().numpy()
# Calculate accuracy
acc = test(test_loader)
print('Accuracy: {:.4f}'.format(acc))
# Plot groundtruth
mp.plot(v, c=-y, shading={"point_size":0.15})
e = out.max(dim=1)[1].cpu().numpy()
# Plot estimation
mp.plot(v, c=-e, shading={"point_size": 0.15})
Loading a Pretrained Model - Normals¶
typ = "Normals"
test_dataset = test_dataset_n
state_file = "Normals_44_12.52.dat"
model = Net(test_dataset.num_classes, k=30, typ=typ)
if torch.cuda.is_available():
state = torch.load("data/ml/ABC/models/%s"%state_file)
device = torch.device('cuda')
else:
state = torch.load("data/ml/ABC/models/%s"%state_file,
map_location=torch.device('cpu'))
device = torch.device('cpu')
model.load_state_dict(state)
model.to(device);
Visualizing the Predicted Results¶
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
loader = iter(test_loader)
d = loader.next()
with torch.no_grad():
out = model(d.to(device))
v = d.pos.cpu().numpy()
y = d.y.cpu().numpy()
# Calculate accuracy
_, angle = cosine_loss(out, d.y)
print(angle.item() * 180 / np.pi)
# Plot groundtruth
c1 = np.abs(y)
mp.plot(v, c=c1, shading={"point_size":0.15})
n = out.cpu().numpy()
c2 = np.abs(n)
# Plot estimation
mp.plot(v, c=c2, shading={"point_size": 0.15})