# 行列計算でクロス集計をする (Rust)
[hibi-myzk/rust-crosstab](https://github.com/hibi-myzk/rust-crosstab)
[[行列計算でクロス集計をする]] の [[Rust]] 実装。
Cargo.toml
```toml
[dependencies.ndarray]
version = "0.15"
features = ["serde-1"]
```
main.rs
```rust
extern crate ndarray;
use ndarray::{arr2, Array2};
fn main() {
/* 生データ */
// 3 人分の属性データ(性別、年代)
// API から取得
let attr: Array2<i32> = arr2(&[
// [男性, 女性, 20 代, 30 代, 40 代]
[1, 0, 1, 0, 0], // 男性, 20 代
[0, 1, 1, 0, 0], // 女性, 20 代
[0, 1, 0, 1, 0], // 女性, 30 代
]);
// 3 人分の回答データ(10 問)
// API から取得
let data: Array2<i32> = arr2(&[
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代
[3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代
]);
/* クロス集計 */
println!("# Crosstab\n");
// 二つを比較したい
let cond: Array2<i32> = arr2(&[
[0, 0, 1, 0, 0], // 20 代(性別は問わない)
[0, 1, 1, 0, 0], // 20 代女性
]);
// 指定した属性数
// 1 .. 年代のみ, 2 .. 年代と性別
let cond_count = vec![1, 2];
let crosstab_result = calc_crosstab(&attr, &data, &cond, &cond_count);
if let Ok(result) = crosstab_result {
println!("Crosstab Result Matrix:\n{}", result);
}
/*
Crosstab Result Matrix:
[[6, 8, 4, 6, 8, 10, 6, 8, 4, 6], // 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 1, 2]] // 20 代女性
*/
/* 全属性ごとの合計スコア */
println!("\n# All Score\n");
// コードで作成
// 2 x 3 = 6 行(性別: 2, 年代: 3)
let all_attr = arr2(&[
[1, 1, 1, 1, 1], // 全体
[1, 0, 0, 0, 0], // 男性
[0, 1, 0, 0, 0], // 女性
[0, 0, 1, 0, 0], // 20 代
[0, 0, 0, 1, 0], // 30 代
[0, 0, 0, 0, 1], // 40 代
]);
// 指定した属性数
// 1 .. 年代のみ or 性別のみ, 2 .. 年代と性別
let all_cond_count = vec![2, 1, 1, 1, 1, 1];
let all_result = calc_crosstab(&attr, &data, &all_attr, &all_cond_count);
if let Ok(result) = all_result {
println!("All Result Matrix:\n{}", result);
}
/*
All Result Matrix:
[[9, 12, 9, 12, 9, 12, 9, 12, 9, 12], // 全体
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性
[8, 10, 6, 8, 4, 6, 8, 10, 6, 8], // 女性
[6, 8, 4, 6, 8, 10, 6, 8, 4, 6], // 20 代
[3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 30 代
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] // 40 代
*/
/* 平均スコア */
println!("\n# Average Score\n");
let data_for_average: Array2<i32> = arr2(&[
[1, 2, 3, 4, 5, 6, 1, 2, 3, 0], // 男性, 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 0, 0], // 女性, 20 代
[3, 4, 5, 6, 1, 2, 3, 0, 0, 0], // 女性, 30 代
]);
let average_result = calc_crosstab_average(&attr, &data_for_average, &all_attr, &all_cond_count);
if let Ok(result) = average_result {
println!("All Average Result Matrix:\n{}", result);
}
/*
All Average Result Matrix:
[[3, 4, 3, 4, 3, 4, 3, 4, 3, NaN],
[1, 2, 3, 4, 5, 6, 1, 2, 3, NaN],
[4, 5, 3, 4, 2, 3, 4, 6, NaN, NaN],
[3, 4, 2, 3, 4, 5, 3, 4, 3, NaN],
[3, 4, 5, 6, 1, 2, 3, NaN, NaN, NaN],
[NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN]]
*/
}
```
```rust
// 符号関数
fn sign(x: i32) -> i32 {
match x {
x if x > 0 => 1,
x if x < 0 => -1,
_ => 0,
}
}
// 活性化関数
fn relu(x: i32) -> i32 {
x.max(0)
}
```
```rust
/// クロス集計を行う関数
///
/// # Arguments
/// * `attr` - 属性データ
/// * `data` - 回答データ
/// * `cond` - 集計条件
/// * `cond_count` - 集計条件に含まれる属性数
///
/// # Returns
/// * クロス集計結果の行列とマスク行列
fn calc_crosstab(
attr: &Array2<i32>,
data: &Array2<i32>,
cond: &Array2<i32>,
cond_count: &Vec<i32>
) -> Result<(Array2<i32>, Array2<i32>), String> {
// 入力値の検証
if attr.shape()[1] != cond.shape()[1] {
println!("attr.shape()[1]: {}", attr.shape()[1]);
println!("cond.shape()[1]: {}", cond.shape()[1]);
return Err("Invalid attribute matrix shape".to_string());
}
if attr.shape()[0] != data.shape()[0] {
println!("attr.shape()[0]: {}", attr.shape()[0]);
println!("data.shape()[0]: {}", data.shape()[0]);
return Err("Invalid data matrix shape".to_string());
}
let n = data.shape()[0]; // 回答者数
let ones = Array2::<i32>::ones((1, n));
println!("Ones Matrix:\n{}", ones);
/*
Ones Matrix:
[[1, 1, 1]]
*/
// cond と attr の転置行列の積
let mask_base = cond.dot(&attr.t());
println!("Mask Base Matrix:\n{}", mask_base);
/*
Mask Base Matrix:
[[1, 1, 0],
[1, 2, 1]]
*/
let cond_count_matrix = Array2::<i32>::from_shape_vec((cond_count.len(), 1), cond_count.clone()).unwrap();
let cond_ones = Array2::<i32>::ones((1, attr.shape()[0]));
let cond_diff = cond_count_matrix.dot(&cond_ones) - cond_ones;
println!("Cond Diff Matrix:\n{}", cond_diff);
/*
Cond Diff Matrix:
[[0, 0, 0],
[1, 1, 1]]
*/
let mask = (mask_base - cond_diff).map(|x| relu(*x));
println!("Mask Matrix:\n{}", mask);
/*
Mask Matrix:
[[1, 1, 0],
[0, 1, 0]]
*/
println!("Count:\n{}", ones.dot(&mask.t()));
/*
Count:
[[2, 1]]
20 代: 2 人
20 代女性: 1 人
*/
// 該当する回答者のスコアの合計
let result = mask.dot(data);
println!("Result Matrix:\n{}", result);
/*
Result Matrix:
[[6, 8, 4, 6, 8, 10, 6, 8, 4, 6], // 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 1, 2]] // 20 代女性
*/
Ok((result, mask))
}
```
```rust
fn calc_crosstab_average(
attr: &Array2<i32>,
data: &Array2<i32>,
cond: &Array2<i32>,
cond_count: &Vec<i32>
) -> Result<Array2<f32>, String> {
let (sum, mask) = calc_crosstab(attr, data, cond, cond_count)?;
println!("Sum:\n{}", sum);
println!("Mask:\n{}", mask);
let signed_data = data.map(|x| sign(*x));
let counts = mask.dot(&signed_data);
println!("Counts:\n{}", counts);
// 有効回答数で割って平均を計算
let average = sum.mapv(|x| x as f32) / counts.mapv(|x| x as f32);
println!("Average:\n{}", average);
Ok(average)
}
```
```rust
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calc_crosstab() {
// テストケースを追加
let attr = arr2(&[
[1, 0, 1, 0, 0], // 男性, 20 代
[0, 1, 1, 0, 0], // 女性, 20 代
[0, 1, 0, 1, 0], // 女性, 30 代
]);
let data = arr2(&[
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代
[3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代
]);
let cond = arr2(&[
[0, 0, 1, 0, 0], // 20 代(性別は問わない)
[0, 1, 1, 0, 0], // 20 代女性
]);
let cond_count = vec![1, 2];
let (_, mask) = calc_crosstab(&attr, &data, &cond, &cond_count).unwrap();
let ones = Array2::<i32>::ones((1, data.shape()[0]));
let count = ones.dot(&mask.t());
assert_eq!(count, arr2(&[[2, 1]]));
}
#[test]
fn test_calc_crosstab_with_more_cond() {
// テストケースを追加
let attr = arr2(&[
[1, 0, 1, 0, 0], // 男性, 20 代
[0, 1, 1, 0, 0], // 女性, 20 代
[0, 1, 0, 1, 0], // 女性, 30 代
]);
let data = arr2(&[
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代
[3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代
]);
let cond = arr2(&[
[1, 0, 1, 0, 0], // 20 代男性
[0, 1, 0, 0, 1], // 40 代女性
]);
let cond_count = vec![2, 2];
let (_, mask) = calc_crosstab(&attr, &data, &cond, &cond_count).unwrap();
let ones = Array2::<i32>::ones((1, data.shape()[0]));
let count = ones.dot(&mask.t());
assert_eq!(count, arr2(&[[1, 0]]));
}
#[test]
fn test_invalid_input() {
// エラーケースのテスト
let attr = arr2(&[
// 行が少ない
[1, 0, 1, 0, 0, 0], // 列が多い
]);
let data = arr2(&[
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4], // 男性, 20 代
[5, 6, 1, 2, 3, 4, 5, 6, 1, 2], // 女性, 20 代
[3, 4, 5, 6, 1, 2, 3, 4, 5, 6], // 女性, 30 代
]);
let cond = arr2(&[
[0, 0, 1, 0, 0], // 20 代(性別は問わない)
[0, 1, 1, 0, 0], // 20 代女性
]);
let cond_count = vec![1, 2];
let result = calc_crosstab(&attr, &data, &cond, &cond_count);
assert!(result.is_err());
}
}
```