-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Got an error when using lazy. #33
Comments
Oh,the code I used lazy is like this: def get_batches(sz, pad=0):
for i in range(0, len(datatmp), sz):
n=0
srcdata = []
trgdata = []
for j in range(n, sz):
srcdata.append(datatmp[i+j][0])#appened is a list
trgdata.append(datatmp[i+j][1])#identical to beneath
a = randint(1, 2)
src_max_seq_length=max([len(srcdata[i]) for i in range(len(srcdata))])
trg_max_seq_length=max([len(trgdata[i]) for i in range(len(trgdata))])
# pad src to src_max_seq_length
for i in range(len(srcdata)):
srcdata[i] = srcdata[i] + [pad for j in range(src_max_seq_length-len(srcdata[i]))]
#pad trg to trg_max_seq_length
for i in range(len(trgdata)):
trgdata[i] = trgdata[i] + [pad for j in range(trg_max_seq_length-len(trgdata[i]))]
sr = np.ndarray(shape=(sz, src_max_seq_length))
tg = np.ndarray(shape=(sz, trg_max_seq_length))
for i in range(len(srcdata)):
for j in range(len(srcdata[i])):
sr[i][j] = srcdata[i][j]
for i in range(len(trgdata)):
for j in range(len(trgdata[i])):
tg[i][j] = trgdata[i][j]
#srcdata = np.array(srcdata)
#trgdata = np.array(trgdata)
srcdata = torch.from_numpy(sr)
trgdata = torch.from_numpy(tg)
src = Variable(srcdata, requires_grad=False).long()
trg = Variable(trgdata, requires_grad=False).long()
(src, trg) = lazy(src,trg, batch=0)#Here
yield Batch(src, trg, pad) |
Hmm, sorry for the late response. It seems to me that you're using PyTorch 0.4.* right? I didn't test versions <1 so I'm not sure where the issue comes from. If I had to guess, it's perhaps because of the mismatch between the API of |
I'm sorry but it has the same error when I use torch2.0.0 and just torch.from_numpy(not using Variable).
|
I see. Seems to be an oversight on my part where I didn't handle broadcasting mechanism with primitives. Thanks for the feedback! |
I'm doing a NMT task.I use my own data loading function rather than using torch dataset.I got an "int object doesn't has attribute 'size' " error.
Here's my data loading code:
ps:The code is adapted from 'Annotated Transformer'
The text was updated successfully, but these errors were encountered: