Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mmoe; support multilabel libsvm, multilabelauc #106

Open
wants to merge 1 commit into
base: branch-0.2.0
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cpp/src/angel/pytorch/angel_torch.cc
Original file line number Diff line number Diff line change
@@ -268,6 +268,12 @@ JNIEXPORT jfloatArray JNICALL Java_com_tencent_angel_pytorch_Torch_forward
std::vector<torch::jit::IValue> inputs;
std::vector<std::pair<std::string, void *>> ptrs;

int multi_forward_out = 1;
if (angel::jni_map_contain(env, jparams, "multi_forward_out")) {
multi_forward_out =
angel::jni_map_get_int(env, jparams, "multi_forward_out");
}

int batch_size = angel::jni_map_get_int(env, jparams, "batch_size");
// data inputs
inputs.emplace_back(batch_size);
@@ -282,7 +288,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_tencent_angel_pytorch_Torch_forward
}
auto output = ptr->serving_forward(inputs);
auto output_ptr = output.data_ptr();
DEFINE_JFLOATARRAY(output_ptr, batch_size);
DEFINE_JFLOATARRAY(output_ptr, batch_size * multi_forward_out);

// release java arrays
release_array(env, ptrs, jparams);
@@ -291,7 +297,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_tencent_angel_pytorch_Torch_forward
add_inputs(env, &inputs, &ptrs, jparams, ptr->get_type());
auto output = ptr->forward(inputs).toTensor();
auto output_ptr = output.data_ptr();
DEFINE_JFLOATARRAY(output_ptr, batch_size);
DEFINE_JFLOATARRAY(output_ptr, batch_size * multi_forward_out);

// release java arrays
release_array(env, ptrs, jparams);
@@ -603,4 +609,4 @@ JNIEXPORT void JNICALL Java_com_tencent_angel_pytorch_Torch_gcnSave
ptr->save(path);
env->ReleaseStringUTFChars(jpath, path);
release_array(env, ptrs, jparams);
}
}
Original file line number Diff line number Diff line change
@@ -50,16 +50,17 @@ public static Tuple3<CooLongFloatMatrix, long[], String[]> parsePredict(String[]
private static Tuple2<CooLongFloatMatrix, float[]> parseLIBSVM(String[] lines) {
LongArrayList rows = new LongArrayList();
LongArrayList cols = new LongArrayList();
LongArrayList fields = null;
FloatArrayList vals = new FloatArrayList();
float[] targets = new float[lines.length];
FloatArrayList targets = new FloatArrayList();

int index = 0;
for (int i = 0; i < lines.length; i++) {
String[] parts = lines[i].split(" ");
float label = Float.parseFloat(parts[0]);
targets[i] = label;

String[] labels = parts[0].split("#");
for (int l = 0; l < labels.length; l += 1) {
float label = Float.parseFloat(labels[l]);
targets.add(label);
}
for (int j = 1; j < parts.length; j++) {
String[] kv = parts[j].split(":");
long key = Long.parseLong(kv[0]) - 1;
@@ -75,8 +76,7 @@ private static Tuple2<CooLongFloatMatrix, float[]> parseLIBSVM(String[] lines) {

CooLongFloatMatrix coo = MFactory.cooLongFloatMatrix(rows.toLongArray(),
cols.toLongArray(), vals.toFloatArray(), null);

return new Tuple2<CooLongFloatMatrix, float[]>(coo, targets);
return new Tuple2<CooLongFloatMatrix, float[]>(coo, targets.toFloatArray());
}

private static Tuple3<CooLongFloatMatrix, long[], float[]> parseLIBFFM(String[] lines) {
22 changes: 22 additions & 0 deletions java/src/main/java/com/tencent/angel/pytorch/torch/TorchModel.java
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ public class TorchModel implements Serializable {

// load library of torch and torch_angel
static {
System.loadLibrary("torch");
System.loadLibrary("torch_angel");
}

@@ -242,6 +243,27 @@ public float[] forward(int batchSize, CooLongFloatMatrix batch, float[] bias, fl
return Torch.forward(ptr, params, false);
}

public float[] forward(int batchSize, CooLongFloatMatrix batch, float[] bias, float[] weights, float[] embeddings, int embeddingDim, float[] mats, int[] matSizes, int multiForwardOut) {
Map<String, Object> params = buildParams(batchSize, batch, bias, weights);
params.put("embedding", embeddings);
params.put("embedding_dim", embeddingDim);
params.put("mats", mats);
params.put("mats_sizes", matSizes);
params.put("multi_forward_out", multiForwardOut);
return Torch.forward(ptr, params, false);
}

public float[] forward(int batchSize, CooLongFloatMatrix batch, float[] bias, float[] weights, float[] embeddings, int embeddingDim, float[] mats, int[] matSizes, long[] fields, int multiForwardOut) {
Map<String, Object> params = buildParams(batchSize, batch, bias, weights);
params.put("embedding", embeddings);
params.put("embedding_dim", embeddingDim);
params.put("mats", mats);
params.put("mats_sizes", matSizes);
params.put("fields", fields);
params.put("multi_forward_out", multiForwardOut);
return Torch.forward(ptr, params, false);
}

public float backward(int batchSize, CooLongFloatMatrix batch, float[] bias, float[] weights, float[] targets) {
Map<String, Object> params = buildParams(batchSize, batch, bias, weights);
params.put("targets", targets);
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.eval

import org.apache.spark.rdd.RDD

import scala.language.implicitConversions

// evaluation for multi-labels
private[pytorch]
abstract class EvaluationM extends Serializable {

def calculate(pairs: RDD[(Double, Double)]): String
}

private[pytorch]
object EvaluationM {

def eval(metrics: Array[String], pairs: RDD[(Double, Double)], numLabels: Int = 1): Map[String, String] = {
metrics.map(name => (name.toLowerCase(), EvaluationM.apply(name, numLabels).calculate(pairs))).toMap
}

def apply(name: String, numLabels: Int = 1): EvaluationM = {
name.toLowerCase match {
case "multi_auc" => new MultiLabelAUC(numLabels)
case "multi_auc_collect" => new MultiLabelAUCCollect(numLabels)
}
}

implicit def pairNumericRDDToPairDoubleRDD[T](rdd: RDD[(T, T)])(implicit num: Numeric[T])
: RDD[(Double, Double)] = {
rdd.map(x => (num.toDouble(x._1), num.toDouble(x._2)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.eval

import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

class MultiLabelAUC(numLabels: Int) extends EvaluationM {

def calculate_(pairs: RDD[(Double, Double)]): Double = {
// sort by predict
val sorted = pairs.sortBy(f => f._2)
sorted.cache()

val numTotal = sorted.count()
val numPositive = sorted.filter(f => f._1 > 0).count()
val numNegetive = numTotal - numPositive

// calculate the summation of ranks for positive samples
val sumRanks_ = sorted.zipWithIndex().filter(f => f._1._1.toInt == 1).persist(StorageLevel.MEMORY_ONLY)
val sumRanks = sumRanks_.map(f => f._2 + 1).reduce(_ + _)
val auc = sumRanks * 1.0 / numPositive / numNegetive - (numPositive + 1.0) / 2.0 / numNegetive

sorted.unpersist()
sumRanks_.unpersist()
auc
}

override
def calculate(pairs: RDD[(Double, Double)]): String = {
pairs.persist(StorageLevel.MEMORY_ONLY)
val data = pairs.mapPartitions { part =>
val p = part.toArray
p.sliding(numLabels, numLabels).map(_.toArray)
}.persist(StorageLevel.MEMORY_ONLY)
val re = new Array[Double](numLabels)
var i = 0
while (i < numLabels) {
re(i) = calculate_(data.map(_(i)))
i += 1
}
re.mkString(",")
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.eval

import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

/**
*collect predict results to driver to calculate multi-label auc
*this is suitable when num of train/predict samples is acceptable for collecting, eg. less than 10,000,000
*/
class MultiLabelAUCCollect(numLabels: Int) extends EvaluationM {

def calculate_(pairs: Array[(Double, Double)]): Double = {
// sort by predict
val sorted = pairs.sortBy(f => f._2)

val numTotal = sorted.length
val numPositive = sorted.count(f => f._1 > 0)
val numNegative = numTotal - numPositive

// calculate the summation of ranks for positive samples
val sumRanks_ = sorted.zipWithIndex.filter(f => f._1._1.toInt == 1)
val sumRanks = sumRanks_.map(f => f._2.toLong + 1).sum
val auc = sumRanks * 1.0 / numPositive / numNegative - (numPositive + 1.0) / 2.0 / numNegative
auc
}

override
def calculate(pairs: RDD[(Double, Double)]): String = {
pairs.persist(StorageLevel.MEMORY_ONLY)
val data = pairs.mapPartitions { part =>
val p = part.toArray
p.sliding(numLabels, numLabels).map(_.toArray)
}.collect()
val re = new Array[Double](numLabels)
var i = 0
while (i < numLabels) {
re(i) = calculate_(data.map(_(i)))
i += 1
}
re.mkString(",")
}
}
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ object RecommendationExample {
val torchOutputModelPath = params.getOrElse("torchOutputModelPath", "")
val rowType = params.getOrElse("rowType", "T_FLOAT_DENSE")
val evals = params.getOrElse("evals", "auc")
val numLabels = params.getOrElse("numLabels", "1").toInt
val level = params.getOrElse("storageLevel", "memory_only").toUpperCase()

val recommendation = new Recommendation(torchModelPath)
@@ -60,6 +61,7 @@ object RecommendationExample {
recommendation.setDecay(decay)
recommendation.setAsync(async)
recommendation.setEvaluations(evals)
recommendation.setNumLabels(numLabels)
recommendation.setStorageLevel(StorageLevel.fromString(level))

var numPartitions = start(mode)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.params

import org.apache.spark.ml.param.{IntParam, Params}

trait HasNumLabels extends Params {

final val numLabels = new IntParam(this, "numLabels", "numLabels")

final def getNumLabels: Int = $(numLabels)

setDefault(numLabels, 1)

final def setNumLabels(value: Int): this.type = set(numLabels, value)
}
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ package com.tencent.angel.pytorch.recommendation
import com.tencent.angel.ml.math2.matrix.CooLongFloatMatrix
import com.tencent.angel.pytorch.data.SampleParser
import com.tencent.angel.pytorch.eval.Evaluation
import com.tencent.angel.pytorch.eval.EvaluationM
import com.tencent.angel.pytorch.model.TorchModelType
import com.tencent.angel.pytorch.optim.AsyncOptim
import com.tencent.angel.pytorch.params._
@@ -34,7 +35,7 @@ import org.apache.spark.sql.{DataFrame, Row}

class Recommendation(torchModelPath: String, val uid: String) extends Serializable
with HasOptimizer with HasAsync with HasNumEpoch with HasBatchSize with HasTestRatio
with HasEvaluation with HasStorageLevel {
with HasEvaluation with HasNumLabels with HasStorageLevel {

def this(torchModelPath: String) = this(torchModelPath, Identifiable.randomUID("Recommendation"))

@@ -98,10 +99,17 @@ class Recommendation(torchModelPath: String, val uid: String) extends Serializab
}
}

def evaluate(model: RecommendPSModel, data: RDD[String]): Map[String, Double] = {
def evaluate(model: RecommendPSModel, data: RDD[String]): Map[String, String] = {
val scores = predict(model, data).map(f => (f._1.toFloat, f._2))
import com.tencent.angel.pytorch.eval.Evaluation._
Evaluation.eval(getEvaluations, scores)
$(numLabels) match {
case 1 =>
import com.tencent.angel.pytorch.eval.Evaluation._
Evaluation.eval(getEvaluations, scores).map(f => (f._1, f._2.toString()))
case _ =>
import com.tencent.angel.pytorch.eval.EvaluationM._
EvaluationM.eval(getEvaluations, scores, $(numLabels))
}

}

def predict(model: RecommendPSModel, data: DataFrame): DataFrame = {
@@ -117,9 +125,9 @@ class Recommendation(torchModelPath: String, val uid: String) extends Serializab
data.sparkSession.createDataFrame(scores, schema)
}

def predict(model: RecommendPSModel, data: RDD[String]): RDD[(String, Float)] = {
def predict(model: RecommendPSModel, data: RDD[String]): RDD[(Float, Float)] = {

def predictPartition(it: Iterator[String]): Iterator[(Array[String], Array[Float])] = {
def predictPartition(it: Iterator[String]): Iterator[(Array[Float], Array[Float])] = {
it.sliding($(batchSize), $(batchSize))
.map(f => predict(f.toArray, model))
}
@@ -216,10 +224,10 @@ class Recommendation(torchModelPath: String, val uid: String) extends Serializab
loss * batchSize
}

def predict(batch: Array[String], model: RecommendPSModel): (Array[String], Array[Float]) = {
def predict(batch: Array[String], model: RecommendPSModel): (Array[Float], Array[Float]) = {
TorchModel.setPath(torchModelPath)
val torch = TorchModel.get()
val tuple3 = SampleParser.parsePredict(batch, torch.getType)
val tuple3 = SampleParser.parse(batch, torch.getType)
val (coo, fields, targets) = (tuple3._1, tuple3._2, tuple3._3)
val output = TorchModelType.withName(torch.getType) match {
case TorchModelType.BIAS_WEIGHT =>
@@ -268,14 +276,15 @@ class Recommendation(torchModelPath: String, val uid: String) extends Serializab
val weightInput = makeWeight(weight, batch)
val embeddingInput = makeEmbedding(embedding, batch.getColIndices, model.getEmbeddingDim)
val matsInput = makeMats(mats)
val multiForwardOut = $(numLabels)
if (fields.isEmpty)
torch.forward(batchSize, batch, biasInput, weightInput,
embeddingInput, model.getEmbeddingDim,
matsInput, torch.getMatsSize)
matsInput, torch.getMatsSize, multiForwardOut)
else
torch.forward(batchSize, batch, biasInput, weightInput,
embeddingInput, model.getEmbeddingDim,
matsInput, torch.getMatsSize, fields.get)
matsInput, torch.getMatsSize, fields.get, multiForwardOut)
}

}
214 changes: 214 additions & 0 deletions python/recommendation/mmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Tencent is pleased to support the open source community by making Angel available.
#
# Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
# compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/Apache-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
#
# !/usr/bin/env python

import argparse
from typing import List

import torch
from torch import Tensor


class MMoE(torch.nn.Module):
def __init__(self, input_dim=-1, n_fields=-1, embedding_dim=-1, experts_hidden=-1, experts_out=-1, towers_hidden=-1, towers_out=1, num_experts=6, tasks=1):
super(MMoE, self).__init__()
# loss func
self.loss_fn = torch.nn.BCELoss()
# input params
self.input_dim = input_dim
self.n_fields = n_fields
self.embedding_dim = embedding_dim

self.experts_out = experts_out
self.num_experts = num_experts
self.tasks = tasks

"""Angel Params"""
# bias
self.bias = torch.nn.Parameter(torch.zeros(1, 1))
# weights
self.weights = torch.nn.Parameter(torch.zeros(1, 1))
# embeddings
self.embedding = torch.nn.Parameter(torch.zeros(embedding_dim))

# mats
self.mats = []
# experts
for i in range(num_experts):
self.mats.append(torch.nn.Parameter(torch.randn(input_dim, experts_hidden)))
self.mats.append(torch.nn.Parameter(torch.randn(1, experts_hidden)))
self.mats.append(torch.nn.Parameter(torch.randn(experts_hidden, experts_out)))
self.mats.append(torch.nn.Parameter(torch.randn(1, experts_out)))
# gates
for i in range(tasks):
self.mats.append(torch.nn.Parameter(torch.randn(input_dim, num_experts)))
# towers
for i in range(tasks):
self.mats.append(torch.nn.Parameter(torch.randn(experts_out, towers_hidden)))
self.mats.append(torch.nn.Parameter(torch.randn(1, towers_hidden)))
self.mats.append(torch.nn.Parameter(torch.randn(towers_hidden, towers_out)))
self.mats.append(torch.nn.Parameter(torch.randn(1, towers_out)))

# init params
for i in self.mats:
torch.nn.init.xavier_uniform_(i)

def parse_mats(self, mats: List[Tensor]):
experts_mats = [torch.stack([mats[4 * i + j] for i in range(self.num_experts)]) for j in range(4)]
gates_mats = [torch.stack([mats[4 * self.num_experts + i] for i in range(self.tasks)])]
towers_mats = [torch.stack([mats[(4 * self.num_experts + self.tasks) + 4 * i + j] for i in range(self.tasks)]) for j in range(4)]
return experts_mats, gates_mats, towers_mats

def experts_module(self, x, experts_mats: List[Tensor]):
x = torch.relu(torch.baddbmm(experts_mats[1], x.expand(self.num_experts, -1, -1), experts_mats[0]))
# 6,30,32
x = torch.nn.functional.dropout(x, p=0.3)
x = torch.baddbmm(experts_mats[3], x, experts_mats[2])
# 6,30,16
return x

def gates_module(self, x, gates_mats: List[Tensor]):
x = torch.nn.functional.softmax(torch.bmm(x.expand(self.tasks, -1, -1), gates_mats[0]), dim=2)
# 2,30,6
return x

def towers_input(self, experts_out, gates_out):
e_o = experts_out.expand(self.tasks, -1, -1, -1)
# 2,6,30,16
g_o = gates_out.unsqueeze(3).expand(-1, -1, -1, self.experts_out)
# 2,30,6,16
g_o = g_o.permute(0, 2, 1, 3)
# 2,6,30,16
x = torch.sum(e_o * g_o, dim=1)
return x

def towers_module(self, towers_input, towers_mats: List[Tensor]):
x = torch.relu(torch.baddbmm(towers_mats[1], towers_input, towers_mats[0]))
x = torch.nn.functional.dropout(x, p=0.4)
x = torch.baddbmm(towers_mats[3], x, towers_mats[2])
x = torch.sigmoid(x)
return x

def forward_(self, batch_size: int, index, feats, values, bias, weights, embeddings, mats: List[Tensor], fields=Tensor([])):
# parse_mats
experts_mats, gates_mats, towers_mats = self.parse_mats(mats)

# sparse_coo_tensor
indices = torch.stack((index, feats), dim=0)
sparse_x = torch.sparse_coo_tensor(indices, values, size=torch.Size((batch_size, self.input_dim)))
# to_dense
dense_x = sparse_x.to_dense()

# experts_module
experts_out = self.experts_module(dense_x, experts_mats)
# gates_module
gates_out = self.gates_module(dense_x, gates_mats)
# towers_input
towers_input = self.towers_input(experts_out, gates_out)
# towers_module
pred = self.towers_module(towers_input, towers_mats)
pred = pred.squeeze(2).permute(1, 0).reshape(-1)
return pred

def forward(self, batch_size: int, index, feats, values, fields=Tensor([])):
return self.forward_(batch_size, index, feats, values, self.bias, self.weights, self.embedding, self.mats, fields)

@torch.jit.export
def loss(self, pred, gt):
pred = pred.view(self.tasks, -1)
gt = gt.view(self.tasks, -1)
return self.loss_fn(pred, gt)

@torch.jit.export
def get_type(self):
return "BIAS_WEIGHT_EMBEDDING_MATS"

@torch.jit.export
def get_name(self):
return "MMoE"


FLAGS = None


def main():
mmoe = MMoE(
input_dim=FLAGS.input_dim,
n_fields=FLAGS.n_fields,
embedding_dim=FLAGS.embedding_dim,
experts_hidden=FLAGS.experts_hidden,
experts_out=FLAGS.experts_out,
towers_hidden=FLAGS.towers_hidden,
towers_out=1,
num_experts=FLAGS.num_experts,
tasks=FLAGS.tasks,
)
mmoe_script_module = torch.jit.script(mmoe)
mmoe_script_module.save("MMoE.pt")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input_dim",
type=int,
default=148,
help="data input dim.",
)
parser.add_argument(
"--n_fields",
type=int,
default=-1,
help="data num fields.",
)
parser.add_argument(
"--embedding_dim",
type=int,
default=1,
help="embedding dim.",
)
parser.add_argument(
"--experts_hidden",
type=int,
default=32,
help="experts hidden.",
)
parser.add_argument(
"--experts_out",
type=int,
default=16,
help="experts out.",
)
parser.add_argument(
"--towers_hidden",
type=int,
default=8,
help="towers hidden.",
)
parser.add_argument(
"--num_experts",
type=int,
default=6,
help="num experts.",
)
parser.add_argument(
"--tasks",
type=int,
default=2,
help="num tasks.",
)
FLAGS, unparsed = parser.parse_known_args()
main()