Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Rust CI

on:
push:
pull_request:
branches: [ main ]

env:
CARGO_TERM_COLOR: always

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Build
run: cargo build --verbose
# - name: Run tests
# run: cargo test --verbose
24 changes: 12 additions & 12 deletions src/bayesian_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ use serde_json::{Result, Value};
/// this index comes from the order given in `states`
/// - columns are indexed by the parent's possible values
/// this index comes from the order given in the `parents` table
///
///
/// Example:
///
///
/// Suppose we have a Bayesian network (a) -> (c) <- (b)
/// - state: {"a" -> ["a1", "a2"], "b" -> ["b1", "b2", "b3"]}
/// - parents: {"a" -> [], "b" -> [], "c" -> ["a", "b"]}
/// - cpts: ```{"a" -> [[0.1], [0.9]], // says "a" has a prior probability of value "a1" with prob 0.1
/// "b" -> [[0.3], [0.2], [0.5]],
/// "c" -> [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
/// [0.9, 0.8, 0.7, 0.6, 0.5, 0.4]]}```
/// "b" -> [[0.3], [0.2], [0.5]],
/// "c" -> [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
/// [0.9, 0.8, 0.7, 0.6, 0.5, 0.4]]}```
/// says Pr(c=c1 | a=a1, b=b1) = 0.1
/// Pr(c=c1 | a=a2, b=b1) = 0.4
/// Pr(c=c1 | a=a1, b=b3) = 0.3
Expand All @@ -33,7 +33,7 @@ type parents = HashMap<String, Vec<String>>;
pub struct BayesianNetwork {
network: String,
variables: Vec<String>,
cpts: cpt,
cpts: cpt,
states: states,
parents: parents
}
Expand All @@ -44,12 +44,12 @@ impl BayesianNetwork {
}

fn get_state_index(&self, variable: &String, assignment: &String) -> usize {
let cur_s = self.states.get(variable).unwrap_or_else(|| panic!("could not find variable {variable}"));
cur_s.into_iter().position(|x| *x == *assignment).unwrap_or_else(|| panic!("could not find assignment {assignment} for variable {variable}"))
let cur_s = self.states.get(variable).unwrap_or_else(|| panic!("could not find variable {}", variable));
cur_s.into_iter().position(|x| *x == *assignment).unwrap_or_else(|| panic!("could not find assignment {} for variable {}", assignment, variable))
}

fn get_num_states(&self, variable: &String) -> usize {
let cur_s = self.states.get(variable).unwrap_or_else(|| panic!("could not find variable {variable}"));
let cur_s = self.states.get(variable).unwrap_or_else(|| panic!("could not find variable {}", variable));
cur_s.len()
}

Expand All @@ -67,7 +67,7 @@ impl BayesianNetwork {
let p = cur_parents.pop().unwrap();
let cur_values = self.get_all_assignments(&p);
let sub = self.parent_h(cur_parents);

// add each assignment onto
for v in cur_values {
for i in 0..(sub.len()) {
Expand Down Expand Up @@ -142,7 +142,7 @@ impl BayesianNetwork {
fn test_conditional() {
let sachs = include_str!("../bayesian_networks/sachs.json");
let network = BayesianNetwork::from_string(&sachs);
let parent_assgn = HashMap::from([ (String::from("Erk"), String::from("HIGH")),
let parent_assgn = HashMap::from([ (String::from("Erk"), String::from("HIGH")),
(String::from("PKA"), String::from("AVG")) ]);
assert_eq!(network.get_conditional_prob(&String::from("Akt"), &String::from("LOW"), &parent_assgn),0.177105936);
}
Expand All @@ -153,4 +153,4 @@ fn test_parent() {
let network = BayesianNetwork::from_string(&sachs);
println!("{:?}", network.parent_assignments(&String::from("Erk")));
assert!(false);
}
}