Skip to content

Commit

Permalink
Merge pull request #7 from johnzhang1999/multi-image
Browse files Browse the repository at this point in the history
Multi image
  • Loading branch information
johnzhang1999 authored Jul 12, 2019
2 parents 2845068 + e1b9a8b commit e447e5f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 8 deletions.
2 changes: 2 additions & 0 deletions scripts/default_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def init_parser():
help='cmc ranks')
parser.add_argument('--rerank', action='store_true',
help='use person re-ranking (by Zhong et al. CVPR2017)')
parser.add_argument('--combine-method', type=str, default='none', choices=["none", "mean"],
help='use combine method of [none | mean]')

parser.add_argument('--visrank', action='store_true',
help='visualize ranked results, only available in evaluation mode')
Expand Down
6 changes: 4 additions & 2 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def build_engine(args, datamanager, model, optimizer, scheduler, experiment=expe
scheduler=scheduler,
use_cpu=args.use_cpu,
label_smooth=args.label_smooth,
experiment=experiment
experiment=experiment,
combine_method=args.combine_method
)
else:
engine = torchreid.engine.ImageTripletEngine(
Expand All @@ -73,7 +74,8 @@ def build_engine(args, datamanager, model, optimizer, scheduler, experiment=expe
scheduler=scheduler,
use_cpu=args.use_cpu,
label_smooth=args.label_smooth,
experiment=experiment
experiment=experiment,
combine_method=args.combine_method
)

else:
Expand Down
17 changes: 15 additions & 2 deletions torchreid/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.nn import functional as F

import torchreid
from torchreid.utils import AverageMeter, visualize_ranked_results, visualize_cam, save_checkpoint, re_ranking
from torchreid.utils import AverageMeter, visualize_ranked_results, visualize_cam, save_checkpoint, re_ranking, combine_by_id
from torchreid.losses import DeepSupervision
from torchreid import metrics

Expand All @@ -32,14 +32,15 @@ class Engine(object):
use_cpu (bool, optional): use cpu. Default is False.
"""

def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_cpu=False, experiment=None):
def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_cpu=False, experiment=None, combine_method="mean"):
self.datamanager = datamanager
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
self.writer = None
self.experiment = experiment
self.combine_method = combine_method

# check attributes
if not isinstance(self.model, nn.Module):
Expand Down Expand Up @@ -300,6 +301,18 @@ def _evaluate(self, arch, epoch, dataset_name='', queryloader=None, galleryloade
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)

# gf = gf.numpy()
# unique_ids = set(g_pids)
# new_g_pids = []
# gf_by_id = np.empty((len(unique_ids), gf.shape[-1]))
# for i, gid in enumerate(unique_ids):
# gf_by_id[i] = np.mean(gf[np.asarray(g_pids) == gid], axis=0)
# new_g_pids.append(gid)
# gf = torch.tensor(gf_by_id, dtype=torch.float)
# g_pids = np.array(new_g_pids)

gf, g_pids = combine_by_id(gf, g_pids, self.combine_method)
print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

print('Speed: {:.4f} sec/batch'.format(batch_time.avg))
Expand Down
4 changes: 2 additions & 2 deletions torchreid/engine/image/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class ImageSoftmaxEngine(engine.Engine):
"""

def __init__(self, datamanager, model, optimizer, scheduler=None, use_cpu=False,
label_smooth=True, experiment=None):
super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment)
label_smooth=True, experiment=None, combine_method="mean"):
super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment, combine_method)

self.criterion = CrossEntropyLoss(
num_classes=self.datamanager.num_train_pids,
Expand Down
4 changes: 2 additions & 2 deletions torchreid/engine/image/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class ImageTripletEngine(engine.Engine):

def __init__(self, datamanager, model, optimizer, margin=0.3,
weight_t=1, weight_x=1, scheduler=None, use_cpu=False,
label_smooth=True, experiment=None):
super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment)
label_smooth=True, experiment=None, combine_method="mean"):
super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu, experiment, combine_method)

self.weight_t = weight_t
self.weight_x = weight_x
Expand Down
1 change: 1 addition & 0 deletions torchreid/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .reidtools import *
from .torchtools import *
from .rerank import re_ranking
from .multi_image import combine_by_id
from .model_complexity import compute_model_complexity
31 changes: 31 additions & 0 deletions torchreid/utils/multi_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import torch
__all__ = ['combine_by_id']

def combine_by_id(gf, g_pids, method):
"""
transforms features of same bag to a bag embedding
"""
if method == "none":
print("Does not combine by id")
return gf, g_pids
elif method == "mean":
print("Calculating mean by id ...")
gf = gf.numpy()
unique_ids = set(g_pids)
new_g_pids = []
gf_by_id = np.empty((len(unique_ids), gf.shape[-1]))
for i, gid in enumerate(unique_ids):
gf_by_id[i] = np.mean(gf[np.asarray(g_pids) == gid], axis=0)
new_g_pids.append(gid)
gf = torch.tensor(gf_by_id, dtype=torch.float)
g_pids = np.array(new_g_pids)
return gf, g_pids
elif method == "self_attention":
# TODO: self attention
return distmat
elif method == "multi_head_attention":
# TODO: multi-headed attention
return distmat
else:
raise ValueError('Must be valid combine-method')

0 comments on commit e447e5f

Please sign in to comment.