diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/DefaultSource.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/DefaultSource.scala index 2db87e7..0b4ec31 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/DefaultSource.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/DefaultSource.scala @@ -33,7 +33,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - new MongodbRelation(MongodbConfigBuilder(parseParameters(parameters)).build())(sqlContext) + new MongodbRelation(MongodbConfigBuilder(parameters).build())(sqlContext) } @@ -42,7 +42,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr parameters: Map[String, String], schema: StructType): BaseRelation = { - new MongodbRelation(MongodbConfigBuilder(parseParameters(parameters)).build(), Some(schema))(sqlContext) + new MongodbRelation(MongodbConfigBuilder(parameters).build(), Some(schema))(sqlContext) } @@ -53,7 +53,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr data: DataFrame): BaseRelation = { val mongodbRelation = new MongodbRelation( - MongodbConfigBuilder(parseParameters(parameters)).build(), Some(data.schema))(sqlContext) + MongodbConfigBuilder(parameters).build(), Some(data.schema))(sqlContext) mode match{ case Append => mongodbRelation.insert(data, overwrite = false) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbRelation.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbRelation.scala index a1a951d..2b21743 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbRelation.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/MongodbRelation.scala @@ -58,7 +58,7 @@ with PrunedFilteredScan with InsertableRelation { @transient private lazy val lazySchema = MongodbSchema( new MongodbRDD(sqlContext, config, rddPartitioner), - config.get[Any](MongodbConfig.SamplingRatio).fold(MongodbConfig.DefaultSamplingRatio)(_.toString.toDouble)).schema() + config.getOrElse(MongodbConfig.SamplingRatio, MongodbConfig.DefaultSamplingRatio)).schema() override val schema: StructType = schemaProvided.getOrElse(lazySchema) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala index 531ff35..7eed1a8 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/client/MongodbClientActor.scala @@ -79,7 +79,7 @@ class MongodbClientActor extends Actor { mongoClient.update(finalKey, connection.copy( timeOut = System.currentTimeMillis() + - extractValue[String](clientOptions, ConnectionsTime).map(_.toLong).getOrElse(DefaultConnectionsTime), + extractValue[Long](clientOptions, ConnectionsTime).getOrElse(DefaultConnectionsTime), status = ConnectionStatus.Busy)) sender ! ClientResponse(finalKey, connection.client) @@ -187,13 +187,10 @@ class MongodbClientActor extends Actor { case Some(preference) => parseReadPreference(preference) case None => DefaultReadPreference }) - .connectTimeout(extractValue[String](clientOptions, ConnectTimeout).map(_.toInt) - .getOrElse(DefaultConnectTimeout)) - .connectionsPerHost(extractValue[String](clientOptions, ConnectionsPerHost).map(_.toInt) - .getOrElse(DefaultConnectionsPerHost)) - .maxWaitTime(extractValue[String](clientOptions, MaxWaitTime).map(_.toInt) - .getOrElse(DefaultMaxWaitTime)) - .threadsAllowedToBlockForConnectionMultiplier(extractValue[String](clientOptions, ThreadsAllowedToBlockForConnectionMultiplier).map(_.toInt) + .connectTimeout(extractValue[Int](clientOptions, ConnectTimeout).getOrElse(DefaultConnectTimeout)) + .connectionsPerHost(extractValue[Int](clientOptions, ConnectionsPerHost).getOrElse(DefaultConnectionsPerHost)) + .maxWaitTime(extractValue[Int](clientOptions, MaxWaitTime).getOrElse(DefaultMaxWaitTime)) + .threadsAllowedToBlockForConnectionMultiplier(extractValue[Int](clientOptions, ThreadsAllowedToBlockForConnectionMultiplier) .getOrElse(DefaultThreadsAllowedToBlockForConnectionMultiplier)) if (sslBuilder(optionSSLOptions)) builder.socketFactory(SSLSocketFactory.getDefault()) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfig.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfig.scala index 8810585..74a76bc 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfig.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfig.scala @@ -87,7 +87,7 @@ object MongodbConfig { val DefaultConnectionsTime = 120000L val DefaultCursorBatchSize = 101 val DefaultBulkBatchSize = 1000 - val DefaultIdAsObjectId = "true" + val DefaultIdAsObjectId = true /** * Parse Map of string parameters to Map with the correct objects used in MongoDb Datasource functions @@ -95,41 +95,152 @@ object MongodbConfig { * @return List of parameters parsed to correct mongoDb configurations */ // TODO Review when refactoring config - def parseParameters(parameters : Map[String,String]): Map[String, Any] = { + def parseParameters(parameters : Map[String, Any]): Map[String, Any] = { - // required properties - /** We will assume hosts are provided like 'host:port,host2:port2,...' */ - val properties: Map[String, Any] = parameters.updated(Host, parameters.getOrElse(Host, notFound[String](Host)).split(",").toList) - if (!parameters.contains(Database)) notFound(Database) - if (!parameters.contains(Collection)) notFound(Collection) + // don't check required properties here, since it will be checked in the Config.build() + + val properties: Map[String, Any] = parameters //optional parseable properties - val optionalProperties: List[String] = List(Credentials,SSLOptions, UpdateFields) - - (properties /: optionalProperties){ - /** We will assume credentials are provided like 'user,database,password;user,database,password;...' */ - case (properties,Credentials) => - parameters.get(Credentials).map{ credentialInput => - val credentials = credentialInput.split(";").map(_.split(",")).toList - .map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray)) - properties + (Credentials -> credentials) + val optionalProperties: List[String] = List(Host, Credentials, SSLOptions, UpdateFields) + + val optionalParsedProperties = (properties /: optionalProperties){ + /** We will assume hosts are provided like 'host:port,host2:port2,...' or like List('host1:port1','host2:port2',.. + * .) */ + case (properties, Host) => + parameters.get(Host).map{ + case hostInput: String => properties + (Host -> hostInput.split(",").toList) + case hostInput @ List(_: String, _*) => properties + case _ => throw new IllegalArgumentException } getOrElse properties - /** We will assume ssloptions are provided like '/path/keystorefile,keystorepassword,/path/truststorefile,truststorepassword' */ - case (properties,SSLOptions) => - parameters.get(SSLOptions).map{ ssloptionsInput => + /** We will assume credentials are provided like 'user,database,password;user,database,password;...' or like + * List('user,database,password', 'user,database,password', ...) */ + case (properties, Credentials) => + parameters.get(Credentials).map { + case credentialInput: String => + val credentials = credentialInput.split(";").map(_.split(",")).toList + .map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray)) + properties + (Credentials -> credentials) + case credentialInput: MongodbCredentials => properties + (Credentials -> List(credentialInput)) + case credentialInput @ List(_: String, _*) => + val credentials = credentialInput.map(_.toString.split(",")) + .map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray)) + properties + (Credentials -> credentials) + case credentialInput @ List(_: MongodbCredentials, _*) => properties + case _ => throw new IllegalArgumentException + } getOrElse properties - val ssloption = ssloptionsInput.split(",") - val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3))) - properties + (SSLOptions -> ssloptions) - } getOrElse properties + /** We will assume ssloptions are provided like '/path/keystorefile,keystorepassword,/path/truststorefile,truststorepassword' */ + case (properties, SSLOptions) => + parameters.get(SSLOptions).map { + case ssloptionsInput: String => + val ssloption = ssloptionsInput.toString.split(",") + val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3))) + properties + (SSLOptions.toLowerCase -> ssloptions) - SSLOptions + case ssloptionsInput: MongodbSSLOptions => properties + (SSLOptions.toLowerCase -> ssloptionsInput) - SSLOptions + } getOrElse { + parameters.get(SSLOptions.toLowerCase).map { + case ssloptionsInput: String => + val ssloption = ssloptionsInput.toString.split(",") + val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3))) + properties + (SSLOptions.toLowerCase -> ssloptions) + case ssloptionsInput: MongodbSSLOptions => properties + } getOrElse properties + } - /** We will assume fields are provided like 'user,database,password...' */ + /** We will assume fields are provided like 'fieldName1,fieldName2,...' or like List('fieldName1','fieldName2',.. + * .)*/ case (properties, UpdateFields) => { - parameters.get(UpdateFields).map{ updateInputs => - val updateFields = updateInputs.split(",") - properties + (UpdateFields -> updateFields) - } getOrElse properties + parameters.get(UpdateFields).map { + case updateInputs: String => + val updateFields = updateInputs.split(",") + properties + (UpdateFields.toLowerCase -> updateFields) - UpdateFields + case updateFields @ Array(_: String, _*) => properties + (UpdateFields.toLowerCase -> updateFields) - UpdateFields + case _ => throw new IllegalArgumentException + } getOrElse { + parameters.get(UpdateFields.toLowerCase).map { + case updateInputs: String => + val updateFields = updateInputs.split(",") + properties + (UpdateFields.toLowerCase -> updateFields) + case updateFields @ Array(_: String, _*) => properties + case _ => throw new IllegalArgumentException + } getOrElse properties + } + } + } + + val intProperties: List[String] = List(SplitSize, ConnectTimeout, ConnectionsPerHost, MaxWaitTime, SocketTimeout, + ThreadsAllowedToBlockForConnectionMultiplier, CursorBatchSize, BulkBatchSize) + + val intParsedProperties = (optionalParsedProperties /: intProperties){ + case (properties, intProperty) => { + parameters.get(intProperty).map{ + case intValueInput: String => properties + (intProperty.toLowerCase -> intValueInput.toInt) - intProperty + case intValueInput: Int => properties + (intProperty.toLowerCase -> intValueInput) - intProperty + case _ => throw new IllegalArgumentException + } getOrElse { + parameters.get(intProperty.toLowerCase).map { + case intValueInput: String => properties + (intProperty.toLowerCase -> intValueInput.toInt) + case intValueInput: Int => properties + case _ => throw new IllegalArgumentException + } getOrElse properties + } + } + } + + val longProperties: List[String] = List(ConnectionsTime) + + val longParsedProperties = (intParsedProperties /: longProperties){ + case (properties, longProperty) => { + parameters.get(longProperty).map { + case longValueInput: String => properties + (longProperty.toLowerCase -> longValueInput.toLong) - longProperty + case longValueInput: Long => properties + (longProperty.toLowerCase -> longValueInput) - longProperty + case _ => throw new IllegalArgumentException + } getOrElse { + parameters.get(longProperty.toLowerCase).map { + case longValueInput: String => properties + (longProperty.toLowerCase -> longValueInput.toLong) + case longValueInput: Long => properties + case _ => throw new IllegalArgumentException + } getOrElse properties + } + } + } + + val doubleProperties: List[String] = List(SamplingRatio) + + val doubleParsedProperties = (longParsedProperties /: doubleProperties){ + case (properties, doubleProperty) => { + parameters.get(doubleProperty).map { + case doubleValueInput: String => properties + (doubleProperty.toLowerCase -> doubleValueInput.toDouble) - doubleProperty + case doubleValueInput: Double => properties + (doubleProperty.toLowerCase -> doubleValueInput) - doubleProperty + case _ => throw new IllegalArgumentException + } getOrElse { + parameters.get(doubleProperty.toLowerCase).map { + case doubleValueInput: String => properties + (doubleProperty.toLowerCase -> doubleValueInput.toDouble) + case doubleValueInput: Double => properties + case _ => throw new IllegalArgumentException + } getOrElse properties + } + } + } + + val booleanProperties: List[String] = List(IdAsObjectId) + + (doubleParsedProperties /: booleanProperties){ + case (properties, booleanProperty) => { + parameters.get(booleanProperty).map { + case booleanValueInput: String => + properties + (booleanProperty.toLowerCase -> booleanValueInput.toBoolean) - booleanProperty + case booleanValueInput: Boolean => properties + (booleanProperty.toLowerCase -> booleanValueInput) - booleanProperty + case _ => throw new IllegalArgumentException + } getOrElse { + parameters.get(booleanProperty.toLowerCase).map { + case booleanValueInput: String => properties + (booleanProperty.toLowerCase -> booleanValueInput.toBoolean) + case booleanValueInput: Boolean => properties + case _ => throw new IllegalArgumentException + } getOrElse properties + } } } } diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfigBuilder.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfigBuilder.scala index d470471..88f0fc9 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfigBuilder.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/config/MongodbConfigBuilder.scala @@ -28,7 +28,7 @@ import Config._ case class MongodbConfigBuilder(props: Map[Property, Any] = Map()) extends { - override val properties = Map() ++ props + override val properties = MongodbConfig.parseParameters(Map() ++ props) } with ConfigBuilder[MongodbConfigBuilder](properties) { diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala index 9d8a977..d9c4cda 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/partitioner/MongodbPartitioner.scala @@ -56,9 +56,9 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] { private val collectionFullName: String = s"$databaseName.$collectionName" - private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong) + private val connectionsTime = config.get[Long](MongodbConfig.ConnectionsTime) - private val cursorBatchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize) + private val cursorBatchSize = config.getOrElse(MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize) override def computePartitions(): Array[MongodbPartition] = { val mongoClient = MongodbClientFactory.getClient(hosts, credentials, ssloptions, clientOptions) @@ -180,8 +180,7 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] { } else (MongoDBObject.empty, None, None) - val maxChunkSize = config.get[String](MongodbConfig.SplitSize).map(_.toInt) - .getOrElse(MongodbConfig.DefaultSplitSize) + val maxChunkSize = config.getOrElse(MongodbConfig.SplitSize, MongodbConfig.DefaultSplitSize) val cmd: MongoDBObject = MongoDBObject( "splitVector" -> collectionFullName, diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/query/FilterSection.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/query/FilterSection.scala index ea589b7..6301622 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/query/FilterSection.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/query/FilterSection.scala @@ -149,6 +149,6 @@ case class SourceFilters( } lazy val idAsObjectId: Boolean = - config.getOrElse[String](MongodbConfig.IdAsObjectId, MongodbConfig.DefaultIdAsObjectId).equalsIgnoreCase("true") + config.getOrElse(MongodbConfig.IdAsObjectId, MongodbConfig.DefaultIdAsObjectId) } \ No newline at end of file diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala index 0496ba6..298014f 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/reader/MongodbReader.scala @@ -45,9 +45,9 @@ class MongodbReader(config: Config, private var dbCursor: Option[MongoCursorBase] = None - private val batchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize) + private val batchSize = config.getOrElse(MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize) - private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong) + private val connectionsTime = config.get[Long](MongodbConfig.ConnectionsTime) def close(): Unit = { @@ -90,7 +90,10 @@ class MongodbReader(config: Config, MongoCredential.createCredential(user, database, password) } val sslOptions = config.get[MongodbSSLOptions](MongodbConfig.SSLOptions) - val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions)) + val clientOptions = { + val lowerCaseOptions = MongodbConfig.ListMongoClientOptions.map(_.toLowerCase).toSet + config.properties.filter { case (k, _) => lowerCaseOptions contains k } + } val mongoClientResponse = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions) mongoClient = Option(mongoClientResponse.clientConnection) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbBatchWriter.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbBatchWriter.scala index 2f8e032..849c9df 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbBatchWriter.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbBatchWriter.scala @@ -29,7 +29,7 @@ class MongodbBatchWriter(config: Config) extends MongodbWriter(config) { private val IdKey = "_id" - private val bulkBatchSize = config.getOrElse[Int](MongodbConfig.BulkBatchSize, MongodbConfig.DefaultBulkBatchSize) + private val bulkBatchSize = config.getOrElse(MongodbConfig.BulkBatchSize, MongodbConfig.DefaultBulkBatchSize) private val pkConfig: Option[Array[String]] = config.get[Array[String]](MongodbConfig.UpdateFields) diff --git a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala index 057a204..22f8447 100644 --- a/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala +++ b/spark-mongodb/src/main/scala/com/stratio/datasource/mongodb/writer/MongodbWriter.scala @@ -50,11 +50,14 @@ abstract class MongodbWriter(config: Config) extends Serializable { case None => DefaultWriteConcern } - private val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions)) + private val clientOptions = { + val lowerCaseOptions = MongodbConfig.ListMongoClientOptions.map(_.toLowerCase).toSet + config.properties.filter { case (k, _) => lowerCaseOptions contains k } + } private val languageConfig = config.get[String](MongodbConfig.Language) - private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong) + private val connectionsTime = config.get[Long](MongodbConfig.ConnectionsTime) protected val mongoClient = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions) diff --git a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala index 5eeabb1..03fc389 100644 --- a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala +++ b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/client/MongodbClientFactoryTest.scala @@ -18,10 +18,10 @@ package com.stratio.datasource.mongodb.client import com.mongodb.casbah.MongoClient import com.mongodb.{MongoCredential, ServerAddress} import com.stratio.datasource.MongodbTestConstants -import com.stratio.datasource.mongodb.config.MongodbSSLOptions +import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbConfigBuilder, MongodbCredentials, MongodbSSLOptions} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfter, FlatSpec, Matchers} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) class MongodbClientFactoryTest extends FlatSpec @@ -36,18 +36,28 @@ with BeforeAndAfterAll { val hostPortCredentialsClient = MongodbClientFactory.getClient("127.0.0.1", 27017, "user", "database", "password").clientConnection + val config = MongodbConfigBuilder(Map( + "readPreference" -> "NEAREST", + "connectTimeout"-> "50000", + "socketTimeout"-> "50000", + "maxWaitTime"-> "50000", + "connectionsPerHost" -> "20", + "threadsAllowedToBlockForConnectionMultiplier" -> "5" + )).set("host", "127.0.0.1:27017") + .set("database", "database") + .set("collection", "collection") + .set(MongodbConfig.Credentials, MongodbCredentials("user","database","password".toCharArray)) + .set(MongodbConfig.SSLOptions, MongodbSSLOptions(Some("/etc/ssl/mongodb.keystore"), Some("password"), "/etc/ssl/mongodb.keystore", Some("password"))) + .build() + val fullClient = MongodbClientFactory.getClient( - List(new ServerAddress("127.0.0.1:27017")), - List(MongoCredential.createCredential("user","database","password".toCharArray)), - Some(MongodbSSLOptions(Some("/etc/ssl/mongodb.keystore"), Some("password"), "/etc/ssl/mongodb.keystore", Some("password"))), - Map( - "readPreference" -> "nearest", - "connectTimeout"-> "50000", - "socketTimeout"-> "50000", - "maxWaitTime"-> "50000", - "connectionsPerHost" -> "20", - "threadsAllowedToBlockForConnectionMultiplier" -> "5" - ) + config[List[String]](MongodbConfig.Host).map(add => new ServerAddress(add)), + config[List[MongodbCredentials]](MongodbConfig.Credentials).map { + case MongodbCredentials(user, database, password) => + MongoCredential.createCredential(user, database, password) + }, + config.get[MongodbSSLOptions](MongodbConfig.SSLOptions), + config.properties ).clientConnection val gracefully = true diff --git a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/config/ConfigTest.scala b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/config/ConfigTest.scala index 3ce801d..d93bf5c 100644 --- a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/config/ConfigTest.scala +++ b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/config/ConfigTest.scala @@ -70,6 +70,20 @@ with MongodbTestConstants{ } + it should "get right value of user specified splitSize" + scalaBinaryVersion in { + val parameters = Map( + "host" -> "example.com", + "database" -> "example", + "collection" -> "test", + "schema_samplingRatio" -> "0.0001", + "splitSize" -> "20" + ) + + val config = MongodbConfigBuilder(parameters).build() + val splitSize = config.getOrElse(MongodbConfig.SplitSize, MongodbConfig.DefaultSplitSize) + splitSize should equal(20) + } + } trait ConfigHelpers { diff --git a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/writer/MongodbWriterIT.scala b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/writer/MongodbWriterIT.scala index c83da71..e178a41 100644 --- a/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/writer/MongodbWriterIT.scala +++ b/spark-mongodb/src/test/scala/com/stratio/datasource/mongodb/writer/MongodbWriterIT.scala @@ -36,7 +36,7 @@ with BeforeAndAfterAll { private val host: String = "localhost" private val collection: String = "testCol" - private val writeConcern = "NORMAL" + private val writeConcern = "SAFE" private val idField: String = "att2" private val updateField: Array[String] = Array("att3") private val wrongIdField: String = "non-existentColumn"