Skip to content

Commit a5673e5

Browse files
DEV: Implement imputation for VCF features
* Update python wrapper to include imputation strategy parameter * Update scala API to pass imputation strategy to VCFFeatureSource * Create functions to handle mode and zero imputation strategies * Added imputation strategy to test cases * Added imputation strategy to FeatureSource cli * Remove sparkPar from test cases due to changes in class signature * Updated DefVariantToFeatureConverterTest to use zeros imputation
1 parent 01efa9d commit a5673e5

File tree

10 files changed

+59
-23
lines changed

10 files changed

+59
-23
lines changed

python/varspark/core.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,26 @@ def __init__(self, ss, silent=False):
5656
" /_/ \n"
5757
)
5858

59-
@params(self=object, vcf_file_path=str, min_partitions=int)
60-
def import_vcf(self, vcf_file_path, min_partitions=0):
61-
"""Import features from a VCF file."""
59+
@params(self=object, vcf_file_path=str, imputation_strategy=Nullable(str))
60+
def import_vcf(self, vcf_file_path, imputation_strategy="none"):
61+
"""Import features from a VCF file.
62+
63+
:param vcf_file_path String: The file path for the vcf file to import
64+
:param imputation_strategy String:
65+
The imputation strategy to use. Options for imputation include:
66+
67+
- none: No imputation will be performed. Missing values will be replaced with -1 (not recommended unless there are no missing values)
68+
- mode: Missing values will be replaced with the most commonly occuring value among that feature. Recommended option
69+
- zeros: Missing values will be replaced with zeros. Faster than mode imputation
70+
"""
71+
if imputation_strategy == "none":
72+
print("WARNING: Imputation strategy is set to none - please ensure that there are no missing values in the data.")
6273
return FeatureSource(
6374
self._jvm,
6475
self._vs_api,
6576
self._jsql,
6677
self.sql,
67-
self._jvsc.importVCF(vcf_file_path, min_partitions),
78+
self._jvsc.importVCF(vcf_file_path, imputation_strategy),
6879
)
6980

7081
@params(
@@ -76,7 +87,7 @@ def import_vcf(self, vcf_file_path, min_partitions=0):
7687
def import_covariates(self, cov_file_path, cov_types=None, transposed=False):
7788
"""Import covariates from a CSV file.
7889
79-
:param cov_file_path: The file path for covariate csv file
90+
:param cov_file_path String: The file path for covariate csv file
8091
:param cov_types Dict[String]:
8192
A dictionary specifying types for each covariate, where the key is the variable name
8293
and the value is the type. The value can be one of the following:

src/main/scala/au/csiro/variantspark/api/VSContext.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ class VSContext(val spark: SparkSession) extends SqlContextHolder {
3838
* @param inputFile path to file or directory with VCF files to load
3939
* @return FeatureSource loaded from the VCF file
4040
*/
41-
def importVCF(inputFile: String, sparkPar: Int = 0): FeatureSource = {
41+
def importVCF(inputFile: String, imputationStrategy: String = "none"): FeatureSource = {
4242
val vcfSource =
4343
VCFSource(sc, inputFile)
4444
// VCFSource(sc.textFile(inputFile, if (sparkPar > 0) sparkPar else sc.defaultParallelism))
45-
VCFFeatureSource(vcfSource)
45+
VCFFeatureSource(vcfSource, imputationStrategy)
4646
}
4747

4848
/** Import features from a CSV file

src/main/scala/au/csiro/variantspark/cli/CochranArmanCmd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class CochranArmanCmd extends ArgsApp with SparkApp with Echoable with Logging w
8989
VCFSource(sc.textFile(inputFile, if (sparkPar > 0) sparkPar else sc.defaultParallelism))
9090
verbose(s"VCF Version: ${vcfSource.version}")
9191
verbose(s"VCF Header: ${vcfSource.header}")
92-
VCFFeatureSource(vcfSource)
92+
VCFFeatureSource(vcfSource, imputationStrategy = "none")
9393
}
9494

9595
def loadCSV(): CsvFeatureSource = {

src/main/scala/au/csiro/variantspark/cli/FilterCmd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class FilterCmd extends ArgsApp with TestArgs with SparkApp {
3030
logDebug(s"Running with filesystem: ${fs}, home: ${fs.getHomeDirectory}")
3131

3232
val vcfSource = VCFSource(sc.textFile(inputFile))
33-
val source = VCFFeatureSource(vcfSource)
33+
val source = VCFFeatureSource(vcfSource, imputationStrategy = "none")
3434
val features = source.features.zipWithIndex().cache()
3535
val featureCount = features.count()
3636
println(s"No features: ${featureCount}")

src/main/scala/au/csiro/variantspark/cli/VcfToLabels.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class VcfToLabels extends ArgsApp with SparkApp {
2727
val version = vcfSource.version
2828
println(header)
2929
println(version)
30-
val source = VCFFeatureSource(vcfSource)
30+
val source = VCFFeatureSource(vcfSource, imputationStrategy = "none")
3131
val columns = source.features.take(limit)
3232
CSVUtils.withFile(new File(outputFile)) { writer =>
3333
writer.writeRow("" :: columns.map(_.label).toList)

src/main/scala/au/csiro/variantspark/cli/args/FeatureSourceArgs.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ object VCFFeatureSourceFactory {
2626
val DEF_SEPARATOR: String = "_"
2727
}
2828

29-
case class VCFFeatureSourceFactory(inputFile: String, isBiallelic: Option[Boolean],
30-
separator: Option[String])
29+
case class VCFFeatureSourceFactory(inputFile: String, imputationStrategy: Option[String],
30+
isBiallelic: Option[Boolean], separator: Option[String])
3131
extends FeatureSourceFactory with Echoable {
3232
def createSource(sparkArgs: SparkArgs): FeatureSource = {
3333
echo(s"Loading header from VCF file: ${inputFile}")
@@ -36,8 +36,8 @@ case class VCFFeatureSourceFactory(inputFile: String, isBiallelic: Option[Boolea
3636
verbose(s"VCF Header: ${vcfSource.header}")
3737

3838
import VCFFeatureSourceFactory._
39-
VCFFeatureSource(vcfSource, isBiallelic.getOrElse(DEF_IS_BIALLELIC),
40-
separator.getOrElse(DEF_SEPARATOR))
39+
VCFFeatureSource(vcfSource, imputationStrategy.getOrElse("none"),
40+
isBiallelic.getOrElse(DEF_IS_BIALLELIC), separator.getOrElse(DEF_SEPARATOR))
4141
}
4242
}
4343

src/main/scala/au/csiro/variantspark/input/VCFFeatureSource.scala

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import au.csiro.variantspark.data.StdFeature
1010

1111
trait VariantToFeatureConverter {
1212
def convert(vc: VariantContext): Feature
13+
def convertModeImputed(vc: VariantContext): Feature
14+
def convertZeroImputed(vc: VariantContext): Feature
1315
}
1416

1517
case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: String = "_")
@@ -20,6 +22,18 @@ case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: S
2022
StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), gts)
2123
}
2224

25+
def convertModeImputed(vc: VariantContext): Feature = {
26+
val gts = vc.getGenotypes.iterator().asScala.map(convertGenotype).toArray
27+
val modeImputedGts = ModeImputationStrategy(noLevels = 3).impute(gts)
28+
StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), modeImputedGts)
29+
}
30+
31+
def convertZeroImputed(vc: VariantContext): Feature = {
32+
val gts = vc.getGenotypes.iterator().asScala.map(convertGenotype).toArray
33+
val zeroImputedGts = ZeroImputationStrategy.impute(gts)
34+
StdFeature.from(convertLabel(vc), BoundedOrdinalVariable(3), zeroImputedGts)
35+
}
36+
2337
def convertLabel(vc: VariantContext): String = {
2438

2539
if (biallelic && !vc.isBiallelic) {
@@ -44,23 +58,34 @@ case class DefVariantToFeatureConverter(biallelic: Boolean = false, separator: S
4458
}
4559

4660
def convertGenotype(gt: Genotype): Byte = {
47-
if (!gt.isCalled || gt.isHomRef) 0 else if (gt.isHomVar || gt.isHetNonRef) 2 else 1
61+
if (!gt.isCalled) Missing.BYTE_NA_VALUE
62+
else if (gt.isHomRef) 0
63+
else if (gt.isHomVar || gt.isHetNonRef) 2
64+
else 1
4865
}
4966
}
5067

51-
class VCFFeatureSource(vcfSource: VCFSource, converter: VariantToFeatureConverter)
68+
class VCFFeatureSource(vcfSource: VCFSource, converter: VariantToFeatureConverter,
69+
imputationStrategy: String)
5270
extends FeatureSource {
5371
override lazy val sampleNames: List[String] =
5472
vcfSource.header.getGenotypeSamples.asScala.toList
5573
override def features: RDD[Feature] = {
5674
val converterRef = converter
57-
vcfSource.genotypes().map(converterRef.convert)
75+
imputationStrategy match {
76+
case "none" => vcfSource.genotypes().map(converterRef.convert)
77+
case "mode" => vcfSource.genotypes().map(converterRef.convertModeImputed)
78+
case "zeros" => vcfSource.genotypes().map(converterRef.convertZeroImputed)
79+
case _ =>
80+
throw new IllegalArgumentException(s"Unknown imputation strategy: $imputationStrategy")
81+
}
5882
}
5983
}
6084

6185
object VCFFeatureSource {
62-
def apply(vcfSource: VCFSource, biallelic: Boolean = false,
86+
def apply(vcfSource: VCFSource, imputationStrategy: String, biallelic: Boolean = false,
6387
separator: String = "_"): VCFFeatureSource = {
64-
new VCFFeatureSource(vcfSource, DefVariantToFeatureConverter(biallelic, separator))
88+
new VCFFeatureSource(vcfSource, DefVariantToFeatureConverter(biallelic, separator),
89+
imputationStrategy)
6590
}
6691
}

src/test/scala/au/csiro/variantspark/input/DefVariantToFeatureConverterTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ class DefVariantToFeatureConverterTest {
3737
@Test
3838
def testConvertsBialleicVariantCorrctly() {
3939
val converter = DefVariantToFeatureConverter(true, ":")
40-
val result = converter.convert(bialellicVC)
40+
val result = converter.convertZeroImputed(bialellicVC)
4141
assertEquals("chr1:10:T:A", result.label)
4242
assertArrayEquals(expectedEncodedGenotype, result.valueAsByteArray)
4343
}
4444

4545
@Test
4646
def testConvertsMultialleicVariantCorrctly() {
4747
val converter = DefVariantToFeatureConverter(false)
48-
val result = converter.convert(multialleciVC)
48+
val result = converter.convertZeroImputed(multialleciVC)
4949
assertEquals("chr1_10_T_A|G", result.label)
5050
assertArrayEquals(expectedEncodedGenotype, result.valueAsByteArray)
5151
}

src/test/scala/au/csiro/variantspark/misc/CovariateReproducibilityTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class CovariateReproducibilityTest extends SparkTest {
2828
def testCovariateReproducibleResults() {
2929
implicit val vsContext = VSContext(spark)
3030
implicit val sqlContext = spark.sqlContext
31-
val genotypes = vsContext.importVCF("data/chr22_1000.vcf", 3)
31+
val genotypes = vsContext.importVCF("data/chr22_1000.vcf")
3232
val optVariableTypes = new ArrayList[String](Arrays.asList("CONTINUOUS", "ORDINAL(2)",
3333
"CONTINUOUS", "CONTINUOUS", "CONTINUOUS", "CONTINUOUS"))
3434
val covariates =

src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ReproducibilityTest extends SparkTest {
2525
def testReproducibleResults() {
2626
implicit val vsContext = VSContext(spark)
2727
implicit val sqlContext = spark.sqlContext
28-
val features = vsContext.importVCF("data/chr22_1000.vcf", 3)
28+
val features = vsContext.importVCF("data/chr22_1000.vcf")
2929
val label = vsContext.loadLabel("data/chr22-labels.csv", "22_16051249")
3030
val params = RandomForestParams(seed = 13L)
3131
val rfModel1 = RFModelTrainer.trainModel(features, label, params, 40, 20)

0 commit comments

Comments
 (0)