|
| 1 | +import argparse |
| 2 | +import copy |
| 3 | +import csv |
| 4 | +import os |
| 5 | +import warnings |
| 6 | + |
| 7 | +import numpy |
| 8 | +import torch |
| 9 | +import tqdm |
| 10 | +import yaml |
| 11 | +from torch.utils import data |
| 12 | + |
| 13 | +from nets import nn |
| 14 | +from utils import util |
| 15 | +from utils.dataset import Dataset |
| 16 | + |
| 17 | +warnings.filterwarnings("ignore") |
| 18 | + |
| 19 | + |
| 20 | +def learning_rate(args, params): |
| 21 | + def fn(x): |
| 22 | + return (1 - x / args.epochs) * (1.0 - params['lrf']) + params['lrf'] |
| 23 | + |
| 24 | + return fn |
| 25 | + |
| 26 | + |
| 27 | +def train(args, params): |
| 28 | + # Model |
| 29 | + model = nn.yolo_v8_n(len(params['names'].values())).cuda() |
| 30 | + |
| 31 | + # Optimizer |
| 32 | + accumulate = max(round(64 / (args.batch_size * args.world_size)), 1) |
| 33 | + params['weight_decay'] *= args.batch_size * args.world_size * accumulate / 64 |
| 34 | + |
| 35 | + p = [], [], [] |
| 36 | + for v in model.modules(): |
| 37 | + if hasattr(v, 'bias') and isinstance(v.bias, torch.nn.Parameter): |
| 38 | + p[2].append(v.bias) |
| 39 | + if isinstance(v, torch.nn.BatchNorm2d): |
| 40 | + p[1].append(v.weight) |
| 41 | + elif hasattr(v, 'weight') and isinstance(v.weight, torch.nn.Parameter): |
| 42 | + p[0].append(v.weight) |
| 43 | + |
| 44 | + optimizer = torch.optim.SGD(p[2], params['lr0'], params['momentum'], nesterov=True) |
| 45 | + |
| 46 | + optimizer.add_param_group({'params': p[0], 'weight_decay': params['weight_decay']}) |
| 47 | + optimizer.add_param_group({'params': p[1]}) |
| 48 | + del p |
| 49 | + |
| 50 | + # Scheduler |
| 51 | + lr = learning_rate(args, params) |
| 52 | + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr, last_epoch=-1) |
| 53 | + |
| 54 | + # EMA |
| 55 | + ema = util.EMA(model) if args.local_rank == 0 else None |
| 56 | + |
| 57 | + filenames = [] |
| 58 | + with open('../Dataset/COCO/train2017.txt') as reader: |
| 59 | + for filename in reader.readlines(): |
| 60 | + filename = filename.rstrip().split('/')[-1] |
| 61 | + filenames.append('../Dataset/COCO/images/train2017/' + filename) |
| 62 | + |
| 63 | + dataset = Dataset(filenames, args.input_size, True) |
| 64 | + |
| 65 | + if args.world_size <= 1: |
| 66 | + sampler = None |
| 67 | + else: |
| 68 | + sampler = data.distributed.DistributedSampler(dataset) |
| 69 | + |
| 70 | + loader = data.DataLoader(dataset, args.batch_size, sampler is None, sampler, |
| 71 | + num_workers=8, pin_memory=True, collate_fn=Dataset.collate_fn) |
| 72 | + |
| 73 | + if args.world_size > 1: |
| 74 | + # DDP mode |
| 75 | + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| 76 | + model = torch.nn.parallel.DistributedDataParallel(module=model, |
| 77 | + device_ids=[args.local_rank], |
| 78 | + output_device=args.local_rank) |
| 79 | + |
| 80 | + # Start training |
| 81 | + best = 0 |
| 82 | + num_batch = len(loader) |
| 83 | + amp_scale = torch.cuda.amp.GradScaler() |
| 84 | + criterion = util.ComputeLoss(model, params) |
| 85 | + num_warmup = max(round(params['warmup_epochs'] * num_batch), 1000) |
| 86 | + with open('weights/step.csv', 'w') as f: |
| 87 | + if args.local_rank == 0: |
| 88 | + writer = csv.DictWriter(f, fieldnames=['epoch', 'mAP@50', 'mAP']) |
| 89 | + writer.writeheader() |
| 90 | + for epoch in range(args.epochs): |
| 91 | + model.train() |
| 92 | + |
| 93 | + if args.epochs - epoch == 10: |
| 94 | + loader.dataset.mosaic = False |
| 95 | + |
| 96 | + m_loss = util.AverageMeter() |
| 97 | + if args.world_size > 1: |
| 98 | + sampler.set_epoch(epoch) |
| 99 | + p_bar = enumerate(loader) |
| 100 | + if args.local_rank == 0: |
| 101 | + print(('\n' + '%10s' * 3) % ('epoch', 'memory', 'loss')) |
| 102 | + if args.local_rank == 0: |
| 103 | + p_bar = tqdm.tqdm(p_bar, total=num_batch) # progress bar |
| 104 | + |
| 105 | + optimizer.zero_grad() |
| 106 | + |
| 107 | + for i, (samples, targets, _) in p_bar: |
| 108 | + x = i + num_batch * epoch # number of iterations |
| 109 | + samples = samples.cuda().float() / 255 |
| 110 | + targets = targets.cuda() |
| 111 | + |
| 112 | + # Warmup |
| 113 | + if x <= num_warmup: |
| 114 | + xp = [0, num_warmup] |
| 115 | + fp = [1, 64 / (args.batch_size * args.world_size)] |
| 116 | + accumulate = max(1, numpy.interp(x, xp, fp).round()) |
| 117 | + for j, y in enumerate(optimizer.param_groups): |
| 118 | + if j == 0: |
| 119 | + fp = [params['warmup_bias_lr'], y['initial_lr'] * lr(epoch)] |
| 120 | + else: |
| 121 | + fp = [0.0, y['initial_lr'] * lr(epoch)] |
| 122 | + y['lr'] = numpy.interp(x, xp, fp) |
| 123 | + if 'momentum' in y: |
| 124 | + fp = [params['warmup_momentum'], params['momentum']] |
| 125 | + y['momentum'] = numpy.interp(x, xp, fp) |
| 126 | + |
| 127 | + # Forward |
| 128 | + with torch.cuda.amp.autocast(): |
| 129 | + outputs = model(samples) # forward |
| 130 | + loss = criterion(outputs, targets) |
| 131 | + |
| 132 | + m_loss.update(loss.item(), samples.size(0)) |
| 133 | + |
| 134 | + loss *= args.batch_size # loss scaled by batch_size |
| 135 | + loss *= args.world_size # gradient averaged between devices in DDP mode |
| 136 | + |
| 137 | + # Backward |
| 138 | + amp_scale.scale(loss).backward() |
| 139 | + |
| 140 | + # Optimize |
| 141 | + if x % accumulate == 0: |
| 142 | + amp_scale.unscale_(optimizer) # unscale gradients |
| 143 | + util.clip_gradients(model) # clip gradients |
| 144 | + amp_scale.step(optimizer) # optimizer.step |
| 145 | + amp_scale.update() |
| 146 | + optimizer.zero_grad() |
| 147 | + if ema: |
| 148 | + ema.update(model) |
| 149 | + |
| 150 | + # Log |
| 151 | + if args.local_rank == 0: |
| 152 | + memory = f'{torch.cuda.memory_reserved() / 1E9:.3g}G' # (GB) |
| 153 | + s = ('%10s' * 2 + '%10.4g') % (f'{epoch + 1}/{args.epochs}', memory, m_loss.avg) |
| 154 | + p_bar.set_description(s) |
| 155 | + |
| 156 | + del loss |
| 157 | + del outputs |
| 158 | + |
| 159 | + # Scheduler |
| 160 | + scheduler.step() |
| 161 | + |
| 162 | + if args.local_rank == 0: |
| 163 | + # mAP |
| 164 | + last = test(args, ema.ema) |
| 165 | + writer.writerow({'mAP': str(f'{last[1]:.3f}'), |
| 166 | + 'epoch': str(epoch + 1).zfill(3), |
| 167 | + 'mAP@50': str(f'{last[0]:.3f}')}) |
| 168 | + f.flush() |
| 169 | + |
| 170 | + # Update best mAP |
| 171 | + if last[1] > best: |
| 172 | + best = last[1] |
| 173 | + |
| 174 | + # Save model |
| 175 | + ckpt = {'model': copy.deepcopy(ema.ema).half()} |
| 176 | + |
| 177 | + # Save last, best and delete |
| 178 | + torch.save(ckpt, './weights/last.pt') |
| 179 | + if best == last[1]: |
| 180 | + torch.save(ckpt, './weights/best.pt') |
| 181 | + del ckpt |
| 182 | + |
| 183 | + if args.local_rank == 0: |
| 184 | + util.strip_optimizer('./weights/best.pt') # strip optimizers |
| 185 | + util.strip_optimizer('./weights/last.pt') # strip optimizers |
| 186 | + |
| 187 | + torch.cuda.empty_cache() |
| 188 | + |
| 189 | + |
| 190 | +@torch.no_grad() |
| 191 | +def test(args, model=None): |
| 192 | + filenames = [] |
| 193 | + with open('../Dataset/COCO/val2017.txt') as reader: |
| 194 | + for filename in reader.readlines(): |
| 195 | + filename = filename.rstrip().split('/')[-1] |
| 196 | + filenames.append('../Dataset/COCO/images/val2017/' + filename) |
| 197 | + dataset = Dataset(filenames, args.input_size, False) |
| 198 | + loader = data.DataLoader(dataset, 4, False, num_workers=4, |
| 199 | + pin_memory=True, collate_fn=Dataset.collate_fn) |
| 200 | + |
| 201 | + if model is None: |
| 202 | + model = torch.load('./weights/best.pt', map_location='cuda')['model'].float() |
| 203 | + |
| 204 | + model.half() |
| 205 | + |
| 206 | + # Configure |
| 207 | + model.eval() |
| 208 | + iou_v = torch. linspace( 0.5, 0.95, 10). cuda() # iou vector for [email protected]:0.95 |
| 209 | + n_iou = iou_v.numel() |
| 210 | + |
| 211 | + metrics = [] |
| 212 | + m_pre, m_rec, map50, mean_ap = 0.0, 0.0, 0.0, 0.0 |
| 213 | + p_bar = tqdm.tqdm(loader, desc=('%10s' * 3) % ('precision', 'recall', 'mAP')) |
| 214 | + for samples, targets, shapes in p_bar: |
| 215 | + samples = samples.cuda() |
| 216 | + targets = targets.cuda() |
| 217 | + samples = samples.half() # uint8 to fp16/32 |
| 218 | + samples = samples / 255 # 0 - 255 to 0.0 - 1.0 |
| 219 | + _, _, height, width = samples.shape # batch size, channels, height, width |
| 220 | + |
| 221 | + # Inference |
| 222 | + outputs = model(samples) |
| 223 | + |
| 224 | + # NMS |
| 225 | + targets[:, 2:] *= torch.tensor((width, height, width, height)).cuda() # to pixels |
| 226 | + outputs = util.non_max_suppression(outputs, 0.001, 0.65) |
| 227 | + |
| 228 | + # Metrics |
| 229 | + for i, output in enumerate(outputs): |
| 230 | + labels = targets[targets[:, 0] == i, 1:] |
| 231 | + correct = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda() |
| 232 | + |
| 233 | + if output.shape[0] == 0: |
| 234 | + if labels.shape[0]: |
| 235 | + metrics.append((correct, *torch.zeros((3, 0)).cuda())) |
| 236 | + continue |
| 237 | + |
| 238 | + detections = output.clone() |
| 239 | + util.scale(detections[:, :4], samples[i].shape[1:], shapes[i][0], shapes[i][1]) |
| 240 | + |
| 241 | + # Evaluate |
| 242 | + if labels.shape[0]: |
| 243 | + tbox = labels[:, 1:5].clone() # target boxes |
| 244 | + tbox[:, 0] = labels[:, 1] - labels[:, 3] / 2 # top left x |
| 245 | + tbox[:, 1] = labels[:, 2] - labels[:, 4] / 2 # top left y |
| 246 | + tbox[:, 2] = labels[:, 1] + labels[:, 3] / 2 # bottom right x |
| 247 | + tbox[:, 3] = labels[:, 2] + labels[:, 4] / 2 # bottom right y |
| 248 | + util.scale(tbox, samples[i].shape[1:], shapes[i][0], shapes[i][1]) |
| 249 | + |
| 250 | + correct = numpy.zeros((detections.shape[0], iou_v.shape[0])) |
| 251 | + correct = correct.astype(bool) |
| 252 | + |
| 253 | + t_tensor = torch.cat((labels[:, 0:1], tbox), 1) |
| 254 | + iou = util.box_iou(t_tensor[:, 1:], detections[:, :4]) |
| 255 | + correct_class = t_tensor[:, 0:1] == detections[:, 5] |
| 256 | + for j in range(len(iou_v)): |
| 257 | + x = torch.where((iou >= iou_v[j]) & correct_class) |
| 258 | + if x[0].shape[0]: |
| 259 | + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() |
| 260 | + if x[0].shape[0] > 1: |
| 261 | + matches = matches[matches[:, 2].argsort()[::-1]] |
| 262 | + matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]] |
| 263 | + matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]] |
| 264 | + correct[matches[:, 1].astype(int), j] = True |
| 265 | + correct = torch.tensor(correct, dtype=torch.bool, device=iou_v.device) |
| 266 | + metrics.append((correct, output[:, 4], output[:, 5], labels[:, 0])) |
| 267 | + |
| 268 | + # Compute metrics |
| 269 | + metrics = [torch.cat(x, 0).cpu().numpy() for x in zip(*metrics)] # to numpy |
| 270 | + if len(metrics) and metrics[0].any(): |
| 271 | + tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics) |
| 272 | + |
| 273 | + # Print results |
| 274 | + print('%10.3g' * 3 % (m_pre, m_rec, mean_ap)) |
| 275 | + |
| 276 | + # Return results |
| 277 | + model.float() # for training |
| 278 | + return map50, mean_ap |
| 279 | + |
| 280 | + |
| 281 | +def main(): |
| 282 | + parser = argparse.ArgumentParser() |
| 283 | + parser.add_argument('--input-size', default=640, type=int) |
| 284 | + parser.add_argument('--batch-size', default=32, type=int) |
| 285 | + parser.add_argument('--local_rank', default=0, type=int) |
| 286 | + parser.add_argument('--epochs', default=500, type=int) |
| 287 | + parser.add_argument('--train', action='store_true') |
| 288 | + parser.add_argument('--test', action='store_true') |
| 289 | + |
| 290 | + args = parser.parse_args() |
| 291 | + |
| 292 | + args.local_rank = int(os.getenv('LOCAL_RANK', 0)) |
| 293 | + args.world_size = int(os.getenv('WORLD_SIZE', 1)) |
| 294 | + |
| 295 | + if args.world_size > 1: |
| 296 | + torch.cuda.set_device(device=args.local_rank) |
| 297 | + torch.distributed.init_process_group(backend='nccl', init_method='env://') |
| 298 | + |
| 299 | + if args.local_rank == 0: |
| 300 | + if not os.path.exists('weights'): |
| 301 | + os.makedirs('weights') |
| 302 | + |
| 303 | + util.setup_seed() |
| 304 | + util.setup_multi_processes() |
| 305 | + |
| 306 | + with open(os.path.join('utils', 'args.yaml'), errors='ignore') as f: |
| 307 | + params = yaml.safe_load(f) |
| 308 | + |
| 309 | + if args.train: |
| 310 | + train(args, params) |
| 311 | + if args.test: |
| 312 | + test(args) |
| 313 | + |
| 314 | + |
| 315 | +if __name__ == "__main__": |
| 316 | + main() |
0 commit comments