1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
| class Train: def __init__(self,root,weight_path): self.summaryWriter = SummaryWriter('logs') self.train_dataset = MyDataset(root=root,is_train= True) self.test_dataset = MyDataset(root=root,is_train= False) self.train_dataLoader = DataLoader(self.train_dataset,batch_size=50,shuffle=True) self.test_dataLoader = DataLoader(self.test_dataset,batch_size=50,shuffle=True)
self.net = MyNet().to(DEVICE)
if os.path.exists(weight_path): self.net.load_state_dict(torch.load(weight_path))
self.opt = optim.Adam(self.net.parameters())
self.label_loss_fun = nn.BCEWithLogitsLoss() self.position_loss_fun = nn.MSELoss() self.sort_loss_fun = nn.CrossEntropyLoss()
self.train = True self.test = True def __call__(self): index1,index2 = 0,0 for epoch in range(1000): if self.train: for i, (img,label,position,sort) in enumerate(self.train_dataLoader): self.net.train() img, label, position, sort = img.to(DEVICE), label.to(DEVICE), position.to(DEVICE), sort.to(DEVICE) out_label,out_position,out_sort = self.net(img) label_loss = self.label_loss_fun(out_label,label)
position_loss=self.position_loss_fun(out_position,position) sort = sort[torch.where(sort >= 0)] out_sort = out_sort[torch.where(sort >= 0)] sort_loss = self.sort_loss_fun(out_sort,sort)
train_loss = label_loss + position_loss + sort_loss self.opt.zero_grad() train_loss.backward() self.opt.step()
if i%10 ==0 : print(f'train_loss{i}===>',train_loss.item()) self.summaryWriter.add_scalar('train_loss',train_loss,index1) index1 +=1 data_time = str(datetime.datetime.now()).replace(' ', '-').replace(':','_').replace('·','_') save_dir = 'param' if not os.path.exists(save_dir): os.makedirs(save_dir)
torch.save(self.net.state_dict(), f'{save_dir}/{data_time}-{epoch}.pt') if self.test: sum_sort_acc,sum_label_acc = 0,0 for i, (img,label,position,sort) in enumerate(self.test_dataLoader): self.net.train() img, label, position, sort = img.to(DEVICE), label.to(DEVICE), position.to(DEVICE), sort.to(DEVICE) out_label,out_position,out_sort = self.net(img) label_loss = self.label_loss_fun(out_label,label)
position_loss=self.position_loss_fun(out_position,position) sort = sort[torch.where(sort>=0)] out_sort = out_sort[torch.where(sort >= 0)] sort_loss = self.sort_loss_fun(out_sort,sort)
test_loss = label_loss + position_loss + sort_loss
out_label = torch.tensor(torch.sigmoid(out_label)) out_label[torch.where(out_label>=0.5)] = 1 out_label[torch.where(out_label<0.5)] = 0
label_acc = torch.mean(torch.eq(out_label,label).float()) sum_label_acc += label_acc
if out_sort.numel() > 0: out_sort = torch.argmax(torch.softmax(out_sort, dim=1)) out_sort = out_sort.to(sort.device) else: out_sort = torch.tensor([], device=sort.device)
sort_acc = torch.mean(torch.eq(sort,out_sort).float()) sum_sort_acc += sort_acc if i%10 ==0 : print(f'test_loss{i}===>',test_loss.item()) self.summaryWriter.add_scalar('test_loss',test_loss,index2) index2 +=1 avg_sort_acc = sum_sort_acc/i print(f'avg_sort_acc {epoch}====>',avg_sort_acc) self.summaryWriter.add_scalar('avg_sort_acc',avg_sort_acc,epoch)
avg_label_acc = sum_label_acc/i print(f'avg_label_acc {epoch}====>',avg_label_acc) self.summaryWriter.add_scalar('avg_label_acc',avg_label_acc,epoch)
|