From e7af94a224bcb3d7b5db014788473319fb45f1a6 Mon Sep 17 00:00:00 2001 From: Aleksandr Rigachnyi Date: Tue, 4 Jun 2024 18:19:44 +0200 Subject: [PATCH] Replace SessionManager by SessionFactory --- .../libs/daemon/common/bootstrap.cpp | 18 +- .../libs/endpoints/endpoint_manager.cpp | 196 ++++----- .../libs/endpoints/endpoint_manager.h | 2 +- .../libs/endpoints/endpoint_manager_ut.cpp | 134 +++--- cloud/blockstore/libs/endpoints/public.h | 7 +- .../libs/endpoints/service_endpoint_ut.cpp | 153 ++++--- .../libs/endpoints/session_manager.cpp | 382 ++++-------------- .../libs/endpoints/session_manager.h | 57 ++- .../libs/endpoints/session_manager_ut.cpp | 30 +- 9 files changed, 387 insertions(+), 592 deletions(-) diff --git a/cloud/blockstore/libs/daemon/common/bootstrap.cpp b/cloud/blockstore/libs/daemon/common/bootstrap.cpp index 8c1595d587f..2a5d7b7a81f 100644 --- a/cloud/blockstore/libs/daemon/common/bootstrap.cpp +++ b/cloud/blockstore/libs/daemon/common/bootstrap.cpp @@ -329,13 +329,13 @@ void TBootstrapBase::Init() STORAGE_INFO("StorageProvider initialized"); - TSessionManagerOptions sessionManagerOptions; - sessionManagerOptions.StrictContractValidation + TSessionFactoryOptions sessionFactoryOptions; + sessionFactoryOptions.StrictContractValidation = Configs->ServerConfig->GetStrictContractValidation(); - sessionManagerOptions.DefaultClientConfig + sessionFactoryOptions.DefaultClientConfig = Configs->EndpointConfig->GetClientConfig(); - sessionManagerOptions.HostProfile = Configs->HostPerformanceProfile; - sessionManagerOptions.TemporaryServer = Configs->Options->TemporaryServer; + sessionFactoryOptions.HostProfile = Configs->HostPerformanceProfile; + sessionFactoryOptions.TemporaryServer = Configs->Options->TemporaryServer; if (!KmsKeyProvider) { KmsKeyProvider = CreateKmsKeyProviderStub(); @@ -345,7 +345,7 @@ void TBootstrapBase::Init() Logging, CreateEncryptionKeyProvider(KmsKeyProvider)); - auto sessionManager = CreateSessionManager( + auto sessionFactory = CreateSessionFactory( Timer, Scheduler, Logging, @@ -357,9 +357,9 @@ void TBootstrapBase::Init() StorageProvider, encryptionClientFactory, Executor, - sessionManagerOptions); + sessionFactoryOptions); - STORAGE_INFO("SessionManager initialized"); + STORAGE_INFO("SessionFactory initialized"); THashMap endpointListeners; @@ -571,7 +571,7 @@ void TBootstrapBase::Init() ServerStats, Executor, EndpointEventHandler, - std::move(sessionManager), + std::move(sessionFactory), std::move(endpointStorage), std::move(endpointListeners), std::move(nbdDeviceFactory), diff --git a/cloud/blockstore/libs/endpoints/endpoint_manager.cpp b/cloud/blockstore/libs/endpoints/endpoint_manager.cpp index 8dbdb1143cc..a1e3bba60ff 100644 --- a/cloud/blockstore/libs/endpoints/endpoint_manager.cpp +++ b/cloud/blockstore/libs/endpoints/endpoint_manager.cpp @@ -393,8 +393,8 @@ struct TRequestState struct TEndpoint { std::shared_ptr Request; + IEndpointSessionPtr Session; NBD::IDevicePtr Device; - NProto::TVolume Volume; }; //////////////////////////////////////////////////////////////////////////////// @@ -408,7 +408,7 @@ class TEndpointManager final const ILoggingServicePtr Logging; const IServerStatsPtr ServerStats; const TExecutorPtr Executor; - const ISessionManagerPtr SessionManager; + const ISessionFactoryPtr SessionFactory; const IEndpointStoragePtr EndpointStorage; const THashMap EndpointListeners; const NBD::IDeviceFactoryPtr NbdDeviceFactory; @@ -448,7 +448,7 @@ class TEndpointManager final IVolumeStatsPtr volumeStats, IServerStatsPtr serverStats, TExecutorPtr executor, - ISessionManagerPtr sessionManager, + ISessionFactoryPtr sessionFactory, IEndpointStoragePtr endpointStorage, THashMap listeners, NBD::IDeviceFactoryPtr nbdDeviceFactory, @@ -456,7 +456,7 @@ class TEndpointManager final : Logging(std::move(logging)) , ServerStats(std::move(serverStats)) , Executor(std::move(executor)) - , SessionManager(std::move(sessionManager)) + , SessionFactory(std::move(sessionFactory)) , EndpointStorage(std::move(endpointStorage)) , EndpointListeners(std::move(listeners)) , NbdDeviceFactory(std::move(nbdDeviceFactory)) @@ -578,24 +578,28 @@ class TEndpointManager final NProto::TError SwitchEndpointImpl( TCallContextPtr ctx, - std::shared_ptr request); + std::shared_ptr request, + const IEndpointSession& session); NProto::TError AlterEndpoint( + const TEndpoint& endpoint, TCallContextPtr ctx, - NProto::TStartEndpointRequest newReq, - NProto::TStartEndpointRequest oldReq); + NProto::TStartEndpointRequest newReq); NProto::TError RestartListenerEndpoint( + const TEndpoint& endpoint, TCallContextPtr ctx, - const NProto::TStartEndpointRequest& request); + const NProto::THeaders& headers); NProto::TError OpenAllEndpointSockets( const NProto::TStartEndpointRequest& request, - const TSessionInfo& sessionInfo); + const NProto::TVolume& volume, + const IEndpointSession& session); NProto::TError OpenEndpointSocket( const NProto::TStartEndpointRequest& request, - const TSessionInfo& sessionInfo); + const NProto::TVolume& volume, + const IEndpointSession& session); void CloseAllEndpointSockets(const NProto::TStartEndpointRequest& request); void CloseEndpointSocket(const NProto::TStartEndpointRequest& request); @@ -612,13 +616,14 @@ class TEndpointManager final void DetachFileDevice(const TString& device); - template - void RemoveSession(TCallContextPtr ctx, const T& request) + void RemoveSession( + IEndpointSession& session, + TCallContextPtr ctx, + const NProto::THeaders& headers) { - auto future = SessionManager->RemoveSession( + auto future = session.Remove( std::move(ctx), - request.GetUnixSocketPath(), - request.GetHeaders()); + headers); if (const auto& error = Executor->WaitFor(future); HasError(error)) { STORAGE_ERROR("Failed to remove session: " << FormatError(error)); @@ -756,11 +761,11 @@ NProto::TStartEndpointResponse TEndpointManager::StartEndpointImpl( auto it = Endpoints.find(socketPath); if (it != Endpoints.end()) { - const auto& endpoint = it->second; + auto endpoint = it->second; if (!NFs::Exists(socketPath)) { // restart listener endpoint to recreate the socket - auto error = RestartListenerEndpoint(ctx, *endpoint.Request); + auto error = RestartListenerEndpoint(endpoint, ctx, request->GetHeaders()); if (HasError(error)) { return TErrorResponse(error); } @@ -771,35 +776,42 @@ NProto::TStartEndpointResponse TEndpointManager::StartEndpointImpl( } } - auto error = AlterEndpoint(std::move(ctx), *request, *endpoint.Request); + auto error = AlterEndpoint(endpoint, ctx, *request); if (HasError(error)) { return TErrorResponse(error); } + auto describeFuture = endpoint.Session->Describe(ctx, request->GetHeaders()); + auto mountResponse = Executor->WaitFor(describeFuture); + if (HasError(mountResponse)) { + return TErrorResponse(mountResponse.GetError()); + } + NProto::TStartEndpointResponse response; response.MutableError()->CopyFrom(error); - response.MutableVolume()->CopyFrom(endpoint.Volume); + response.MutableVolume()->CopyFrom(mountResponse.GetVolume()); response.SetNbdDeviceFile(endpoint.Request->GetNbdDeviceFile()); return response; } - auto future = SessionManager->CreateSession(ctx, *request); - auto [sessionInfo, error] = Executor->WaitFor(future); + NProto::TVolume volume; + auto future = SessionFactory->CreateSession(ctx, *request, volume); + auto [session, error] = Executor->WaitFor(future); if (HasError(error)) { return TErrorResponse(error); } - error = OpenAllEndpointSockets(*request, sessionInfo); + error = OpenAllEndpointSockets(*request, volume, *session); if (HasError(error)) { - RemoveSession(std::move(ctx), *request); + RemoveSession(*session, std::move(ctx), request->GetHeaders()); return TErrorResponse(error); } - auto deviceOrError = StartNbdDevice(request, restoring, sessionInfo.Volume); + auto deviceOrError = StartNbdDevice(request, restoring, volume); error = deviceOrError.GetError(); if (HasError(error)) { CloseAllEndpointSockets(*request); - RemoveSession(std::move(ctx), *request); + RemoveSession(*session, std::move(ctx), request->GetHeaders()); return TErrorResponse(error); } auto device = deviceOrError.ExtractResult(); @@ -814,15 +826,15 @@ NProto::TStartEndpointResponse TEndpointManager::StartEndpointImpl( } ReleaseNbdDevice(request->GetNbdDeviceFile(), restoring); CloseAllEndpointSockets(*request); - RemoveSession(std::move(ctx), *request); + RemoveSession(*session, std::move(ctx), request->GetHeaders()); return TErrorResponse(error); } } TEndpoint endpoint = { .Request = request, + .Session = session, .Device = device, - .Volume = sessionInfo.Volume, }; if (auto c = ServerStats->GetEndpointCounter(request->GetIpcType())) { @@ -832,16 +844,17 @@ NProto::TStartEndpointResponse TEndpointManager::StartEndpointImpl( STORAGE_VERIFY(inserted, TWellKnownEntityTypes::ENDPOINT, socketPath); NProto::TStartEndpointResponse response; - response.MutableVolume()->CopyFrom(sessionInfo.Volume); + response.MutableVolume()->CopyFrom(volume); response.SetNbdDeviceFile(request->GetNbdDeviceFile()); return response; } NProto::TError TEndpointManager::AlterEndpoint( + const TEndpoint& endpoint, TCallContextPtr ctx, - NProto::TStartEndpointRequest newReq, - NProto::TStartEndpointRequest oldReq) + NProto::TStartEndpointRequest newReq) { + auto oldReq = *endpoint.Request; const auto& socketPath = newReq.GetUnixSocketPath(); // NBS-3018 @@ -880,25 +893,16 @@ NProto::TError TEndpointManager::AlterEndpoint( << " has already been started with other args"); } - auto future = SessionManager->AlterSession( + auto future = endpoint.Session->Alter( ctx, - socketPath, newReq.GetVolumeAccessMode(), newReq.GetVolumeMountMode(), newReq.GetMountSeqNumber(), newReq.GetHeaders()); - if (const auto& error = Executor->WaitFor(future); HasError(error)) { - return error; - } - - auto getSessionFuture = - SessionManager->GetSession(ctx, socketPath, newReq.GetHeaders()); - - const auto& [sessionInfo, error] = Executor->WaitFor(getSessionFuture); - - if (HasError(error)) { - return error; + const auto& mountResponse = Executor->WaitFor(future); + if (HasError(mountResponse)) { + return mountResponse.GetError(); } auto listenerIt = EndpointListeners.find(oldReq.GetIpcType()); @@ -911,26 +915,23 @@ NProto::TError TEndpointManager::AlterEndpoint( auto alterFuture = listener->AlterEndpoint( oldReq, - sessionInfo.Volume, - sessionInfo.Session); - + mountResponse.GetVolume(), + endpoint.Session->GetSession()); return Executor->WaitFor(alterFuture); } NProto::TError TEndpointManager::RestartListenerEndpoint( + const TEndpoint& endpoint, TCallContextPtr ctx, - const NProto::TStartEndpointRequest& request) + const NProto::THeaders& headers) { + const auto& request = *endpoint.Request; STORAGE_INFO("Restart listener endpoint: " << request); - auto sessionFuture = SessionManager->GetSession( - ctx, - request.GetUnixSocketPath(), - request.GetHeaders()); - - auto [sessionInfo, error] = Executor->WaitFor(sessionFuture); - if (HasError(error)) { - return error; + auto describeFuture = endpoint.Session->Describe(ctx, headers); + auto mountResponse = Executor->WaitFor(describeFuture); + if (HasError(mountResponse)) { + return mountResponse.GetError(); } auto listenerIt = EndpointListeners.find(request.GetIpcType()); @@ -942,7 +943,7 @@ NProto::TError TEndpointManager::RestartListenerEndpoint( auto& listener = listenerIt->second; auto future = listener->StopEndpoint(request.GetUnixSocketPath()); - error = Executor->WaitFor(future); + auto error = Executor->WaitFor(future); if (HasError(error)) { STORAGE_ERROR("Failed to stop endpoint while restarting it: " << FormatError(error)); @@ -950,8 +951,8 @@ NProto::TError TEndpointManager::RestartListenerEndpoint( future = listener->StartEndpoint( request, - sessionInfo.Volume, - sessionInfo.Session); + mountResponse.GetVolume(), + endpoint.Session->GetSession()); error = Executor->WaitFor(future); if (HasError(error)) { STORAGE_ERROR("Failed to start endpoint while recreating it: " @@ -1008,7 +1009,7 @@ NProto::TStopEndpointResponse TEndpointManager::StopEndpointImpl( } ReleaseNbdDevice(endpoint.Request->GetNbdDeviceFile(), false); CloseAllEndpointSockets(*endpoint.Request); - RemoveSession(std::move(ctx), *request); + RemoveSession(*endpoint.Session, std::move(ctx), request->GetHeaders()); if (auto error = EndpointStorage->RemoveEndpoint(socketPath); HasError(error) && !HasProtoFlag(error.GetFlags(), NProto::EF_SILENT)) @@ -1122,15 +1123,17 @@ NProto::TDescribeEndpointResponse TEndpointManager::DoDescribeEndpoint( return TErrorResponse(E_REJECTED, "endpoint is restoring now"); } - NProto::TDescribeEndpointResponse response; - - auto [profile, err] = SessionManager->GetProfile(socketPath); - if (HasError(err)) { - response.MutableError()->CopyFrom(err); - } else { - response.MutablePerformanceProfile()->CopyFrom(profile); + auto it = Endpoints.find(socketPath); + if (it == Endpoints.end()) { + return TErrorResponse(S_FALSE, TStringBuilder() + << "endpoint " << socketPath.Quote() + << " hasn't been started yet"); } + auto profile = it->second.Session->GetProfile(); + + NProto::TDescribeEndpointResponse response; + response.MutablePerformanceProfile()->CopyFrom(profile); return response; } @@ -1167,8 +1170,9 @@ NProto::TRefreshEndpointResponse TEndpointManager::RefreshEndpointImpl( return TErrorResponse(S_FALSE, TStringBuilder() << "endpoint " << socketPath.Quote() << " not started"); } + auto endpoint = it->second; - auto ipcType = it->second.Request->GetIpcType(); + auto ipcType = endpoint.Request->GetIpcType(); auto listenerIt = EndpointListeners.find(ipcType); STORAGE_VERIFY( listenerIt != EndpointListeners.end(), @@ -1176,22 +1180,22 @@ NProto::TRefreshEndpointResponse TEndpointManager::RefreshEndpointImpl( socketPath); const auto& listener = listenerIt->second; - auto future = SessionManager->GetSession(std::move(ctx), socketPath, headers); - const auto& [sessionInfo, getSessionError] = Executor->WaitFor(future); - - if (HasError(getSessionError)) { - return TErrorResponse(getSessionError); + auto future = endpoint.Session->Describe(std::move(ctx), headers); + const auto& response = Executor->WaitFor(future); + if (HasError(response)) { + return TErrorResponse(response.GetError()); } - const auto refreshError = listener->RefreshEndpoint(socketPath, sessionInfo.Volume); + const auto refreshError = listener->RefreshEndpoint(socketPath, response.GetVolume()); return TErrorResponse(refreshError); } NProto::TError TEndpointManager::OpenAllEndpointSockets( const NProto::TStartEndpointRequest& request, - const TSessionInfo& sessionInfo) + const NProto::TVolume& volume, + const IEndpointSession& session) { - auto error = OpenEndpointSocket(request, sessionInfo); + auto error = OpenEndpointSocket(request, volume, session); if (HasError(error)) { return error; } @@ -1199,7 +1203,7 @@ NProto::TError TEndpointManager::OpenAllEndpointSockets( auto nbdRequest = CreateNbdStartEndpointRequest(request); if (nbdRequest) { STORAGE_INFO("Start additional endpoint: " << *nbdRequest); - auto error = OpenEndpointSocket(*nbdRequest, sessionInfo); + auto error = OpenEndpointSocket(*nbdRequest, volume, session); if (HasError(error)) { CloseEndpointSocket(request); @@ -1211,7 +1215,8 @@ NProto::TError TEndpointManager::OpenAllEndpointSockets( NProto::TError TEndpointManager::OpenEndpointSocket( const NProto::TStartEndpointRequest& request, - const TSessionInfo& sessionInfo) + const NProto::TVolume& volume, + const IEndpointSession& session) { auto ipcType = request.GetIpcType(); auto listenerIt = EndpointListeners.find(ipcType); @@ -1230,8 +1235,8 @@ NProto::TError TEndpointManager::OpenEndpointSocket( auto future = listener->StartEndpoint( request, - sessionInfo.Volume, - sessionInfo.Session); + volume, + session.GetSession()); return Executor->WaitFor(future); } @@ -1316,7 +1321,10 @@ NProto::TError TEndpointManager::DoSwitchEndpoint( return promise.ExtractValue(); } - auto response = SwitchEndpointImpl(std::move(ctx), std::move(request)); + auto response = SwitchEndpointImpl( + std::move(ctx), + std::move(request), + *it->second.Session); promise.SetValue(response); RemoveProcessingSocket(socketPath); @@ -1325,7 +1333,8 @@ NProto::TError TEndpointManager::DoSwitchEndpoint( NProto::TError TEndpointManager::SwitchEndpointImpl( TCallContextPtr ctx, - std::shared_ptr request) + std::shared_ptr request, + const IEndpointSession& session) { const auto& socketPath = request->GetUnixSocketPath(); @@ -1343,32 +1352,31 @@ NProto::TError TEndpointManager::SwitchEndpointImpl( socketPath); IEndpointListenerPtr listener = listenerIt->second; - auto getSessionFuture = SessionManager->GetSession( + auto future = session.Describe( std::move(ctx), - startRequest->GetUnixSocketPath(), startRequest->GetHeaders()); - const auto& [sessionInfo, getSessionError] = Executor->WaitFor(getSessionFuture); - if (HasError(getSessionError)) { - return getSessionError; + auto response = Executor->WaitFor(future); + if (HasError(response)) { + return response.GetError(); } STORAGE_INFO("Switching endpoint" << ", reason=" << request->GetReason() - << ", volume=" << sessionInfo.Volume.GetDiskId() - << ", IsFastPathEnabled=" << sessionInfo.Volume.GetIsFastPathEnabled() - << ", Migrations=" << sessionInfo.Volume.GetMigrations().size()); + << ", volume=" << response.GetVolume().GetDiskId() + << ", IsFastPathEnabled=" << response.GetVolume().GetIsFastPathEnabled() + << ", Migrations=" << response.GetVolume().GetMigrations().size()); auto switchFuture = listener->SwitchEndpoint( *startRequest, - sessionInfo.Volume, - sessionInfo.Session); + response.GetVolume(), + session.GetSession()); const auto& switchError = Executor->WaitFor(switchFuture); if (HasError(switchError)) { ReportEndpointSwitchFailure(TStringBuilder() << "Failed to switch endpoint for volume " - << sessionInfo.Volume.GetDiskId() + << response.GetVolume().GetDiskId() << ", " << switchError.GetMessage()); } @@ -1601,7 +1609,7 @@ IEndpointManagerPtr CreateEndpointManager( IServerStatsPtr serverStats, TExecutorPtr executor, IEndpointEventProxyPtr eventProxy, - ISessionManagerPtr sessionManager, + ISessionFactoryPtr sessionFactory, IEndpointStoragePtr endpointStorage, THashMap listeners, NBD::IDeviceFactoryPtr nbdDeviceFactory, @@ -1615,7 +1623,7 @@ IEndpointManagerPtr CreateEndpointManager( std::move(volumeStats), std::move(serverStats), std::move(executor), - std::move(sessionManager), + std::move(sessionFactory), std::move(endpointStorage), std::move(listeners), std::move(nbdDeviceFactory), diff --git a/cloud/blockstore/libs/endpoints/endpoint_manager.h b/cloud/blockstore/libs/endpoints/endpoint_manager.h index 1717f72223e..e7defe5812c 100644 --- a/cloud/blockstore/libs/endpoints/endpoint_manager.h +++ b/cloud/blockstore/libs/endpoints/endpoint_manager.h @@ -59,7 +59,7 @@ IEndpointManagerPtr CreateEndpointManager( IServerStatsPtr serverStats, TExecutorPtr executor, IEndpointEventProxyPtr eventProxy, - ISessionManagerPtr sessionManager, + ISessionFactoryPtr sessionFactory, IEndpointStoragePtr endpointStorage, THashMap listeners, NBD::IDeviceFactoryPtr nbdDeviceFactory, diff --git a/cloud/blockstore/libs/endpoints/endpoint_manager_ut.cpp b/cloud/blockstore/libs/endpoints/endpoint_manager_ut.cpp index 85c159a85a9..a16053c7917 100644 --- a/cloud/blockstore/libs/endpoints/endpoint_manager_ut.cpp +++ b/cloud/blockstore/libs/endpoints/endpoint_manager_ut.cpp @@ -56,79 +56,96 @@ static const TString TestClientId = "testClientId"; //////////////////////////////////////////////////////////////////////////////// -struct TTestSessionManager final - : public ISessionManager +struct TTestEndpointSession final + : public IEndpointSession { - ui32 CreateSessionCounter = 0; - NProto::TStartEndpointRequest LastCreateSesionRequest; - ui32 AlterSessionCounter = 0; - TString LastAlterSocketPath; NProto::EVolumeAccessMode LastAlterAccessMode; NProto::EVolumeMountMode LastAlterMountMode; ui64 LastAlterMountSeqNumber; - TFuture CreateSession( - TCallContextPtr ctx, - const NProto::TStartEndpointRequest& request) override - { - Y_UNUSED(ctx); - - ++CreateSessionCounter; - LastCreateSesionRequest = request; - return MakeFuture(TSessionInfo()); - } - - TFuture RemoveSession( + TFuture Remove( TCallContextPtr ctx, - const TString& socketPath, - const NProto::THeaders& headers) override + NProto::THeaders headers) override { Y_UNUSED(ctx); - Y_UNUSED(socketPath); Y_UNUSED(headers); return MakeFuture(NProto::TError()); } - TFuture AlterSession( + TFuture Alter( TCallContextPtr ctx, - const TString& socketPath, NProto::EVolumeAccessMode accessMode, NProto::EVolumeMountMode mountMode, ui64 mountSeqNumber, - const NProto::THeaders& headers) override + NProto::THeaders headers) override { Y_UNUSED(ctx); Y_UNUSED(headers); ++AlterSessionCounter; - LastAlterSocketPath = socketPath; LastAlterAccessMode = accessMode; LastAlterMountMode = mountMode; LastAlterMountSeqNumber = mountSeqNumber; - return MakeFuture(NProto::TError()); + return MakeFuture(NProto::TMountVolumeResponse()); } - TFuture GetSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) override + TFuture Describe( + TCallContextPtr ctx, + NProto::THeaders headers) const override { - Y_UNUSED(callContext); - Y_UNUSED(socketPath); + Y_UNUSED(ctx); Y_UNUSED(headers); - return MakeFuture(TSessionInfo()); + return MakeFuture(NProto::TMountVolumeResponse()); } - TResultOrError GetProfile( - const TString& socketPath) override + NClient::ISessionPtr GetSession() const override + { + return nullptr; + } + + NProto::TClientPerformanceProfile GetProfile() const override { - Y_UNUSED(socketPath); return NProto::TClientPerformanceProfile(); } }; //////////////////////////////////////////////////////////////////////////////// +struct TTestSessionFactory final + : public ISessionFactory +{ + ui32 CreateSessionCounter = 0; + NProto::TStartEndpointRequest LastCreateSesionRequest; + TVector> EndpointSessions; + + TFuture CreateSession( + TCallContextPtr ctx, + const NProto::TStartEndpointRequest& request, + NProto::TVolume& volume) override + { + Y_UNUSED(ctx); + + ++CreateSessionCounter; + LastCreateSesionRequest = request; + + volume = {}; + auto endpointSession = std::make_shared(); + EndpointSessions.push_back(endpointSession); + return MakeFuture(TSessionOrError(endpointSession)); + } + + ui32 GetAlterSessionCounter() const + { + ui32 alterSessionCounter = 0; + for (auto& session: EndpointSessions) { + alterSessionCounter += session->AlterSessionCounter; + } + return alterSessionCounter; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + struct TTestDeviceFactory : public NBD::IDeviceFactory { @@ -268,7 +285,7 @@ struct TBootstrap IRequestStatsPtr RequestStats = CreateRequestStatsStub(); IVolumeStatsPtr VolumeStats = CreateVolumeStatsStub(); IServerStatsPtr ServerStats = CreateServerStatsStub(); - ISessionManagerPtr SessionManager; + ISessionFactoryPtr SessionFactory; IEndpointStoragePtr EndpointStorage = CreateFileEndpointStorage(DirPath); IMutableEndpointStoragePtr MutableStorage = CreateFileMutableEndpointStorage(DirPath); THashMap EndpointListeners; @@ -431,16 +448,16 @@ std::shared_ptr CreateTestService( IEndpointManagerPtr CreateEndpointManager(TBootstrap& bootstrap) { - if (!bootstrap.SessionManager) { - TSessionManagerOptions sessionManagerOptions; - sessionManagerOptions.DefaultClientConfig.SetRequestTimeout( + if (!bootstrap.SessionFactory) { + TSessionFactoryOptions sessionFactoryOptions; + sessionFactoryOptions.DefaultClientConfig.SetRequestTimeout( TestRequestTimeout.MilliSeconds()); auto encryptionClientFactory = CreateEncryptionClientFactory( bootstrap.Logging, CreateDefaultEncryptionKeyProvider()); - bootstrap.SessionManager = CreateSessionManager( + bootstrap.SessionFactory = CreateSessionFactory( bootstrap.Timer, bootstrap.Scheduler, bootstrap.Logging, @@ -452,7 +469,7 @@ IEndpointManagerPtr CreateEndpointManager(TBootstrap& bootstrap) CreateDefaultStorageProvider(bootstrap.Service), std::move(encryptionClientFactory), bootstrap.Executor, - std::move(sessionManagerOptions)); + std::move(sessionFactoryOptions)); } bootstrap.EndpointManager = NServer::CreateEndpointManager( @@ -464,7 +481,7 @@ IEndpointManagerPtr CreateEndpointManager(TBootstrap& bootstrap) bootstrap.ServerStats, bootstrap.Executor, bootstrap.EndpointEventHandler, - bootstrap.SessionManager, + bootstrap.SessionFactory, bootstrap.EndpointStorage, bootstrap.EndpointListeners, bootstrap.NbdDeviceFactory, @@ -1134,8 +1151,8 @@ Y_UNIT_TEST_SUITE(TEndpointManagerTest) TMap mountedVolumes; bootstrap.Service = CreateTestService(mountedVolumes); - auto sessionManager = std::make_shared(); - bootstrap.SessionManager = sessionManager; + auto sessionFactory = std::make_shared(); + bootstrap.SessionFactory = sessionFactory; auto listener = std::make_shared(); bootstrap.EndpointListeners = {{ NProto::IPC_GRPC, listener }}; @@ -1170,13 +1187,13 @@ Y_UNIT_TEST_SUITE(TEndpointManagerTest) S_OK == response.GetError().GetCode(), response.GetError()); - UNIT_ASSERT_VALUES_EQUAL(1, sessionManager->CreateSessionCounter); - UNIT_ASSERT_VALUES_EQUAL(0, sessionManager->AlterSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(1, sessionFactory->CreateSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(0, sessionFactory->GetAlterSessionCounter()); google::protobuf::util::MessageDifferencer comparator; UNIT_ASSERT(comparator.Equals( request, - sessionManager->LastCreateSesionRequest)); + sessionFactory->LastCreateSesionRequest)); } { @@ -1188,8 +1205,8 @@ Y_UNIT_TEST_SUITE(TEndpointManagerTest) S_ALREADY == response.GetError().GetCode(), response.GetError()); - UNIT_ASSERT_VALUES_EQUAL(1, sessionManager->CreateSessionCounter); - UNIT_ASSERT_VALUES_EQUAL(0, sessionManager->AlterSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(1, sessionFactory->CreateSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(0, sessionFactory->GetAlterSessionCounter()); } { @@ -1204,16 +1221,15 @@ Y_UNIT_TEST_SUITE(TEndpointManagerTest) S_OK == response.GetError().GetCode(), response.GetError()); - UNIT_ASSERT_VALUES_EQUAL(1, sessionManager->CreateSessionCounter); - UNIT_ASSERT_VALUES_EQUAL(1, sessionManager->AlterSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(1, sessionFactory->CreateSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(1, sessionFactory->GetAlterSessionCounter()); + auto endpointSession = sessionFactory->EndpointSessions[0]; - UNIT_ASSERT_VALUES_EQUAL( - request.GetUnixSocketPath(), sessionManager->LastAlterSocketPath); UNIT_ASSERT( - NProto::VOLUME_ACCESS_READ_WRITE == sessionManager->LastAlterAccessMode); + NProto::VOLUME_ACCESS_READ_WRITE == endpointSession->LastAlterAccessMode); UNIT_ASSERT( - NProto::VOLUME_MOUNT_LOCAL == sessionManager->LastAlterMountMode); - UNIT_ASSERT_VALUES_EQUAL(42, sessionManager->LastAlterMountSeqNumber); + NProto::VOLUME_MOUNT_LOCAL == endpointSession->LastAlterMountMode); + UNIT_ASSERT_VALUES_EQUAL(42, endpointSession->LastAlterMountSeqNumber); } { @@ -1223,8 +1239,8 @@ Y_UNIT_TEST_SUITE(TEndpointManagerTest) auto response = future.GetValue(TDuration::Seconds(5)); UNIT_ASSERT_C(HasError(response.GetError()), response.GetError()); - UNIT_ASSERT_VALUES_EQUAL(1, sessionManager->CreateSessionCounter); - UNIT_ASSERT_VALUES_EQUAL(1, sessionManager->AlterSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(1, sessionFactory->CreateSessionCounter); + UNIT_ASSERT_VALUES_EQUAL(1, sessionFactory->GetAlterSessionCounter()); } } diff --git a/cloud/blockstore/libs/endpoints/public.h b/cloud/blockstore/libs/endpoints/public.h index 4005f12a3be..9514edbdab4 100644 --- a/cloud/blockstore/libs/endpoints/public.h +++ b/cloud/blockstore/libs/endpoints/public.h @@ -13,8 +13,11 @@ constexpr size_t UnixSocketPathLengthLimit = 107; //////////////////////////////////////////////////////////////////////////////// -struct ISessionManager; -using ISessionManagerPtr = std::shared_ptr; +struct IEndpointSession; +using IEndpointSessionPtr = std::shared_ptr; + +struct ISessionFactory; +using ISessionFactoryPtr = std::shared_ptr; struct IEndpointManager; using IEndpointManagerPtr = std::shared_ptr; diff --git a/cloud/blockstore/libs/endpoints/service_endpoint_ut.cpp b/cloud/blockstore/libs/endpoints/service_endpoint_ut.cpp index 69f327ae6f8..d210479f39c 100644 --- a/cloud/blockstore/libs/endpoints/service_endpoint_ut.cpp +++ b/cloud/blockstore/libs/endpoints/service_endpoint_ut.cpp @@ -111,97 +111,81 @@ struct TTestEndpointListener final //////////////////////////////////////////////////////////////////////////////// -struct TTestSessionManager final - : public ISessionManager +struct TTestEndpointSession final + : public IEndpointSession { - using TCreateSessionHandler - = std::function()>; - - using TRemoveSessionHandler - = std::function()>; + using TRemoveHandler = std::function()>; - using TAlterSessionHandler - = std::function()>; - - using TGetSessionHandler - = std::function()>; - - using TGetProfileHandler - = std::function()>; - - TCreateSessionHandler CreateSessionHandler = [] () { - return MakeFuture(TSessionInfo()); - }; - - TRemoveSessionHandler RemoveSessionHandler = [] () { - return MakeFuture(NProto::TError()); - }; - - TAlterSessionHandler AlterSessionHandler = [] () { + TRemoveHandler RemoveHandler = [] () { return MakeFuture(NProto::TError()); }; - TGetSessionHandler GetSessionHandler = [] () { - return MakeFuture(TSessionInfo()); - }; - - TGetProfileHandler GetProfileHandler = [] () { - return NProto::TClientPerformanceProfile(); - }; - - TFuture CreateSession( - TCallContextPtr callContext, - const NProto::TStartEndpointRequest& request) override - { - Y_UNUSED(callContext); - Y_UNUSED(request); - return CreateSessionHandler(); - } - - TFuture RemoveSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) override + TFuture Remove( + TCallContextPtr ctx, + NProto::THeaders headers) override { - Y_UNUSED(callContext); - Y_UNUSED(socketPath); + Y_UNUSED(ctx); Y_UNUSED(headers); - return RemoveSessionHandler(); + return RemoveHandler(); } - TFuture AlterSession( - TCallContextPtr callContext, - const TString& socketPath, + TFuture Alter( + TCallContextPtr ctx, NProto::EVolumeAccessMode accessMode, NProto::EVolumeMountMode mountMode, ui64 mountSeqNumber, - const NProto::THeaders& headers) override + NProto::THeaders headers) override { - Y_UNUSED(callContext); - Y_UNUSED(socketPath); + Y_UNUSED(ctx); + Y_UNUSED(headers); Y_UNUSED(accessMode); Y_UNUSED(mountMode); Y_UNUSED(mountSeqNumber); - Y_UNUSED(headers); - return AlterSessionHandler(); + return MakeFuture(NProto::TMountVolumeResponse()); } - TFuture GetSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) override + TFuture Describe( + TCallContextPtr ctx, + NProto::THeaders headers) const override { - Y_UNUSED(callContext); - Y_UNUSED(socketPath); + Y_UNUSED(ctx); Y_UNUSED(headers); - return GetSessionHandler(); + return MakeFuture(NProto::TMountVolumeResponse()); } - TResultOrError GetProfile( - const TString& socketPath) override + NClient::ISessionPtr GetSession() const override { - Y_UNUSED(socketPath); - return GetProfileHandler(); + return nullptr; + } + + NProto::TClientPerformanceProfile GetProfile() const override + { + return NProto::TClientPerformanceProfile(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TTestSessionFactory final + : public ISessionFactory +{ + using TCreateSessionHandler + = std::function()>; + + TCreateSessionHandler CreateSessionHandler = [] () { + IEndpointSessionPtr session = std::make_shared(); + return MakeFuture(session); + }; + + TFuture CreateSession( + TCallContextPtr callContext, + const NProto::TStartEndpointRequest& request, + NProto::TVolume& volume) override + { + Y_UNUSED(callContext); + Y_UNUSED(request); + volume = {}; + return CreateSessionHandler(); } }; @@ -271,7 +255,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {{ ipcType, listener }}, nullptr, // nbdDeviceFactory @@ -340,14 +324,15 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) scheduler->Stop(); }; - auto startEndpointPromise = NewPromise(); + auto startEndpointPromise = NewPromise(); auto stopEndpointPromise = NewPromise(); - auto sessionManager = std::make_shared(); - sessionManager->CreateSessionHandler = [&] () { + auto sessionFactory = std::make_shared(); + sessionFactory->CreateSessionHandler = [&] () { return startEndpointPromise; }; - sessionManager->RemoveSessionHandler = [&] () { + auto session = std::make_shared(); + session->RemoveHandler = [&] () { return stopEndpointPromise; }; @@ -363,7 +348,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - sessionManager, + sessionFactory, endpointStorage, {{ NProto::IPC_GRPC, std::make_shared() }}, nullptr, // nbdDeviceFactory @@ -413,7 +398,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) response.GetError().GetCode(), response); - startEndpointPromise.SetValue(NProto::TError{}); + startEndpointPromise.SetValue(IEndpointSessionPtr(session)); auto future2 = endpointService->StopEndpoint( MakeIntrusive(), @@ -427,7 +412,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) } { - startEndpointPromise = NewPromise(); + startEndpointPromise = NewPromise(); NProto::TStartEndpointRequest startRequest; startRequest.SetDiskId(diskId); @@ -469,7 +454,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) response2.GetError().GetCode(), response2); - startEndpointPromise.SetValue(NProto::TError{}); + startEndpointPromise.SetValue(IEndpointSessionPtr(session)); request->MutableHeaders()->SetRequestTimeout(3000); auto future3 = endpointService->KickEndpoint( @@ -573,7 +558,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {{ NProto::IPC_GRPC, listener }}, nullptr, // nbdDeviceFactory @@ -614,7 +599,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {}, // listeners nullptr, // nbdDeviceFactory @@ -651,7 +636,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {}, // listeners nullptr, // nbdDeviceFactory @@ -718,7 +703,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {{ NProto::IPC_GRPC, listener }}, nullptr, // nbdDeviceFactory @@ -798,7 +783,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {{ NProto::IPC_GRPC, std::make_shared() }}, nullptr, // nbdDeviceFactory @@ -957,7 +942,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {{ NProto::IPC_GRPC, listener }}, nullptr, // nbdDeviceFactory @@ -1075,7 +1060,7 @@ Y_UNIT_TEST_SUITE(TServiceEndpointTest) CreateServerStatsStub(), executor, CreateEndpointEventProxy(), - std::make_shared(), + std::make_shared(), endpointStorage, {{ NProto::IPC_GRPC, listener }}, nullptr, // nbdDeviceFactory diff --git a/cloud/blockstore/libs/endpoints/session_manager.cpp b/cloud/blockstore/libs/endpoints/session_manager.cpp index b42c5cb329e..3c93fb99add 100644 --- a/cloud/blockstore/libs/endpoints/session_manager.cpp +++ b/cloud/blockstore/libs/endpoints/session_manager.cpp @@ -33,93 +33,104 @@ namespace { //////////////////////////////////////////////////////////////////////////////// -class TEndpoint +class TEndpointSession final + : public IEndpointSession + , public std::enable_shared_from_this { private: - TExecutor& Executor; - const ISessionPtr Session; - const IBlockStorePtr DataClient; + ISessionPtr Session; + IBlockStorePtr DataClient; const IThrottlerProviderPtr ThrottlerProvider; const TString ClientId; const TString DiskId; public: - TEndpoint( - TExecutor& executor, + TEndpointSession( ISessionPtr session, IBlockStorePtr dataClient, IThrottlerProviderPtr throttlerProvider, TString clientId, TString diskId) - : Executor(executor) - , Session(std::move(session)) + : Session(std::move(session)) , DataClient(std::move(dataClient)) , ThrottlerProvider(std::move(throttlerProvider)) , ClientId(std::move(clientId)) , DiskId(std::move(diskId)) {} - NProto::TError Start(TCallContextPtr callContext, NProto::THeaders headers) + TFuture Start( + TCallContextPtr callContext, + NProto::THeaders headers) { DataClient->Start(); headers.SetClientId(ClientId); auto future = Session->MountVolume(std::move(callContext), headers); - const auto& response = Executor.WaitFor(future); - - if (HasError(response)) { - DataClient->Stop(); - } - - return response.GetError(); + return future.Apply([weakPtr = weak_from_this()] (const auto& f) { + if (auto p = weakPtr.lock()) { + if (HasError(f.GetValue())) { + p->DataClient->Stop(); + } + } + return f.GetValue().GetError(); + }); } - NProto::TError Stop(TCallContextPtr callContext, NProto::THeaders headers) + TFuture Remove( + TCallContextPtr callContext, + NProto::THeaders headers) override { headers.SetClientId(ClientId); auto future = Session->UnmountVolume(std::move(callContext), headers); - const auto& response = Executor.WaitFor(future); - DataClient->Stop(); - return response.GetError(); + return future.Apply([weakPtr = weak_from_this()] (const auto& f) { + if (auto p = weakPtr.lock()) { + p->DataClient->Stop(); + + p->Session.reset(); + p->DataClient.reset(); + p->ThrottlerProvider->Clean(); + }; + + return f.GetValue().GetError(); + }); } - NProto::TError Alter( + TFuture Alter( TCallContextPtr callContext, NProto::EVolumeAccessMode accessMode, NProto::EVolumeMountMode mountMode, ui64 mountSeqNumber, - NProto::THeaders headers) + NProto::THeaders headers) override { headers.SetClientId(ClientId); - auto future = Session->MountVolume( + return Session->MountVolume( accessMode, mountMode, mountSeqNumber, std::move(callContext), headers); - const auto& response = Executor.WaitFor(future); - return response.GetError(); } - ISessionPtr GetSession() + TFuture Describe( + TCallContextPtr callContext, + NProto::THeaders headers) const override { - return Session; + headers.SetClientId(ClientId); + return Session->MountVolume(std::move(callContext), headers); } - TString GetDiskId() + ISessionPtr GetSession() const override { - return DiskId; + return Session; } - NProto::TClientPerformanceProfile GetPerformanceProfile() + NProto::TClientPerformanceProfile GetProfile() const override { return ThrottlerProvider->GetPerformanceProfile(ClientId); } }; -using TEndpointPtr = std::shared_ptr; - //////////////////////////////////////////////////////////////////////////////// class TClientBase @@ -269,9 +280,8 @@ class TStorageDataClient final //////////////////////////////////////////////////////////////////////////////// -class TSessionManager final - : public ISessionManager - , public std::enable_shared_from_this +class TSessionFactory final + : public ISessionFactory { private: const ITimerPtr Timer; @@ -286,15 +296,12 @@ class TSessionManager final const IThrottlerProviderPtr ThrottlerProvider; const IEncryptionClientFactoryPtr EncryptionClientFactory; const TExecutorPtr Executor; - const TSessionManagerOptions Options; + const TSessionFactoryOptions Options; TLog Log; - TMutex EndpointLock; - THashMap Endpoints; - public: - TSessionManager( + TSessionFactory( ITimerPtr timer, ISchedulerPtr scheduler, ILoggingServicePtr logging, @@ -307,7 +314,7 @@ class TSessionManager final IThrottlerProviderPtr throttlerProvider, IEncryptionClientFactoryPtr encryptionClientFactory, TExecutorPtr executor, - TSessionManagerOptions options) + TSessionFactoryOptions options) : Timer(std::move(timer)) , Scheduler(std::move(scheduler)) , Logging(std::move(logging)) @@ -327,60 +334,14 @@ class TSessionManager final TFuture CreateSession( TCallContextPtr callContext, - const NProto::TStartEndpointRequest& request) override; - - TFuture RemoveSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) override; - - TFuture AlterSession( - TCallContextPtr callContext, - const TString& socketPath, - NProto::EVolumeAccessMode accessMode, - NProto::EVolumeMountMode mountMode, - ui64 mountSeqNumber, - const NProto::THeaders& headers) override; - - TFuture GetSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) override; - - TResultOrError GetProfile( - const TString& socketPath) override; + const NProto::TStartEndpointRequest& request, + NProto::TVolume& volume) override; private: TSessionOrError CreateSessionImpl( TCallContextPtr callContext, - const NProto::TStartEndpointRequest& request); - - NProto::TError RemoveSessionImpl( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers); - - NProto::TError AlterSessionImpl( - TCallContextPtr callContext, - const TString& socketPath, - NProto::EVolumeAccessMode accessMode, - NProto::EVolumeMountMode mountMode, - ui64 mountSeqNumber, - const NProto::THeaders& headers); - - TSessionOrError GetSessionImpl( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers); - - NProto::TDescribeVolumeResponse DescribeVolume( - TCallContextPtr callContext, - const TString& diskId, - const NProto::THeaders& headers); - - TResultOrError CreateEndpoint( const NProto::TStartEndpointRequest& request, - const NProto::TVolume& volume); + NProto::TVolume& volume); TClientAppConfigPtr CreateClientConfig( const NProto::TStartEndpointRequest& request); @@ -391,225 +352,35 @@ class TSessionManager final //////////////////////////////////////////////////////////////////////////////// -TFuture TSessionManager::CreateSession( +TFuture TSessionFactory::CreateSession( TCallContextPtr callContext, - const NProto::TStartEndpointRequest& request) + const NProto::TStartEndpointRequest& request, + NProto::TVolume& volume) { - return Executor->Execute([=] () mutable { - return CreateSessionImpl(std::move(callContext), request); + return Executor->Execute([=, &volume] () mutable { + return CreateSessionImpl(std::move(callContext), request, volume); }); } -TSessionManager::TSessionOrError TSessionManager::CreateSessionImpl( - TCallContextPtr callContext, - const NProto::TStartEndpointRequest& request) -{ - auto describeResponse = DescribeVolume( - callContext, - request.GetDiskId(), - request.GetHeaders()); - if (HasError(describeResponse)) { - return TErrorResponse(describeResponse.GetError()); - } - const auto& volume = describeResponse.GetVolume(); - - auto result = CreateEndpoint(request, volume); - if (HasError(result)) { - return TErrorResponse(result.GetError()); - } - const auto& endpoint = result.GetResult(); - - auto error = endpoint->Start(std::move(callContext), request.GetHeaders()); - if (HasError(error)) { - return TErrorResponse(error); - } - - with_lock (EndpointLock) { - auto [it, inserted] = Endpoints.emplace( - request.GetUnixSocketPath(), - endpoint); - STORAGE_VERIFY( - inserted, - TWellKnownEntityTypes::ENDPOINT, - request.GetUnixSocketPath()); - } - - return TSessionInfo { - .Volume = volume, - .Session = endpoint->GetSession() - }; -} - -NProto::TDescribeVolumeResponse TSessionManager::DescribeVolume( +TSessionFactory::TSessionOrError TSessionFactory::CreateSessionImpl( TCallContextPtr callContext, - const TString& diskId, - const NProto::THeaders& headers) + const NProto::TStartEndpointRequest& request, + NProto::TVolume& volume) { auto describeRequest = std::make_shared(); - describeRequest->MutableHeaders()->CopyFrom(headers); - describeRequest->SetDiskId(diskId); + describeRequest->MutableHeaders()->CopyFrom(request.GetHeaders()); + describeRequest->SetDiskId(request.GetDiskId()); - auto future = Service->DescribeVolume( - std::move(callContext), + auto describeFuture = Service->DescribeVolume( + callContext, std::move(describeRequest)); - return Executor->WaitFor(future); -} - -TFuture TSessionManager::RemoveSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) -{ - return Executor->Execute([=] () mutable { - return RemoveSessionImpl(std::move(callContext), socketPath, headers); - }); -} - -NProto::TError TSessionManager::RemoveSessionImpl( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) -{ - TEndpointPtr endpoint; - - with_lock (EndpointLock) { - auto it = Endpoints.find(socketPath); - STORAGE_VERIFY( - it != Endpoints.end(), - TWellKnownEntityTypes::ENDPOINT, - socketPath); - endpoint = std::move(it->second); - Endpoints.erase(it); - } - - auto error = endpoint->Stop(std::move(callContext), headers); - - endpoint.reset(); - ThrottlerProvider->Clean(); - - return error; -} - -TFuture TSessionManager::AlterSession( - TCallContextPtr callContext, - const TString& socketPath, - NProto::EVolumeAccessMode accessMode, - NProto::EVolumeMountMode mountMode, - ui64 mountSeqNumber, - const NProto::THeaders& headers) -{ - return Executor->Execute([=] () mutable { - return AlterSessionImpl( - std::move(callContext), - socketPath, - accessMode, - mountMode, - mountSeqNumber, - headers); - }); -} - -NProto::TError TSessionManager::AlterSessionImpl( - TCallContextPtr callContext, - const TString& socketPath, - NProto::EVolumeAccessMode accessMode, - NProto::EVolumeMountMode mountMode, - ui64 mountSeqNumber, - const NProto::THeaders& headers) -{ - TEndpointPtr endpoint; - - with_lock (EndpointLock) { - auto it = Endpoints.find(socketPath); - if (it == Endpoints.end()) { - return TErrorResponse( - E_INVALID_STATE, - TStringBuilder() - << "endpoint " << socketPath.Quote() - << " hasn't been started"); - } - endpoint = it->second; - } - - return endpoint->Alter( - std::move(callContext), - accessMode, - mountMode, - mountSeqNumber, - headers); -} - -TFuture TSessionManager::GetSession( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) -{ - return Executor->Execute([=] () mutable { - return GetSessionImpl( - std::move(callContext), - socketPath, - headers); - }); -} - -TSessionManager::TSessionOrError TSessionManager::GetSessionImpl( - TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) -{ - TEndpointPtr endpoint; - - with_lock (EndpointLock) { - auto it = Endpoints.find(socketPath); - if (it == Endpoints.end()) { - return TErrorResponse( - E_INVALID_STATE, - TStringBuilder() - << "endpoint " << socketPath.Quote() - << " hasn't been started"); - } - endpoint = it->second; - } - - auto describeResponse = DescribeVolume( - std::move(callContext), - endpoint->GetDiskId(), - headers); + auto describeResponse = Executor->WaitFor(describeFuture); if (HasError(describeResponse)) { return TErrorResponse(describeResponse.GetError()); } - return TSessionInfo { - .Volume = describeResponse.GetVolume(), - .Session = endpoint->GetSession() - }; -} - -TResultOrError TSessionManager::GetProfile( - const TString& socketPath) -{ - TEndpointPtr endpoint; - - with_lock (EndpointLock) { - auto it = Endpoints.find(socketPath); - if (it == Endpoints.end()) { - return TErrorResponse( - E_INVALID_STATE, - TStringBuilder() - << "endpoint " << socketPath.Quote() - << " hasn't been started"); - } - endpoint = it->second; - } - - return endpoint->GetPerformanceProfile(); -} - -TResultOrError TSessionManager::CreateEndpoint( - const NProto::TStartEndpointRequest& request, - const NProto::TVolume& volume) -{ + volume = describeResponse.GetVolume(); const auto clientId = request.GetClientId(); auto accessMode = request.GetVolumeAccessMode(); @@ -702,16 +473,25 @@ TResultOrError TSessionManager::CreateEndpoint( std::move(clientConfig), CreateSessionConfig(request)); - return std::make_shared( - *Executor, + auto endpoint = std::make_shared( std::move(session), std::move(client), ThrottlerProvider, clientId, volume.GetDiskId()); + + auto startFuture = endpoint->Start( + std::move(callContext), + request.GetHeaders()); + auto error = Executor->WaitFor(startFuture); + if (HasError(error)) { + return error; + } + + return TSessionOrError(endpoint); } -TClientAppConfigPtr TSessionManager::CreateClientConfig( +TClientAppConfigPtr TSessionFactory::CreateClientConfig( const NProto::TStartEndpointRequest& request) { NProto::TClientAppConfig clientAppConfig; @@ -734,7 +514,7 @@ TClientAppConfigPtr TSessionManager::CreateClientConfig( return std::make_shared(std::move(clientAppConfig)); } -TSessionConfig TSessionManager::CreateSessionConfig( +TSessionConfig TSessionFactory::CreateSessionConfig( const NProto::TStartEndpointRequest& request) { TSessionConfig config; @@ -753,7 +533,7 @@ TSessionConfig TSessionManager::CreateSessionConfig( //////////////////////////////////////////////////////////////////////////////// -ISessionManagerPtr CreateSessionManager( +ISessionFactoryPtr CreateSessionFactory( ITimerPtr timer, ISchedulerPtr scheduler, ILoggingServicePtr logging, @@ -765,7 +545,7 @@ ISessionManagerPtr CreateSessionManager( IStorageProviderPtr storageProvider, IEncryptionClientFactoryPtr encryptionClientFactory, TExecutorPtr executor, - TSessionManagerOptions options) + TSessionFactoryOptions options) { auto throttlerProvider = CreateThrottlerProvider( options.HostProfile, @@ -776,7 +556,7 @@ ISessionManagerPtr CreateSessionManager( requestStats, volumeStats); - return std::make_shared( + return std::make_shared( std::move(timer), std::move(scheduler), std::move(logging), diff --git a/cloud/blockstore/libs/endpoints/session_manager.h b/cloud/blockstore/libs/endpoints/session_manager.h index b7a2c8cb400..bcf8059076e 100644 --- a/cloud/blockstore/libs/endpoints/session_manager.h +++ b/cloud/blockstore/libs/endpoints/session_manager.h @@ -21,49 +21,46 @@ namespace NCloud::NBlockStore::NServer { //////////////////////////////////////////////////////////////////////////////// -struct TSessionInfo +struct IEndpointSession { - NProto::TVolume Volume; - NClient::ISessionPtr Session; -}; - -//////////////////////////////////////////////////////////////////////////////// - -struct ISessionManager -{ - virtual ~ISessionManager() = default; - - using TSessionOrError = TResultOrError; - - virtual NThreading::TFuture CreateSession( - TCallContextPtr callContext, - const NProto::TStartEndpointRequest& request) = 0; + virtual ~IEndpointSession() = default; - virtual NThreading::TFuture RemoveSession( + virtual NThreading::TFuture Remove( TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) = 0; + NProto::THeaders headers) = 0; - virtual NThreading::TFuture AlterSession( + virtual NThreading::TFuture Alter( TCallContextPtr callContext, - const TString& socketPath, NProto::EVolumeAccessMode accessMode, NProto::EVolumeMountMode mountMode, ui64 mountSeqNumber, - const NProto::THeaders& headers) = 0; + NProto::THeaders headers) = 0; - virtual NThreading::TFuture GetSession( + virtual NThreading::TFuture Describe( TCallContextPtr callContext, - const TString& socketPath, - const NProto::THeaders& headers) = 0; + NProto::THeaders headers) const = 0; - virtual TResultOrError GetProfile( - const TString& socketPath) = 0; + virtual NClient::ISessionPtr GetSession() const = 0; + virtual NProto::TClientPerformanceProfile GetProfile() const = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct ISessionFactory +{ + virtual ~ISessionFactory() = default; + + using TSessionOrError = TResultOrError; + + virtual NThreading::TFuture CreateSession( + TCallContextPtr callContext, + const NProto::TStartEndpointRequest& request, + NProto::TVolume& volume) = 0; }; //////////////////////////////////////////////////////////////////////////////// -struct TSessionManagerOptions +struct TSessionFactoryOptions { bool StrictContractValidation = false; bool TemporaryServer = false; @@ -75,7 +72,7 @@ struct TSessionManagerOptions //////////////////////////////////////////////////////////////////////////////// -ISessionManagerPtr CreateSessionManager( +ISessionFactoryPtr CreateSessionFactory( ITimerPtr timer, ISchedulerPtr scheduler, ILoggingServicePtr logging, @@ -87,6 +84,6 @@ ISessionManagerPtr CreateSessionManager( IStorageProviderPtr storageProvider, IEncryptionClientFactoryPtr encryptionClientFactory, TExecutorPtr executor, - TSessionManagerOptions options); + TSessionFactoryOptions options); } // namespace NCloud::NBlockStore::NServer diff --git a/cloud/blockstore/libs/endpoints/session_manager_ut.cpp b/cloud/blockstore/libs/endpoints/session_manager_ut.cpp index 747c787bc40..91babb565fb 100644 --- a/cloud/blockstore/libs/endpoints/session_manager_ut.cpp +++ b/cloud/blockstore/libs/endpoints/session_manager_ut.cpp @@ -66,7 +66,7 @@ struct TBootstrap //////////////////////////////////////////////////////////////////////////////// -Y_UNIT_TEST_SUITE(TSessionManagerTest) +Y_UNIT_TEST_SUITE(TSessionFactoryTest) { void ServerStatsShouldMountVolumeWhenEndpointIsStarted( NProto::EClientIpcType ipcType) @@ -124,7 +124,7 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) logging, CreateDefaultEncryptionKeyProvider()); - auto sessionManager = CreateSessionManager( + auto sessionFactory = CreateSessionFactory( CreateWallClockTimer(), scheduler, logging, @@ -136,7 +136,7 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) CreateDefaultStorageProvider(service), encryptionClientFactory, executor, - TSessionManagerOptions()); + TSessionFactoryOptions()); executor->Start(); Y_DEFER { @@ -149,13 +149,18 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) request.SetClientId("testClientId"); request.SetIpcType(ipcType); + IEndpointSessionPtr endpointSession; + { - auto future = sessionManager->CreateSession( + NProto::TVolume volume; + auto future = sessionFactory->CreateSession( MakeIntrusive(), - request); + request, + volume); auto sessionOrError = future.GetValue(TDuration::Seconds(3)); UNIT_ASSERT_C(!HasError(sessionOrError), sessionOrError.GetError()); + endpointSession = sessionOrError.GetResult(); } ui32 expectedCount = 1; @@ -168,9 +173,8 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) } { - auto future = sessionManager->RemoveSession( + auto future = endpointSession->Remove( MakeIntrusive(), - socketPath, request.GetHeaders()); auto error = future.GetValue(TDuration::Seconds(3)); UNIT_ASSERT_C(!HasError(error), error); @@ -231,7 +235,7 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) auto executor = TExecutor::Create("TestService"); auto logging = CreateLoggingService("console"); - TSessionManagerOptions options; + TSessionFactoryOptions options; options.TemporaryServer = temporaryServer; options.DisableDurableClient = true; @@ -239,7 +243,7 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) logging, CreateDefaultEncryptionKeyProvider()); - auto sessionManager = CreateSessionManager( + auto sessionFactory = CreateSessionFactory( CreateWallClockTimer(), CreateSchedulerStub(), logging, @@ -263,13 +267,15 @@ Y_UNIT_TEST_SUITE(TSessionManagerTest) request.SetDiskId(diskId); request.SetClientId("testClientId"); - auto future = sessionManager->CreateSession( + NProto::TVolume volume; + auto future = sessionFactory->CreateSession( MakeIntrusive(), - request); + request, + volume); auto sessionOrError = future.GetValue(TDuration::Seconds(3)); UNIT_ASSERT_C(!HasError(sessionOrError), sessionOrError.GetError()); - auto session = sessionOrError.GetResult().Session; + auto session = sessionOrError.GetResult()->GetSession(); { auto future = session->ReadBlocksLocal(