Skip to content

I has test all javacpp -pytorch dataset and dataloader ,but JavaStateful meet error? #1559

@mullerhai

Description

@mullerhai

these object JavaStatefulTensorDataLoader JavaStatefulTensorDataset JavaStatefulTensorDataLoader JavaStatefulTensorDataset I don't know how to use ,when I try to run two example ,but console is fault ,other dataset can use

import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.DataLoaderOptions;
import org.bytedeco.pytorch.RandomSampler; //{ ExampleIterator, ExampleVectorOptional,DataLoaderOptions, FullDataLoaderOptions, InputArchive, OutputArchive, SizeTOptional, SizeTVectorOptional, T_TensorT_TensorTensor_T_T, T_TensorTensor_T, T_TensorTensor_TOptional, TensorMapper, TensorVector, TransformerImpl, TransformerOptions, kCircular, kGELU, kReflect, kReplicate, kZeros, ChunkBatchDataset, ChunkRandomDataLoader, RandomSampler , SequentialSampler };

public class TestChunkDatas {

    public static void main(String[] args) throws Exception {
        try (PointerScope scope = new PointerScope()) {
            long batch_size = 10;
            long prefetch_count = 1;

            JavaStatefulTensorDataset ds = new JavaStatefulTensorDataset(){
                TensorExampleVector ex = new TensorExampleVector(
                        new TensorExample(Tensor.create(10.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(15.0, 30.0, 50.0, 80.0, 300.0)),
                        new TensorExample(Tensor.create(20.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(35.0, 30.0, 50.0, 80.0, 300.0)),
                        new TensorExample(Tensor.create(40.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(55.0, 30.0, 50.0, 80.0, 300.0)),
                        new TensorExample(Tensor.create(60.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(75.0, 30.0, 50.0, 80.0, 300.0))

                );


                @Override
                public TensorExampleVectorOptional get_batch(long size) {
                    return new TensorExampleVectorOptional(ex);
                }

                @Override
                public void reset() {
                    super.reset();
                }
                //                @Override
//                public TensorExampleVector get_batch(long size) {
//                    return ex;
//                }
//
                @Override
                public SizeTOptional size() {
                    return new SizeTOptional(ex.size());
                }
            };

            SequentialSampler sampler = new SequentialSampler(0);
            DataLoaderOptions opts = new DataLoaderOptions(2);
            opts.workers().put(5);
            JavaStatefulTensorDataLoader data_loader = new JavaStatefulTensorDataLoader(ds, opts); //.map(new ExampleStack());
//            SequentialSampler sampler = new SequentialSampler(0);
//            ChunkMapTensorDataset data_set = new ChunkSharedTensorBatchDataset(
//                    new ChunkTensorDataset(data_reader, sampler, sampler,
//                            new ChunkDatasetOptions(prefetch_count, batch_size))).map(new TensorExampleStack());
//            ChunkRandomTensorDataLoader data_loader = new ChunkRandomTensorDataLoader(
//                    data_set, new DataLoaderOptions(batch_size));
            for (int epoch = 1; epoch <= 10; ++epoch) {
                for (TensorExampleVectorIterator it = data_loader.begin(); !it.equals(data_loader.end()); it = it.increment()) {
                    TensorExampleVector batch = it.access();
                    System.out.println(batch );
                }
            }
        }
    }

    public static void mainrty(String[] args) throws Exception {
        try (PointerScope scope = new PointerScope()) {
            long batch_size = 10;
            long prefetch_count = 1;

            JavaStatefulTensorDataset ds = new JavaStatefulTensorDataset(){


                TensorExampleVector ex = new TensorExampleVector(
                        new TensorExample(Tensor.create(10.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(15.0, 30.0, 50.0, 80.0, 300.0)),
                        new TensorExample(Tensor.create(20.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(35.0, 30.0, 50.0, 80.0, 300.0)),
                        new TensorExample(Tensor.create(40.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(55.0, 30.0, 50.0, 80.0, 300.0)),
                        new TensorExample(Tensor.create(60.0, 20.0, 50.0, 80.0, 100.0)),
                        new TensorExample(Tensor.create(75.0, 30.0, 50.0, 80.0, 300.0))

                );

                public TensorExample get(long index) {
                    return ex.get(index);
//                    return super.get(index);
                }

                @Override
                public TensorExampleVectorOptional get_batch(long size) {
                    return new TensorExampleVectorOptional(ex );
                }

                @Override
                public void reset() {
                    super.reset();
                }

                @Override
                public SizeTOptional size() {
                    return new SizeTOptional(ex.size());
                }
            };

            SequentialSampler sampler = new SequentialSampler(0);
            DataLoaderOptions opts = new DataLoaderOptions(2);
            opts.workers().put(5);
            JavaStatefulTensorDataLoader data_loader = new JavaStatefulTensorDataLoader(ds,  opts); //.map(new ExampleStack()); new DistributedRandomSampler(ds.size().get()),
//            SequentialSampler sampler = new SequentialSampler(0);
//            ChunkMapTensorDataset data_set = new ChunkSharedTensorBatchDataset(
//                    new ChunkTensorDataset(data_reader, sampler, sampler,
//                            new ChunkDatasetOptions(prefetch_count, batch_size))).map(new TensorExampleStack());
//            ChunkRandomTensorDataLoader data_loader = new ChunkRandomTensorDataLoader(
//                    data_set, new DataLoaderOptions(batch_size));
            for (int epoch = 1; epoch <= 10; ++epoch) {
                for (TensorExampleVectorIterator it = data_loader.begin(); !it.equals(data_loader.end()); it = it.increment()) {
                    TensorExampleVector batch = it.access();
                    System.out.println("hello stateful");
                    System.out.println(batch );
                }
            }
        }
    }
}

console

Exception in thread "main" java.lang.RuntimeException: java.lang.RuntimeException: Cannot call pure virtual function javacpp::StatefulDataset<torch::Tensor,torch::data::example::NoTarget>::reset().
	at org.bytedeco.pytorch.JavaStatefulTensorDataLoaderBase.begin(Native Method)
	at example.TestChunkDatas.main(TestChunkDatas.java:60) 

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions