import os from tqdm import tqdm import pickle import argparse import pathlib import json import time import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data import numpy as np import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from metrics import ConfusionMatrix import data_transforms import argparse import random import traceback """ Model """ class STN3d(nn.Module): def __init__(self, in_channels): super(STN3d, self).__init__() self.conv_layers = nn.Sequential( nn.Conv1d(in_channels, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True), nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Conv1d(128, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU(inplace=True) ) self.linear_layers = nn.Sequential( nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, 9) ) self.iden = torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)).reshape(1, 9) def forward(self, x): batchsize = x.size()[0] x = self.conv_layers(x) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) x = self.linear_layers(x) iden = self.iden.repeat(batchsize, 1).to(x.device) x = x + iden x = x.view(-1, 3, 3) return x class STNkd(nn.Module): def __init__(self, k=64): super(STNkd, self).__init__() self.conv_layers = nn.Sequential( nn.Conv1d(k, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True), nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Conv1d(128, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU(inplace=True) ) self.linear_layers = nn.Sequential( nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, k * k) ) self.k = k self.iden = torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)).reshape(1, self.k * self.k) def forward(self, x): batchsize = x.size()[0] x = self.conv_layers(x) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) x = self.linear_layers(x) iden = self.iden.repeat(batchsize, 1).to(x.device) x = x + iden x = x.view(-1, self.k, self.k) return x class PointNetEncoder(nn.Module): def __init__(self, global_feat=True, feature_transform=False, in_channels=3): super(PointNetEncoder, self).__init__() self.stn = STN3d(in_channels) self.conv_layer1 = nn.Sequential( nn.Conv1d(in_channels, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True), nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True) ) self.conv_layer2 = nn.Sequential( nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU(inplace=True) ) self.conv_layer3 = nn.Sequential( nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True) ) self.conv_layer4 = nn.Sequential( nn.Conv1d(128, 1024, 1), nn.BatchNorm1d(1024) ) self.global_feat = global_feat self.feature_transform = feature_transform if self.feature_transform: self.fstn = STNkd(k=64) def forward(self, x): B, D, N = x.size() trans = self.stn(x) x = x.transpose(2, 1) if D > 3: feature = x[:, :, 3:] x = x[:, :, :3] x = torch.bmm(x, trans) if D > 3: x = torch.cat([x, feature], dim=2) x = x.transpose(2, 1) x = self.conv_layer1(x) if self.feature_transform: trans_feat = self.fstn(x) x = x.transpose(2, 1) x = torch.bmm(x, trans_feat) x = x.transpose(2, 1) else: trans_feat = None pointfeat = x x = self.conv_layer2(x) x = self.conv_layer3(x) x = self.conv_layer4(x) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) # Construct graph and compute context-aware features graph = construct_graph(x, args.k) context_features = compute_context_aware_features(x, graph) x = x + context_features if self.global_feat: return x, trans, trans_feat else: x = x.view(-1, 1024, 1).repeat(1, 1, N) return torch.cat([x, pointfeat], 1), trans, trans_feat def construct_graph(points, k): """ Construct a dynamic graph where nodes represent points and edges capture semantic similarities. """ # Compute pairwise distances dist = torch.cdist(points, points) # Get the top k neighbors _, indices = torch.topk(dist, k, largest=False, dim=1) return indices def compute_context_aware_features(points, graph, normalization_method='mean'): """ Compute context-aware feature adjustments using the constructed graph. """ # Initialize context-aware features context_features = torch.zeros_like(points) for i in range(points.size(0)): neighbors = graph[i] if normalization_method == 'mean': context_features[i] = points[neighbors].mean(dim=0) elif normalization_method == 'max': context_features[i] = points[neighbors].max(dim=0)[0] elif normalization_method == 'min': context_features[i] = points[neighbors].min(dim=0)[0] elif normalization_method == 'std': context_features[i] = points[neighbors].std(dim=0) else: raise ValueError("Unknown normalization method: {}".format(normalization_method)) return context_features def feature_transform_reguliarzer(trans): d = trans.size()[1] I = torch.eye(d)[None, :, :] if trans.is_cuda: I = I.cuda() loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))) return loss class Model(nn.Module): def __init__(self, in_channels=3, num_classes=40, scale=0.001): super().__init__() self.mat_diff_loss_scale = scale self.backbone = PointNetEncoder(global_feat=True, feature_transform=True, in_channels=in_channels) self.cls_head = nn.Sequential( nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, 256), nn.Dropout(p=0.4), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, num_classes) ) def forward(self, x, gts): x, trans, trans_feat = self.backbone(x) x = self.cls_head(x) x = F.log_softmax(x, dim=1) loss = F.nll_loss(x, gts) mat_diff_loss = feature_transform_reguliarzer(trans_feat) total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale return total_loss, x """ dataset and normalization """ def pc_normalize(pc): centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc**2, axis=1))) pc = pc / m return pc class ModelNetDataset(Dataset): def __init__(self, data_root, num_category, num_points, split='train'): self.root = data_root self.npoints = num_points self.uniform = True self.use_normals = True self.num_category = num_category if self.num_category == 10: self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt') else: self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.cat = [line.rstrip() for line in open(self.catfile)] self.classes = dict(zip(self.cat, range(len(self.cat)))) shape_ids = {} if self.num_category == 10: shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))] shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))] else: shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] assert (split == 'train' or split == 'test') shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i in range(len(shape_ids[split]))] print('The size of %s data is %d' % (split, len(self.datapath))) if self.uniform: self.data_path = os.path.join(data_root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints)) else: self.data_path = os.path.join(data_root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints)) print('Load processed data from %s...' % self.data_path) with open(self.data_path, 'rb') as f: self.list_of_points, self.list_of_labels = pickle.load(f) def __len__(self): return len(self.datapath) def __getitem__(self, index): point_set, label = self.list_of_points[index], self.list_of_labels[index] point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) if not self.use_normals: point_set = point_set[:, 0:3] return point_set, label[0] def seed_everything(seed=11): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def main(args): seed_everything(args.seed) final_infos = {} all_results = {} pathlib.Path(args.out_dir).mkdir(parents=True, exist_ok=True) datasets, dataloaders = {}, {} for split in ['train', 'test']: datasets[split] = ModelNetDataset(args.data_root, args.num_category, args.num_points, split) dataloaders[split] = DataLoader(datasets[split], batch_size=args.batch_size, shuffle=(split == 'train'), drop_last=(split == 'train'), num_workers=8) model = Model(in_channels=args.in_channels).cuda() optimizer = torch.optim.Adam( model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4 ) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=20, gamma=0.7 ) train_losses = [] print("Training model...") model.train() global_step = 0 cur_epoch = 0 best_oa = 0 best_acc = 0 start_time = time.time() for epoch in tqdm(range(args.max_epoch), desc='training'): model.train() cm = ConfusionMatrix(num_classes=len(datasets['train'].classes)) for points, target in tqdm(dataloaders['train'], desc=f'epoch {cur_epoch}/{args.max_epoch}'): # data transforms points = points.data.numpy() points = data_transforms.random_point_dropout(points) points[:, :, 0:3] = data_transforms.random_scale_point_cloud(points[:, :, 0:3]) points[:, :, 0:3] = data_transforms.shift_point_cloud(points[:, :, 0:3]) points = torch.from_numpy(points).transpose(2, 1).contiguous() points, target = points.cuda(), target.long().cuda() loss, logits = model(points, target) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1, norm_type=2) optimizer.step() model.zero_grad() logs = {"loss": loss.detach().item()} train_losses.append(loss.detach().item()) cm.update(logits.argmax(dim=1), target) scheduler.step() end_time = time.time() training_time = end_time - start_time macc, overallacc, accs = cm.all_acc() print(f"iter: {global_step}/{args.max_epoch*len(dataloaders['train'])}, \ train_macc: {macc}, train_oa: {overallacc}") if (cur_epoch % args.val_per_epoch == 0 and cur_epoch != 0) or cur_epoch == (args.max_epoch - 1): model.eval() cm = ConfusionMatrix(num_classes=datasets['test'].num_category) pbar = tqdm(enumerate(dataloaders['test']), total=dataloaders['test'].__len__()) # with torch.no_grad(): for idx, (points, target) in pbar: points, target = points.cuda(), target.long().cuda() points = points.transpose(2, 1).contiguous() loss, logits = model(points, target) cm.update(logits.argmax(dim=1), target) tp, count = cm.tp, cm.count macc, overallacc, accs = cm.cal_acc(tp, count) print(f"iter: {global_step}/{args.max_epoch*len(dataloaders['train'])}, \ val_macc: {macc}, val_oa: {overallacc}") if overallacc > best_oa: best_oa = overallacc best_acc = macc best_epoch = cur_epoch torch.save(model.state_dict(), os.path.join(args.out_dir, 'best.pth')) cur_epoch += 1 print(f"finish epoch {cur_epoch} training") final_infos = { "modelnet" + str(args.num_category):{ "means":{ "best_oa": best_oa, "best_acc": best_acc, "epoch": best_epoch } } } with open(os.path.join(args.out_dir, "final_info.json"), "w") as f: json.dump(final_infos, f) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--out_dir", type=str, default="run_0") parser.add_argument("--in_channels", type=int, default=6) parser.add_argument("--num_points", type=int, default=1024) parser.add_argument("--num_category", type=int, choices=[10, 40], default=40) parser.add_argument("--data_root", type=str, default='./datasets/modelnet40') parser.add_argument("--learning_rate", type=float, default=1e-3) parser.add_argument("--max_epoch", type=int, default=200) parser.add_argument("--val_per_epoch", type=int, default=5) parser.add_argument("--k", type=int, default=5, help="Number of neighbors for graph construction") parser.add_argument("--seed", type=int, default=666) args = parser.parse_args() try: main(args) except Exception as e: print("Original error in subprocess:", flush=True) traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) raise