Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {

new MongodbRelation(MongodbConfigBuilder(parseParameters(parameters)).build())(sqlContext)
new MongodbRelation(MongodbConfigBuilder(parameters).build())(sqlContext)

}

Expand All @@ -42,7 +42,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
parameters: Map[String, String],
schema: StructType): BaseRelation = {

new MongodbRelation(MongodbConfigBuilder(parseParameters(parameters)).build(), Some(schema))(sqlContext)
new MongodbRelation(MongodbConfigBuilder(parameters).build(), Some(schema))(sqlContext)

}

Expand All @@ -53,7 +53,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
data: DataFrame): BaseRelation = {

val mongodbRelation = new MongodbRelation(
MongodbConfigBuilder(parseParameters(parameters)).build(), Some(data.schema))(sqlContext)
MongodbConfigBuilder(parameters).build(), Some(data.schema))(sqlContext)

mode match{
case Append => mongodbRelation.insert(data, overwrite = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ with PrunedFilteredScan with InsertableRelation {
@transient private lazy val lazySchema =
MongodbSchema(
new MongodbRDD(sqlContext, config, rddPartitioner),
config.get[Any](MongodbConfig.SamplingRatio).fold(MongodbConfig.DefaultSamplingRatio)(_.toString.toDouble)).schema()
config.getOrElse(MongodbConfig.SamplingRatio, MongodbConfig.DefaultSamplingRatio)).schema()

override val schema: StructType = schemaProvided.getOrElse(lazySchema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class MongodbClientActor extends Actor {

mongoClient.update(finalKey, connection.copy(
timeOut = System.currentTimeMillis() +
extractValue[String](clientOptions, ConnectionsTime).map(_.toLong).getOrElse(DefaultConnectionsTime),
extractValue[Long](clientOptions, ConnectionsTime).getOrElse(DefaultConnectionsTime),
status = ConnectionStatus.Busy))

sender ! ClientResponse(finalKey, connection.client)
Expand Down Expand Up @@ -187,13 +187,10 @@ class MongodbClientActor extends Actor {
case Some(preference) => parseReadPreference(preference)
case None => DefaultReadPreference
})
.connectTimeout(extractValue[String](clientOptions, ConnectTimeout).map(_.toInt)
.getOrElse(DefaultConnectTimeout))
.connectionsPerHost(extractValue[String](clientOptions, ConnectionsPerHost).map(_.toInt)
.getOrElse(DefaultConnectionsPerHost))
.maxWaitTime(extractValue[String](clientOptions, MaxWaitTime).map(_.toInt)
.getOrElse(DefaultMaxWaitTime))
.threadsAllowedToBlockForConnectionMultiplier(extractValue[String](clientOptions, ThreadsAllowedToBlockForConnectionMultiplier).map(_.toInt)
.connectTimeout(extractValue[Int](clientOptions, ConnectTimeout).getOrElse(DefaultConnectTimeout))
.connectionsPerHost(extractValue[Int](clientOptions, ConnectionsPerHost).getOrElse(DefaultConnectionsPerHost))
.maxWaitTime(extractValue[Int](clientOptions, MaxWaitTime).getOrElse(DefaultMaxWaitTime))
.threadsAllowedToBlockForConnectionMultiplier(extractValue[Int](clientOptions, ThreadsAllowedToBlockForConnectionMultiplier)
.getOrElse(DefaultThreadsAllowedToBlockForConnectionMultiplier))

if (sslBuilder(optionSSLOptions)) builder.socketFactory(SSLSocketFactory.getDefault())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,49 +87,160 @@ object MongodbConfig {
val DefaultConnectionsTime = 120000L
val DefaultCursorBatchSize = 101
val DefaultBulkBatchSize = 1000
val DefaultIdAsObjectId = "true"
val DefaultIdAsObjectId = true

/**
* Parse Map of string parameters to Map with the correct objects used in MongoDb Datasource functions
* @param parameters List of parameters
* @return List of parameters parsed to correct mongoDb configurations
*/
// TODO Review when refactoring config
def parseParameters(parameters : Map[String,String]): Map[String, Any] = {
def parseParameters(parameters : Map[String, Any]): Map[String, Any] = {

// required properties
/** We will assume hosts are provided like 'host:port,host2:port2,...' */
val properties: Map[String, Any] = parameters.updated(Host, parameters.getOrElse(Host, notFound[String](Host)).split(",").toList)
if (!parameters.contains(Database)) notFound(Database)
if (!parameters.contains(Collection)) notFound(Collection)
// don't check required properties here, since it will be checked in the Config.build()

val properties: Map[String, Any] = parameters

//optional parseable properties
val optionalProperties: List[String] = List(Credentials,SSLOptions, UpdateFields)

(properties /: optionalProperties){
/** We will assume credentials are provided like 'user,database,password;user,database,password;...' */
case (properties,Credentials) =>
parameters.get(Credentials).map{ credentialInput =>
val credentials = credentialInput.split(";").map(_.split(",")).toList
.map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray))
properties + (Credentials -> credentials)
val optionalProperties: List[String] = List(Host, Credentials, SSLOptions, UpdateFields)

val optionalParsedProperties = (properties /: optionalProperties){
/** We will assume hosts are provided like 'host:port,host2:port2,...' or like List('host1:port1','host2:port2',..
* .) */
case (properties, Host) =>
parameters.get(Host).map{
case hostInput: String => properties + (Host -> hostInput.split(",").toList)
case hostInput @ List(_: String, _*) => properties
case _ => throw new IllegalArgumentException
} getOrElse properties

/** We will assume ssloptions are provided like '/path/keystorefile,keystorepassword,/path/truststorefile,truststorepassword' */
case (properties,SSLOptions) =>
parameters.get(SSLOptions).map{ ssloptionsInput =>
/** We will assume credentials are provided like 'user,database,password;user,database,password;...' or like
* List('user,database,password', 'user,database,password', ...) */
case (properties, Credentials) =>
parameters.get(Credentials).map {
case credentialInput: String =>
val credentials = credentialInput.split(";").map(_.split(",")).toList
.map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray))
properties + (Credentials -> credentials)
case credentialInput: MongodbCredentials => properties + (Credentials -> List(credentialInput))
case credentialInput @ List(_: String, _*) =>
val credentials = credentialInput.map(_.toString.split(","))
.map(credentials => MongodbCredentials(credentials(0), credentials(1), credentials(2).toCharArray))
properties + (Credentials -> credentials)
case credentialInput @ List(_: MongodbCredentials, _*) => properties
case _ => throw new IllegalArgumentException
} getOrElse properties

val ssloption = ssloptionsInput.split(",")
val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3)))
properties + (SSLOptions -> ssloptions)
} getOrElse properties
/** We will assume ssloptions are provided like '/path/keystorefile,keystorepassword,/path/truststorefile,truststorepassword' */
case (properties, SSLOptions) =>
parameters.get(SSLOptions).map {
case ssloptionsInput: String =>
val ssloption = ssloptionsInput.toString.split(",")
val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3)))
properties + (SSLOptions.toLowerCase -> ssloptions) - SSLOptions
case ssloptionsInput: MongodbSSLOptions => properties + (SSLOptions.toLowerCase -> ssloptionsInput) - SSLOptions
} getOrElse {
parameters.get(SSLOptions.toLowerCase).map {
case ssloptionsInput: String =>
val ssloption = ssloptionsInput.toString.split(",")
val ssloptions = MongodbSSLOptions(Some(ssloption(0)), Some(ssloption(1)), ssloption(2), Some(ssloption(3)))
properties + (SSLOptions.toLowerCase -> ssloptions)
case ssloptionsInput: MongodbSSLOptions => properties
} getOrElse properties
}

/** We will assume fields are provided like 'user,database,password...' */
/** We will assume fields are provided like 'fieldName1,fieldName2,...' or like List('fieldName1','fieldName2',..
* .)*/
case (properties, UpdateFields) => {
parameters.get(UpdateFields).map{ updateInputs =>
val updateFields = updateInputs.split(",")
properties + (UpdateFields -> updateFields)
} getOrElse properties
parameters.get(UpdateFields).map {
case updateInputs: String =>
val updateFields = updateInputs.split(",")
properties + (UpdateFields.toLowerCase -> updateFields) - UpdateFields
case updateFields @ Array(_: String, _*) => properties + (UpdateFields.toLowerCase -> updateFields) - UpdateFields
case _ => throw new IllegalArgumentException
} getOrElse {
parameters.get(UpdateFields.toLowerCase).map {
case updateInputs: String =>
val updateFields = updateInputs.split(",")
properties + (UpdateFields.toLowerCase -> updateFields)
case updateFields @ Array(_: String, _*) => properties
case _ => throw new IllegalArgumentException
} getOrElse properties
}
}
}

val intProperties: List[String] = List(SplitSize, ConnectTimeout, ConnectionsPerHost, MaxWaitTime, SocketTimeout,
ThreadsAllowedToBlockForConnectionMultiplier, CursorBatchSize, BulkBatchSize)

val intParsedProperties = (optionalParsedProperties /: intProperties){
case (properties, intProperty) => {
parameters.get(intProperty).map{
case intValueInput: String => properties + (intProperty.toLowerCase -> intValueInput.toInt) - intProperty
case intValueInput: Int => properties + (intProperty.toLowerCase -> intValueInput) - intProperty
case _ => throw new IllegalArgumentException
} getOrElse {
parameters.get(intProperty.toLowerCase).map {
case intValueInput: String => properties + (intProperty.toLowerCase -> intValueInput.toInt)
case intValueInput: Int => properties
case _ => throw new IllegalArgumentException
} getOrElse properties
}
}
}

val longProperties: List[String] = List(ConnectionsTime)

val longParsedProperties = (intParsedProperties /: longProperties){
case (properties, longProperty) => {
parameters.get(longProperty).map {
case longValueInput: String => properties + (longProperty.toLowerCase -> longValueInput.toLong) - longProperty
case longValueInput: Long => properties + (longProperty.toLowerCase -> longValueInput) - longProperty
case _ => throw new IllegalArgumentException
} getOrElse {
parameters.get(longProperty.toLowerCase).map {
case longValueInput: String => properties + (longProperty.toLowerCase -> longValueInput.toLong)
case longValueInput: Long => properties
case _ => throw new IllegalArgumentException
} getOrElse properties
}
}
}

val doubleProperties: List[String] = List(SamplingRatio)

val doubleParsedProperties = (longParsedProperties /: doubleProperties){
case (properties, doubleProperty) => {
parameters.get(doubleProperty).map {
case doubleValueInput: String => properties + (doubleProperty.toLowerCase -> doubleValueInput.toDouble) - doubleProperty
case doubleValueInput: Double => properties + (doubleProperty.toLowerCase -> doubleValueInput) - doubleProperty
case _ => throw new IllegalArgumentException
} getOrElse {
parameters.get(doubleProperty.toLowerCase).map {
case doubleValueInput: String => properties + (doubleProperty.toLowerCase -> doubleValueInput.toDouble)
case doubleValueInput: Double => properties
case _ => throw new IllegalArgumentException
} getOrElse properties
}
}
}

val booleanProperties: List[String] = List(IdAsObjectId)

(doubleParsedProperties /: booleanProperties){
case (properties, booleanProperty) => {
parameters.get(booleanProperty).map {
case booleanValueInput: String =>
properties + (booleanProperty.toLowerCase -> booleanValueInput.toBoolean) - booleanProperty
case booleanValueInput: Boolean => properties + (booleanProperty.toLowerCase -> booleanValueInput) - booleanProperty
case _ => throw new IllegalArgumentException
} getOrElse {
parameters.get(booleanProperty.toLowerCase).map {
case booleanValueInput: String => properties + (booleanProperty.toLowerCase -> booleanValueInput.toBoolean)
case booleanValueInput: Boolean => properties
case _ => throw new IllegalArgumentException
} getOrElse properties
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import Config._

case class MongodbConfigBuilder(props: Map[Property, Any] = Map()) extends {

override val properties = Map() ++ props
override val properties = MongodbConfig.parseParameters(Map() ++ props)

} with ConfigBuilder[MongodbConfigBuilder](properties) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] {

private val collectionFullName: String = s"$databaseName.$collectionName"

private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong)
private val connectionsTime = config.get[Long](MongodbConfig.ConnectionsTime)

private val cursorBatchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize)
private val cursorBatchSize = config.getOrElse(MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize)

override def computePartitions(): Array[MongodbPartition] = {
val mongoClient = MongodbClientFactory.getClient(hosts, credentials, ssloptions, clientOptions)
Expand Down Expand Up @@ -180,8 +180,7 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] {
}
else (MongoDBObject.empty, None, None)

val maxChunkSize = config.get[String](MongodbConfig.SplitSize).map(_.toInt)
.getOrElse(MongodbConfig.DefaultSplitSize)
val maxChunkSize = config.getOrElse(MongodbConfig.SplitSize, MongodbConfig.DefaultSplitSize)

val cmd: MongoDBObject = MongoDBObject(
"splitVector" -> collectionFullName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,6 @@ case class SourceFilters(
}

lazy val idAsObjectId: Boolean =
config.getOrElse[String](MongodbConfig.IdAsObjectId, MongodbConfig.DefaultIdAsObjectId).equalsIgnoreCase("true")
config.getOrElse(MongodbConfig.IdAsObjectId, MongodbConfig.DefaultIdAsObjectId)

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class MongodbReader(config: Config,

private var dbCursor: Option[MongoCursorBase] = None

private val batchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize)
private val batchSize = config.getOrElse(MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize)

private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong)
private val connectionsTime = config.get[Long](MongodbConfig.ConnectionsTime)


def close(): Unit = {
Expand Down Expand Up @@ -90,7 +90,10 @@ class MongodbReader(config: Config,
MongoCredential.createCredential(user, database, password)
}
val sslOptions = config.get[MongodbSSLOptions](MongodbConfig.SSLOptions)
val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions))
val clientOptions = {
val lowerCaseOptions = MongodbConfig.ListMongoClientOptions.map(_.toLowerCase).toSet
config.properties.filter { case (k, _) => lowerCaseOptions contains k }
}

val mongoClientResponse = MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions)
mongoClient = Option(mongoClientResponse.clientConnection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MongodbBatchWriter(config: Config) extends MongodbWriter(config) {

private val IdKey = "_id"

private val bulkBatchSize = config.getOrElse[Int](MongodbConfig.BulkBatchSize, MongodbConfig.DefaultBulkBatchSize)
private val bulkBatchSize = config.getOrElse(MongodbConfig.BulkBatchSize, MongodbConfig.DefaultBulkBatchSize)

private val pkConfig: Option[Array[String]] = config.get[Array[String]](MongodbConfig.UpdateFields)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ abstract class MongodbWriter(config: Config) extends Serializable {
case None => DefaultWriteConcern
}

private val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions))
private val clientOptions = {
val lowerCaseOptions = MongodbConfig.ListMongoClientOptions.map(_.toLowerCase).toSet
config.properties.filter { case (k, _) => lowerCaseOptions contains k }
}

private val languageConfig = config.get[String](MongodbConfig.Language)

private val connectionsTime = config.get[String](MongodbConfig.ConnectionsTime).map(_.toLong)
private val connectionsTime = config.get[Long](MongodbConfig.ConnectionsTime)

protected val mongoClient =
MongodbClientFactory.getClient(hosts, credentials, sslOptions, clientOptions)
Expand Down
Loading