# 行列計算でクロス集計をする (Rust) [hibi-myzk/rust-crosstab](https://github.com/hibi-myzk/rust-crosstab) [[行列計算でクロス集計をする]] の [[Rust]] 実装。 Cargo.toml ```toml [dependencies.ndarray] version = "0.15" features = ["serde-1"] ``` main.rs ```rust extern crate ndarray; use ndarray::{arr2, Array2}; fn main() { /* 生データ */ // 3 人分の属性データ(性別、年代) // API から取得 let attr: Array2<i32> = arr2(&[ // [男性, 女性, 20 代, 30 代, 40 代] [1, 0, 1, 0, 0], // 男性, 20 代 [0, 1, 1, 0, 0], // 女性, 20 代 [0, 1, 0, 1, 0], // 女性, 30 代 ]); // 3 人分の回答データ(10 問) // API から取得 let data: Array2<i32> = arr2(&[ [1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代 [3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代 ]); /* クロス集計 */ println!("# Crosstab\n"); // 二つを比較したい let cond: Array2<i32> = arr2(&[ [0, 0, 1, 0, 0], // 20 代(性別は問わない) [0, 1, 1, 0, 0], // 20 代女性 ]); // 指定した属性数 // 1 .. 年代のみ, 2 .. 年代と性別 let cond_count = vec![1, 2]; let crosstab_result = calc_crosstab(&attr, &data, &cond, &cond_count); if let Ok(result) = crosstab_result { println!("Crosstab Result Matrix:\n{}", result); } /* Crosstab Result Matrix: [[6, 8, 4, 6, 8, 10, 6, 8, 4, 6], // 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 1, 2]] // 20 代女性 */ /* 全属性ごとの合計スコア */ println!("\n# All Score\n"); // コードで作成 // 2 x 3 = 6 行(性別: 2, 年代: 3) let all_attr = arr2(&[ [1, 1, 1, 1, 1], // 全体 [1, 0, 0, 0, 0], // 男性 [0, 1, 0, 0, 0], // 女性 [0, 0, 1, 0, 0], // 20 代 [0, 0, 0, 1, 0], // 30 代 [0, 0, 0, 0, 1], // 40 代 ]); // 指定した属性数 // 1 .. 年代のみ or 性別のみ, 2 .. 年代と性別 let all_cond_count = vec![2, 1, 1, 1, 1, 1]; let all_result = calc_crosstab(&attr, &data, &all_attr, &all_cond_count); if let Ok(result) = all_result { println!("All Result Matrix:\n{}", result); } /* All Result Matrix: [[9, 12, 9, 12, 9, 12, 9, 12, 9, 12], // 全体 [1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性 [8, 10, 6, 8, 4, 6, 8, 10, 6, 8], // 女性 [6, 8, 4, 6, 8, 10, 6, 8, 4, 6], // 20 代 [3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 30 代 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] // 40 代 */ /* 平均スコア */ println!("\n# Average Score\n"); let data_for_average: Array2<i32> = arr2(&[ [1, 2, 3, 4, 5, 6, 1, 2, 3, 0], // 男性, 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 0, 0], // 女性, 20 代 [3, 4, 5, 6, 1, 2, 3, 0, 0, 0], // 女性, 30 代 ]); let average_result = calc_crosstab_average(&attr, &data_for_average, &all_attr, &all_cond_count); if let Ok(result) = average_result { println!("All Average Result Matrix:\n{}", result); } /* All Average Result Matrix: [[3, 4, 3, 4, 3, 4, 3, 4, 3, NaN], [1, 2, 3, 4, 5, 6, 1, 2, 3, NaN], [4, 5, 3, 4, 2, 3, 4, 6, NaN, NaN], [3, 4, 2, 3, 4, 5, 3, 4, 3, NaN], [3, 4, 5, 6, 1, 2, 3, NaN, NaN, NaN], [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN]] */ } ``` ```rust // 符号関数 fn sign(x: i32) -> i32 { match x { x if x > 0 => 1, x if x < 0 => -1, _ => 0, } } // 活性化関数 fn relu(x: i32) -> i32 { x.max(0) } ``` ```rust /// クロス集計を行う関数 /// /// # Arguments /// * `attr` - 属性データ /// * `data` - 回答データ /// * `cond` - 集計条件 /// * `cond_count` - 集計条件に含まれる属性数 /// /// # Returns /// * クロス集計結果の行列とマスク行列 fn calc_crosstab( attr: &Array2<i32>, data: &Array2<i32>, cond: &Array2<i32>, cond_count: &Vec<i32> ) -> Result<(Array2<i32>, Array2<i32>), String> { // 入力値の検証 if attr.shape()[1] != cond.shape()[1] { println!("attr.shape()[1]: {}", attr.shape()[1]); println!("cond.shape()[1]: {}", cond.shape()[1]); return Err("Invalid attribute matrix shape".to_string()); } if attr.shape()[0] != data.shape()[0] { println!("attr.shape()[0]: {}", attr.shape()[0]); println!("data.shape()[0]: {}", data.shape()[0]); return Err("Invalid data matrix shape".to_string()); } let n = data.shape()[0]; // 回答者数 let ones = Array2::<i32>::ones((1, n)); println!("Ones Matrix:\n{}", ones); /* Ones Matrix: [[1, 1, 1]] */ // cond と attr の転置行列の積 let mask_base = cond.dot(&attr.t()); println!("Mask Base Matrix:\n{}", mask_base); /* Mask Base Matrix: [[1, 1, 0], [1, 2, 1]] */ let cond_count_matrix = Array2::<i32>::from_shape_vec((cond_count.len(), 1), cond_count.clone()).unwrap(); let cond_ones = Array2::<i32>::ones((1, attr.shape()[0])); let cond_diff = cond_count_matrix.dot(&cond_ones) - cond_ones; println!("Cond Diff Matrix:\n{}", cond_diff); /* Cond Diff Matrix: [[0, 0, 0], [1, 1, 1]] */ let mask = (mask_base - cond_diff).map(|x| relu(*x)); println!("Mask Matrix:\n{}", mask); /* Mask Matrix: [[1, 1, 0], [0, 1, 0]] */ println!("Count:\n{}", ones.dot(&mask.t())); /* Count: [[2, 1]] 20 代: 2 人 20 代女性: 1 人 */ // 該当する回答者のスコアの合計 let result = mask.dot(data); println!("Result Matrix:\n{}", result); /* Result Matrix: [[6, 8, 4, 6, 8, 10, 6, 8, 4, 6], // 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 1, 2]] // 20 代女性 */ Ok((result, mask)) } ``` ```rust fn calc_crosstab_average( attr: &Array2<i32>, data: &Array2<i32>, cond: &Array2<i32>, cond_count: &Vec<i32> ) -> Result<Array2<f32>, String> { let (sum, mask) = calc_crosstab(attr, data, cond, cond_count)?; println!("Sum:\n{}", sum); println!("Mask:\n{}", mask); let signed_data = data.map(|x| sign(*x)); let counts = mask.dot(&signed_data); println!("Counts:\n{}", counts); // 有効回答数で割って平均を計算 let average = sum.mapv(|x| x as f32) / counts.mapv(|x| x as f32); println!("Average:\n{}", average); Ok(average) } ``` ```rust #[cfg(test)] mod tests { use super::*; #[test] fn test_calc_crosstab() { // テストケースを追加 let attr = arr2(&[ [1, 0, 1, 0, 0], // 男性, 20 代 [0, 1, 1, 0, 0], // 女性, 20 代 [0, 1, 0, 1, 0], // 女性, 30 代 ]); let data = arr2(&[ [1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代 [3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代 ]); let cond = arr2(&[ [0, 0, 1, 0, 0], // 20 代(性別は問わない) [0, 1, 1, 0, 0], // 20 代女性 ]); let cond_count = vec![1, 2]; let (_, mask) = calc_crosstab(&attr, &data, &cond, &cond_count).unwrap(); let ones = Array2::<i32>::ones((1, data.shape()[0])); let count = ones.dot(&mask.t()); assert_eq!(count, arr2(&[[2, 1]])); } #[test] fn test_calc_crosstab_with_more_cond() { // テストケースを追加 let attr = arr2(&[ [1, 0, 1, 0, 0], // 男性, 20 代 [0, 1, 1, 0, 0], // 女性, 20 代 [0, 1, 0, 1, 0], // 女性, 30 代 ]); let data = arr2(&[ [1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代 [3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代 ]); let cond = arr2(&[ [1, 0, 1, 0, 0], // 20 代男性 [0, 1, 0, 0, 1], // 40 代女性 ]); let cond_count = vec![2, 2]; let (_, mask) = calc_crosstab(&attr, &data, &cond, &cond_count).unwrap(); let ones = Array2::<i32>::ones((1, data.shape()[0])); let count = ones.dot(&mask.t()); assert_eq!(count, arr2(&[[1, 0]])); } #[test] fn test_invalid_input() { // エラーケースのテスト let attr = arr2(&[ // 行が少ない [1, 0, 1, 0, 0, 0], // 列が多い ]); let data = arr2(&[ [1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代 [5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代 [3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代 ]); let cond = arr2(&[ [0, 0, 1, 0, 0], // 20 代(性別は問わない) [0, 1, 1, 0, 0], // 20 代女性 ]); let cond_count = vec![1, 2]; let result = calc_crosstab(&attr, &data, &cond, &cond_count); assert!(result.is_err()); } } ```