diff --git a/scripts/default_parser.py b/scripts/default_parser.py index df8007d..6796c39 100644 --- a/scripts/default_parser.py +++ b/scripts/default_parser.py @@ -168,6 +168,8 @@ def init_parser(): help='cmc ranks') parser.add_argument('--rerank', action='store_true', help='use person re-ranking (by Zhong et al. CVPR2017)') + parser.add_argument('--combine-method', type=str, default='none', choices=["none", "mean"], + help='use combine method of [none | mean]') parser.add_argument('--visrank', action='store_true', help='visualize ranked results, only available in evaluation mode') diff --git a/scripts/main.py b/scripts/main.py index a5c90a6..defef76 100755 --- a/scripts/main.py +++ b/scripts/main.py @@ -60,7 +60,8 @@ def build_engine(args, datamanager, model, optimizer, scheduler, experiment=expe scheduler=scheduler, use_cpu=args.use_cpu, label_smooth=args.label_smooth, - experiment=experiment + experiment=experiment, + combine_method=args.combine_method ) else: engine = torchreid.engine.ImageTripletEngine( @@ -73,7 +74,8 @@ def build_engine(args, datamanager, model, optimizer, scheduler, experiment=expe scheduler=scheduler, use_cpu=args.use_cpu, label_smooth=args.label_smooth, - experiment=experiment + experiment=experiment, + combine_method=args.combine_method ) else: diff --git a/torchreid/engine/engine.py b/torchreid/engine/engine.py index 39cd683..3fc6f4c 100644 --- a/torchreid/engine/engine.py +++ b/torchreid/engine/engine.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import torchreid -from torchreid.utils import AverageMeter, visualize_ranked_results, visualize_cam, save_checkpoint, re_ranking +from torchreid.utils import AverageMeter, visualize_ranked_results, visualize_cam, save_checkpoint, re_ranking, combine_by_id from torchreid.losses import DeepSupervision from torchreid import metrics @@ -32,7 +32,7 @@ class Engine(object): use_cpu (bool, optional): use cpu. Default is False. """ - def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_cpu=False, experiment=None): + def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_cpu=False, experiment=None, combine_method="mean"): self.datamanager = datamanager self.model = model self.optimizer = optimizer @@ -40,6 +40,7 @@ def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_cpu=F self.use_gpu = (torch.cuda.is_available() and not use_cpu) self.writer = None self.experiment = experiment + self.combine_method = combine_method # check attributes if not isinstance(self.model, nn.Module): @@ -300,6 +301,18 @@ def _evaluate(self, arch, epoch, dataset_name='', queryloader=None, galleryloade gf = torch.cat(gf, 0) g_pids = np.asarray(g_pids) g_camids = np.asarray(g_camids) + + # gf = gf.numpy() + # unique_ids = set(g_pids) + # new_g_pids = [] + # gf_by_id = np.empty((len(unique_ids), gf.shape[-1])) + # for i, gid in enumerate(unique_ids): + # gf_by_id[i] = np.mean(gf[np.asarray(g_pids) == gid], axis=0) + # new_g_pids.append(gid) + # gf = torch.tensor(gf_by_id, dtype=torch.float) + # g_pids = np.array(new_g_pids) + + gf, g_pids = combine_by_id(gf, g_pids, self.combine_method) print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1))) print('Speed: {:.4f} sec/batch'.format(batch_time.avg)) diff --git a/torchreid/engine/image/softmax.py b/torchreid/engine/image/softmax.py index 03fe0c9..facaacc 100644 --- a/torchreid/engine/image/softmax.py +++ b/torchreid/engine/image/softmax.py @@ -63,8 +63,8 @@ class ImageSoftmaxEngine(engine.Engine): """ def __init__(self, datamanager, model, optimizer, scheduler=None, use_cpu=False, - label_smooth=True, experiment=None): - super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment) + label_smooth=True, experiment=None, combine_method="mean"): + super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment, combine_method) self.criterion = CrossEntropyLoss( num_classes=self.datamanager.num_train_pids, diff --git a/torchreid/engine/image/triplet.py b/torchreid/engine/image/triplet.py index 2d0c4fd..406f378 100644 --- a/torchreid/engine/image/triplet.py +++ b/torchreid/engine/image/triplet.py @@ -70,8 +70,8 @@ class ImageTripletEngine(engine.Engine): def __init__(self, datamanager, model, optimizer, margin=0.3, weight_t=1, weight_x=1, scheduler=None, use_cpu=False, - label_smooth=True, experiment=None): - super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment) + label_smooth=True, experiment=None, combine_method="mean"): + super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment, combine_method) self.weight_t = weight_t self.weight_x = weight_x diff --git a/torchreid/utils/__init__.py b/torchreid/utils/__init__.py index fb04b6e..01e6769 100644 --- a/torchreid/utils/__init__.py +++ b/torchreid/utils/__init__.py @@ -7,4 +7,5 @@ from .reidtools import * from .torchtools import * from .rerank import re_ranking +from .multi_image import combine_by_id from .model_complexity import compute_model_complexity diff --git a/torchreid/utils/multi_image.py b/torchreid/utils/multi_image.py new file mode 100644 index 0000000..0c04b38 --- /dev/null +++ b/torchreid/utils/multi_image.py @@ -0,0 +1,31 @@ +import numpy as np +import torch +__all__ = ['combine_by_id'] + +def combine_by_id(gf, g_pids, method): + """ + transforms features of same bag to a bag embedding + """ + if method == "none": + print("Does not combine by id") + return gf, g_pids + elif method == "mean": + print("Calculating mean by id ...") + gf = gf.numpy() + unique_ids = set(g_pids) + new_g_pids = [] + gf_by_id = np.empty((len(unique_ids), gf.shape[-1])) + for i, gid in enumerate(unique_ids): + gf_by_id[i] = np.mean(gf[np.asarray(g_pids) == gid], axis=0) + new_g_pids.append(gid) + gf = torch.tensor(gf_by_id, dtype=torch.float) + g_pids = np.array(new_g_pids) + return gf, g_pids + elif method == "self_attention": + # TODO: self attention + return distmat + elif method == "multi_head_attention": + # TODO: multi-headed attention + return distmat + else: + raise ValueError('Must be valid combine-method')