-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
59ec2df
commit bf5da12
Showing
11 changed files
with
143 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,22 @@ | ||
name: mvbreid | ||
channels: | ||
- pytorch | ||
# - conda-forge | ||
dependencies: | ||
- python=3.7 | ||
- python>3.7 | ||
- numpy | ||
- Cython | ||
- matplotlib | ||
- tensorboardX | ||
- opencv | ||
- h5py | ||
- Pillow | ||
- six | ||
- scipy | ||
- pip | ||
- gcc_linux-64 | ||
- pytorch | ||
- torchvision | ||
- cudatoolkit=10.0 | ||
- pip: | ||
- comet_ml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,104 @@ | ||
import numpy as np | ||
import torch | ||
__all__ = ['combine_by_id'] | ||
from torch import nn | ||
__all__ = ['CombineMultipleImages'] | ||
|
||
def combine_by_id(gf, g_pids, method): | ||
|
||
class CombineMultipleImages: | ||
""" | ||
transforms features of same bag to a bag embedding | ||
Both returned gf and g_pids are numpy array of float32 | ||
""" | ||
if method == "none": | ||
print("Does not combine by id") | ||
def __init__(self, method, embed_dim, input_count, trainloader, encoder): | ||
self.encoder = encoder | ||
self.trainloader = trainloader | ||
if method == "none": | ||
self.fn = Identity() | ||
elif method == "mean": | ||
self.fn = Mean() | ||
elif method == "feed_forward": | ||
self.fn = FeedForward(embed_dim, input_count) | ||
elif method == "self_attention": | ||
self.fn = SelfAttention(embed_dim, input_count) | ||
|
||
def train(self): | ||
self.fn.train(self.encoder, self.trainloader) | ||
|
||
def __call__(self, gf, g_pids, g_camids): | ||
return self.fn(gf, g_pids, g_camids) | ||
|
||
|
||
class CombineFunction: | ||
def train(self, encoder, dataloader): | ||
pass | ||
|
||
def __call__(self, gf, g_pids, g_camids): | ||
raise NotImplementedError | ||
|
||
|
||
class Identity(CombineFunction): | ||
def __call__(self, gf, g_pids, g_camids): | ||
return gf, g_pids | ||
elif method == "mean": | ||
print("Calculating mean by id ...") | ||
|
||
|
||
class Mean(CombineFunction): | ||
def __call__(self, gf, g_pids, g_camids): | ||
gf = gf.numpy() | ||
unique_ids = set(g_pids) | ||
new_g_pids = [] | ||
gf_by_id = np.empty((len(unique_ids), gf.shape[-1])) | ||
for i, gid in enumerate(unique_ids): | ||
gf_by_id[i] = np.mean(gf[np.asarray(g_pids) == gid], axis=0) | ||
new_g_pids.append(gid) | ||
gf = torch.tensor(gf_by_id, dtype=torch.float) | ||
gf = np.array(gf_by_id) | ||
g_pids = np.array(new_g_pids) | ||
return gf, g_pids | ||
elif method == "self_attention": | ||
# TODO: self attention | ||
return distmat | ||
elif method == "multi_head_attention": | ||
# TODO: multi-headed attention | ||
return distmat | ||
else: | ||
raise ValueError('Must be valid combine-method') | ||
|
||
|
||
class FeedForward(CombineFunction): # TODO: | ||
def __init__(self, embed_dim, input_count): | ||
super().__init__() | ||
self.model = FeedForwardNN(embed_dim, input_count) | ||
|
||
def train(self, encoder, dataloader): | ||
for data in dataloader: | ||
imgs = data[0] | ||
pids = data[1] | ||
cam_ids = data[2] | ||
# print(len(data)) | ||
# exit() | ||
|
||
def __call__(self, gf, g_pids, g_camids): | ||
result = self.model(gf, g_pids, g_camids) | ||
# Some modification on result | ||
return result | ||
|
||
|
||
class SelfAttention(CombineFunction): | ||
def __init__(self, embed_dim, input_count): | ||
self.model = SelfAttentionNN(input_dim, output_dim, input_count) | ||
|
||
def train(self, dataloader): | ||
pass | ||
|
||
def __call__(self, gf, g_pids, g_camids): | ||
result = self.model(gf, g_pids, g_camids) | ||
# Some modification on result | ||
return result | ||
|
||
|
||
class FeedForwardNN(nn.Module): | ||
def __init__(self, embed_dim, input_count): | ||
super().__init__() | ||
self.fc1 = nn.Linear(embed_dim * input_count, embed_dim * input_count) | ||
self.fc2 = nn.Linear(embed_dim * input_count, embed_dim) | ||
|
||
def forward(self, x): | ||
pass | ||
|
||
|
||
class SelfAttentionNN(nn.Module): | ||
def __init__(self, embed_dim, input_count): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
pass |