-
Notifications
You must be signed in to change notification settings - Fork 749
Open
Labels
Description
these object JavaStatefulTensorDataLoader JavaStatefulTensorDataset JavaStatefulTensorDataLoader JavaStatefulTensorDataset I don't know how to use ,when I try to run two example ,but console is fault ,other dataset can use
import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.DataLoaderOptions;
import org.bytedeco.pytorch.RandomSampler; //{ ExampleIterator, ExampleVectorOptional,DataLoaderOptions, FullDataLoaderOptions, InputArchive, OutputArchive, SizeTOptional, SizeTVectorOptional, T_TensorT_TensorTensor_T_T, T_TensorTensor_T, T_TensorTensor_TOptional, TensorMapper, TensorVector, TransformerImpl, TransformerOptions, kCircular, kGELU, kReflect, kReplicate, kZeros, ChunkBatchDataset, ChunkRandomDataLoader, RandomSampler , SequentialSampler };
public class TestChunkDatas {
public static void main(String[] args) throws Exception {
try (PointerScope scope = new PointerScope()) {
long batch_size = 10;
long prefetch_count = 1;
JavaStatefulTensorDataset ds = new JavaStatefulTensorDataset(){
TensorExampleVector ex = new TensorExampleVector(
new TensorExample(Tensor.create(10.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(15.0, 30.0, 50.0, 80.0, 300.0)),
new TensorExample(Tensor.create(20.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(35.0, 30.0, 50.0, 80.0, 300.0)),
new TensorExample(Tensor.create(40.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(55.0, 30.0, 50.0, 80.0, 300.0)),
new TensorExample(Tensor.create(60.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(75.0, 30.0, 50.0, 80.0, 300.0))
);
@Override
public TensorExampleVectorOptional get_batch(long size) {
return new TensorExampleVectorOptional(ex);
}
@Override
public void reset() {
super.reset();
}
// @Override
// public TensorExampleVector get_batch(long size) {
// return ex;
// }
//
@Override
public SizeTOptional size() {
return new SizeTOptional(ex.size());
}
};
SequentialSampler sampler = new SequentialSampler(0);
DataLoaderOptions opts = new DataLoaderOptions(2);
opts.workers().put(5);
JavaStatefulTensorDataLoader data_loader = new JavaStatefulTensorDataLoader(ds, opts); //.map(new ExampleStack());
// SequentialSampler sampler = new SequentialSampler(0);
// ChunkMapTensorDataset data_set = new ChunkSharedTensorBatchDataset(
// new ChunkTensorDataset(data_reader, sampler, sampler,
// new ChunkDatasetOptions(prefetch_count, batch_size))).map(new TensorExampleStack());
// ChunkRandomTensorDataLoader data_loader = new ChunkRandomTensorDataLoader(
// data_set, new DataLoaderOptions(batch_size));
for (int epoch = 1; epoch <= 10; ++epoch) {
for (TensorExampleVectorIterator it = data_loader.begin(); !it.equals(data_loader.end()); it = it.increment()) {
TensorExampleVector batch = it.access();
System.out.println(batch );
}
}
}
}
public static void mainrty(String[] args) throws Exception {
try (PointerScope scope = new PointerScope()) {
long batch_size = 10;
long prefetch_count = 1;
JavaStatefulTensorDataset ds = new JavaStatefulTensorDataset(){
TensorExampleVector ex = new TensorExampleVector(
new TensorExample(Tensor.create(10.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(15.0, 30.0, 50.0, 80.0, 300.0)),
new TensorExample(Tensor.create(20.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(35.0, 30.0, 50.0, 80.0, 300.0)),
new TensorExample(Tensor.create(40.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(55.0, 30.0, 50.0, 80.0, 300.0)),
new TensorExample(Tensor.create(60.0, 20.0, 50.0, 80.0, 100.0)),
new TensorExample(Tensor.create(75.0, 30.0, 50.0, 80.0, 300.0))
);
public TensorExample get(long index) {
return ex.get(index);
// return super.get(index);
}
@Override
public TensorExampleVectorOptional get_batch(long size) {
return new TensorExampleVectorOptional(ex );
}
@Override
public void reset() {
super.reset();
}
@Override
public SizeTOptional size() {
return new SizeTOptional(ex.size());
}
};
SequentialSampler sampler = new SequentialSampler(0);
DataLoaderOptions opts = new DataLoaderOptions(2);
opts.workers().put(5);
JavaStatefulTensorDataLoader data_loader = new JavaStatefulTensorDataLoader(ds, opts); //.map(new ExampleStack()); new DistributedRandomSampler(ds.size().get()),
// SequentialSampler sampler = new SequentialSampler(0);
// ChunkMapTensorDataset data_set = new ChunkSharedTensorBatchDataset(
// new ChunkTensorDataset(data_reader, sampler, sampler,
// new ChunkDatasetOptions(prefetch_count, batch_size))).map(new TensorExampleStack());
// ChunkRandomTensorDataLoader data_loader = new ChunkRandomTensorDataLoader(
// data_set, new DataLoaderOptions(batch_size));
for (int epoch = 1; epoch <= 10; ++epoch) {
for (TensorExampleVectorIterator it = data_loader.begin(); !it.equals(data_loader.end()); it = it.increment()) {
TensorExampleVector batch = it.access();
System.out.println("hello stateful");
System.out.println(batch );
}
}
}
}
}
console
Exception in thread "main" java.lang.RuntimeException: java.lang.RuntimeException: Cannot call pure virtual function javacpp::StatefulDataset<torch::Tensor,torch::data::example::NoTarget>::reset().
at org.bytedeco.pytorch.JavaStatefulTensorDataLoaderBase.begin(Native Method)
at example.TestChunkDatas.main(TestChunkDatas.java:60)