お手伝いの yuki です。前回に引き続き、小ネタを紹介します。
みなさんは Rust でのバリデーションチェック時、どうしていますか?
まず考えられるのは自前実装です。手でバリデーションチェックをひたすら書く方法がまず考えられると思います。あるいは、私のブログで以前紹介したことがあるのですが、バリデーションチェック情報を型に落として実装し、トレイトなどを駆使して若干程度半自動化するといった方法をとることもできるでしょう。
あるいは多くの方は validator クレートを利用しているかと思います。私も業務でこのクレートを利用していますが、バリデーションチェックの情報をアノテーションに残しておくことができるため、どのフィールドにどのバリデーションチェックがかかる予定かがわかりやすく気に入っています。
その他の選択肢として最近出てきたのですが、Nutype というクレートがあります。このクレートについて今日は紹介したいと思います。
New Type Pattern
その前に前提知識としてひとつ Rust の実装パターンについて知っておく必要があります。というのもこのクレートはこの New Type という実装パターンを利用することを前提としているためです。Nutype (ニュータイプ)という音がまるっきり同じなことから想像がついたかたもいらっしゃるかもしれません。
Rust の New Type パターンはいわゆるオブジェクト指向における軽めのカプセル化です。型の内部の実装詳細を隠匿できるため、メンテナンス性の向上にも一役買うでしょう。あるいはたとえば関数の引数の型を厳しめに設定できるようになるため、引数の型の入れ違えをコンパイル時に未然に発見し防ぐことができるようになります。
そしてこの New Type Pattern を用いた型付けに、追加でマクロを使ってバリデーションチェックをかけられるようにするのが、今回紹介する Nutype というクレートです。名前の音がまったく同じですね。
Nutype
Nutype はアトリビュート部分に特定の記述をすることにより、その型の受け付けられる値の条件をセットすることできます。たとえば整数値には最大値と最小値をセットすることができます。浮動小数点数では、NaN
や無限の値を弾けます。また、文字列は最大文字数やそもそも空文字を受け付けなくすることができます。正規表現をセットすることも可能です。
加えて、バリデーションチェック後に返されるエラーの型も、簡易的ではありますが自動生成されます。ユーザーはとくに自分自身でエラーの型を定義する必要がない、ということです。Nutype はそもそも型付けを強くしましょう、そのために New Type の実装パターンを活用していきましょうという思想があるようです。エラー型が詳しく生成されるのもこの事情によっているのではないかと思います。
サンプルコード
それでは簡単にですがサンプルを実装してみましょう。今回は株式の情報を集めるアプリを題材として適当なデータモデルを設計しておき、各フィールドに対して必要に応じて New Type パターンを用いながら、強く型付けしていくことを目指します。
use std::marker::PhantomData; use currency::Currency; use nutype::nutype; #[derive(Debug, Clone)] pub struct Stock<C> where C: Currency, { stock_symbol: TickerSymbol, name: String, last_sale: Price<C>, change_rate: Percent, market_cap: u128, } #[nutype( sanitize(trim, uppercase) validate(not_empty, max_len = 5) )] #[derive(*)] pub struct TickerSymbol(String); mod currency { pub struct Usd; impl Currency for Usd {} pub struct Jpy; impl Currency for Jpy {} pub trait Currency {} } #[derive(Debug, Clone)] pub struct Price<C: Currency>(f32, PhantomData<C>); #[nutype(validate(finite, min = -100.0, max = 100.0))] #[derive(*)] pub struct Percent(f32); fn main() {} #[test] fn ticker_symbol() { let ticker_symbol = "GOOGL"; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); let ticker_symbol = "GOOG"; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); // change the string to upper case let ticker_symbol = "goog"; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); } #[test] fn sanitize_ticker_symbol() { let ticker_symbol = " GOOG "; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); let ticker_symbol = "goog"; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); } #[test] fn passed_empty_string_to_symbol() { let empty_symbol = "".to_string(); let symbol = TickerSymbol::new(empty_symbol); claim::assert_err!(symbol.clone()); claim::assert_matches!(symbol, Err(TickerSymbolError::Empty)); } #[test] fn passed_invalid_symbol() { let not_exist_such_symbol = "NONEXIST".to_string(); let symbol = TickerSymbol::new(not_exist_such_symbol); claim::assert_err!(symbol.clone()); claim::assert_matches!(symbol, Err(TickerSymbolError::TooLong)); } #[test] fn correct_percent() { let num = 0.589; let percent = Percent::new(num); claim::assert_ok!(percent); let num = -1.187; let percent = Percent::new(num); claim::assert_ok!(percent); } #[test] fn passed_invalid_percent() { let over_hundred = 101.0; let percent = Percent::new(over_hundred); claim::assert_err!(percent.clone()); claim::assert_matches!(percent, Err(PercentError::TooBig)); let nan = f32::NAN; let percent = Percent::new(nan); claim::assert_err!(percent.clone()); claim::assert_matches!(percent, Err(PercentError::NotFinite)); }
まず株式銘柄を示す Stock
という構造体です。この構造体にはシンボル、銘柄名、終値、前日と比較した騰落率、最後に時価総額のフィールドを持っています。
#[derive(Debug, Clone)] pub struct Stock<C> where C: Currency, { stock_symbol: TickerSymbol, name: String, last_sale: Price<C>, change_rate: Percent, market_cap: u128, }
サニタイズとバリデーションチェック
TickerSymbol
からみていきましょう。今回はここに Nutype クレートのアトリビュートをいくつかセットしました。Nutype ではこのように、#[nutype]
というアトリビュートに値のチェック情報を付与します。
#[nutype( sanitize(trim, uppercase) validate(not_empty, max_len = 5) )] #[derive(*)] pub struct TickerSymbol(String);
Nutype の値チェックはふたつにわかれます。サニタイズ(sanitize)とバリデーション(validate)です。
サニタイズはバリデーションチェックの前段階で、エラーとして判定はしないもののかけておいた方がいいフィルタリング処理をかけておくものです。サニタイズは複数設定できます。
たとえば、sanitize(trim, uppercase)
のように指定しておくと、" a "
という文字列が来た際に空白部分をトリミングしたり、goog
という文字列が来た際に GOOG
のような大文字化したりすることができます。
下記は先ほどのコードから該当箇所のテストコードを抜粋したものです。
#[test] fn sanitize_ticker_symbol() { let ticker_symbol = " GOOG "; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); let ticker_symbol = "goog"; let symbol = TickerSymbol::new(ticker_symbol); claim::assert_ok!(symbol); }
new
という関数はとくに定義していないのですが、これも Nutype が裏で生成します。このコンストラクタの中でバリデーションチェックが行われます。new
関数の返りの型は Result
型になります。
バリデーションはイメージする通りのバリデーションで、ここを通らないとエラーとして判定される類のものです。バリデーションチェックはサニタイズと同様複数設定できます。
たとえば、validate(not_empty, max_len = 5)
のように指定しておくと、空白文字が来た場合や文字数が6文字以上のものがきた場合はエラーとして判定できるようになります。
また先ほども説明した通り、エラーの型が裏で自動生成されます。下記は先ほどのコードから該当箇所のテストコードを抜粋したものです。
#[test] fn passed_empty_string_to_symbol() { let empty_symbol = "".to_string(); let symbol = TickerSymbol::new(empty_symbol); claim::assert_err!(symbol.clone()); claim::assert_matches!(symbol, Err(TickerSymbolError::Empty)); } #[test] fn passed_invalid_symbol() { let not_exist_such_symbol = "NONEXIST".to_string(); let symbol = TickerSymbol::new(not_exist_such_symbol); claim::assert_err!(symbol.clone()); claim::assert_matches!(symbol, Err(TickerSymbolError::TooLong)); }
浮動小数点数の方はちょっとおもしろいです。Percent
という構造体を見てみます。
#[nutype(validate(finite, min = -100.0, max = 100.0))] #[derive(*)] pub struct Percent(f32);
浮動小数点数は NaN
や NegInifinity
の扱いが少し悩ましいことが多いです。NaN
については、たとえば2つの値が等しいかをチェックしたとしても直感に反する結果になることが多いです。f32
型には NaN
を表現するビットパターンが222パターンほど存在します。これのせいで PartialEq
を付与できないなどの問題がありました。が、Nutype ではそれらを受け付けしないようにできます。finite
という設定がそれです。これをつけると、NaN
がこないことを実質的に保証できるようになります。
下記はこのことを示してみたテストコードです。
#[test] fn passed_invalid_percent() { let over_hundred = 101.0; let percent = Percent::new(over_hundred); claim::assert_err!(percent.clone()); claim::assert_matches!(percent, Err(PercentError::TooBig)); let nan = f32::NAN; let percent = Percent::new(nan); claim::assert_err!(percent.clone()); claim::assert_matches!(percent, Err(PercentError::NotFinite)); }
finite
のおかげで、PartialEq
や PartialOrd
などが付与できるようになります。
// たとえば、下記がコンパイルエラーにならない #[nutype(validate(finite, min = -100.0, max = 100.0))] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Percent(f32);
#[derive(*)]
最後にもうひとつ便利なものを紹介します。#[derive(*)]
というアトリビュートがたくさんでてきたのが気になった読者の方がいるかもしれません。
これは Debug
や Clone
などよく利用するトレイト実装を *
と指定するだけでまとめて生やしてくれる便利なアトリビュートです(余計なものまで生やすので意図しない挙動をしないか注意する必要がありますが)。生やされるトレイトはその型がラップする型に応じて変わりますが、たとえば String を内包する TickerSymbol の場合は、
- Eq
- PartialEq
- Ord
- PartialOrd
- Hash
- Clone
- Debug
あたりを自動でつけてくれます。f32 などのコピーセマンティクスな型を内包する Percent などにはさらに Copy
を自動で付与しています。
仕組み
最後にこのクレートがどのような実装によって実現されているかですが、非常に単純でマクロを使ってひたすら必要な情報を生やしているだけです。cargo expand
などの補助コマンドでマクロを展開するとよくわかります。
下記は展開してみたコードを貼ったものです。__nutype_private_XXX__
というモジュール内に必要なトレイトの実装と、生成されるエラー用の型があることがわかります。
#![feature(prelude_import)] #[prelude_import] use std::prelude::rust_2021::*; #[macro_use] extern crate std; use std::marker::PhantomData; use currency::Currency; use nutype::nutype; pub struct Stock<C> where C: Currency, { stock_symbol: TickerSymbol, name: String, last_sale: Price<C>, change_rate: Percent, market_cap: u128, } #[automatically_derived] impl<C: ::core::fmt::Debug> ::core::fmt::Debug for Stock<C> where C: Currency, { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::debug_struct_field5_finish( f, "Stock", "stock_symbol", &self.stock_symbol, "name", &self.name, "last_sale", &self.last_sale, "change_rate", &self.change_rate, "market_cap", &&self.market_cap, ) } } #[automatically_derived] impl<C: ::core::clone::Clone> ::core::clone::Clone for Stock<C> where C: Currency, { #[inline] fn clone(&self) -> Stock<C> { Stock { stock_symbol: ::core::clone::Clone::clone(&self.stock_symbol), name: ::core::clone::Clone::clone(&self.name), last_sale: ::core::clone::Clone::clone(&self.last_sale), change_rate: ::core::clone::Clone::clone(&self.change_rate), market_cap: ::core::clone::Clone::clone(&self.market_cap), } } } #[doc(hidden)] mod __nutype_private_TickerSymbol__ { use super::*; pub struct TickerSymbol(String); #[automatically_derived] impl ::core::marker::StructuralEq for TickerSymbol {} #[automatically_derived] impl ::core::cmp::Eq for TickerSymbol { #[inline] #[doc(hidden)] #[no_coverage] fn assert_receiver_is_total_eq(&self) -> () { let _: ::core::cmp::AssertParamIsEq<String>; } } #[automatically_derived] impl ::core::fmt::Debug for TickerSymbol { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::debug_tuple_field1_finish( f, "TickerSymbol", &&self.0, ) } } #[automatically_derived] impl ::core::cmp::Ord for TickerSymbol { #[inline] fn cmp(&self, other: &TickerSymbol) -> ::core::cmp::Ordering { ::core::cmp::Ord::cmp(&self.0, &other.0) } } #[automatically_derived] impl ::core::cmp::PartialOrd for TickerSymbol { #[inline] fn partial_cmp( &self, other: &TickerSymbol, ) -> ::core::option::Option<::core::cmp::Ordering> { ::core::cmp::PartialOrd::partial_cmp(&self.0, &other.0) } } #[automatically_derived] impl ::core::hash::Hash for TickerSymbol { #[inline] fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () { ::core::hash::Hash::hash(&self.0, state) } } #[automatically_derived] impl ::core::clone::Clone for TickerSymbol { #[inline] fn clone(&self) -> TickerSymbol { TickerSymbol(::core::clone::Clone::clone(&self.0)) } } #[automatically_derived] impl ::core::marker::StructuralPartialEq for TickerSymbol {} #[automatically_derived] impl ::core::cmp::PartialEq for TickerSymbol { #[inline] fn eq(&self, other: &TickerSymbol) -> bool { self.0 == other.0 } } pub enum TickerSymbolError { Empty, TooLong, } #[automatically_derived] impl ::core::fmt::Debug for TickerSymbolError { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::write_str( f, match self { TickerSymbolError::Empty => "Empty", TickerSymbolError::TooLong => "TooLong", }, ) } } #[automatically_derived] impl ::core::clone::Clone for TickerSymbolError { #[inline] fn clone(&self) -> TickerSymbolError { match self { TickerSymbolError::Empty => TickerSymbolError::Empty, TickerSymbolError::TooLong => TickerSymbolError::TooLong, } } } #[automatically_derived] impl ::core::marker::StructuralPartialEq for TickerSymbolError {} #[automatically_derived] impl ::core::cmp::PartialEq for TickerSymbolError { #[inline] fn eq(&self, other: &TickerSymbolError) -> bool { let __self_tag = ::core::intrinsics::discriminant_value(self); let __arg1_tag = ::core::intrinsics::discriminant_value(other); __self_tag == __arg1_tag } } #[automatically_derived] impl ::core::marker::StructuralEq for TickerSymbolError {} #[automatically_derived] impl ::core::cmp::Eq for TickerSymbolError { #[inline] #[doc(hidden)] #[no_coverage] fn assert_receiver_is_total_eq(&self) -> () {} } impl ::core::fmt::Display for TickerSymbolError { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { match self { TickerSymbolError::Empty => f.write_fmt(format_args!("empty")), TickerSymbolError::TooLong => f.write_fmt(format_args!("too long")), } } } impl ::std::error::Error for TickerSymbolError { fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { None } } impl TickerSymbol { pub fn new( raw_value: impl Into<String>, ) -> ::core::result::Result<Self, TickerSymbolError> { fn sanitize(value: String) -> String { let value: String = value.trim().to_string(); let value: String = value.to_uppercase(); value } fn validate(val: &str) -> ::core::result::Result<(), TickerSymbolError> { let chars_count = val.chars().count(); if val.is_empty() { return Err(TickerSymbolError::Empty); } if chars_count > 5usize { return Err(TickerSymbolError::TooLong); } Ok(()) } let sanitized_value = sanitize(raw_value.into()); validate(&sanitized_value)?; Ok(TickerSymbol(sanitized_value)) } } impl TickerSymbol { pub fn into_inner(self) -> String { self.0 } } impl core::str::FromStr for TickerSymbol { type Err = TickerSymbolError; fn from_str(raw_string: &str) -> ::core::result::Result<Self, Self::Err> { TickerSymbol::new(raw_string) } } impl ::core::convert::TryFrom<String> for TickerSymbol { type Error = TickerSymbolError; fn try_from(raw_value: String) -> Result<TickerSymbol, Self::Error> { Self::new(raw_value) } } impl ::core::convert::TryFrom<&str> for TickerSymbol { type Error = TickerSymbolError; fn try_from(raw_value: &str) -> Result<TickerSymbol, Self::Error> { Self::new(raw_value) } } impl ::core::convert::AsRef<str> for TickerSymbol { fn as_ref(&self) -> &str { &self.0 } } } pub use __nutype_private_TickerSymbol__::TickerSymbol; pub use __nutype_private_TickerSymbol__::TickerSymbolError; mod currency { pub struct Usd; impl Currency for Usd {} pub struct Jpy; impl Currency for Jpy {} pub trait Currency {} } pub struct Price<C: Currency>(f32, PhantomData<C>); #[automatically_derived] impl<C: ::core::fmt::Debug + Currency> ::core::fmt::Debug for Price<C> { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::debug_tuple_field2_finish(f, "Price", &self.0, &&self.1) } } #[automatically_derived] impl<C: ::core::clone::Clone + Currency> ::core::clone::Clone for Price<C> { #[inline] fn clone(&self) -> Price<C> { Price(::core::clone::Clone::clone(&self.0), ::core::clone::Clone::clone(&self.1)) } } #[doc(hidden)] mod __nutype_private_Percent__ { use super::*; pub struct Percent(f32); #[automatically_derived] impl ::core::cmp::PartialOrd for Percent { #[inline] fn partial_cmp( &self, other: &Percent, ) -> ::core::option::Option<::core::cmp::Ordering> { ::core::cmp::PartialOrd::partial_cmp(&self.0, &other.0) } } #[automatically_derived] impl ::core::clone::Clone for Percent { #[inline] fn clone(&self) -> Percent { let _: ::core::clone::AssertParamIsClone<f32>; *self } } #[automatically_derived] impl ::core::marker::StructuralPartialEq for Percent {} #[automatically_derived] impl ::core::cmp::PartialEq for Percent { #[inline] fn eq(&self, other: &Percent) -> bool { self.0 == other.0 } } #[automatically_derived] impl ::core::marker::Copy for Percent {} #[automatically_derived] impl ::core::fmt::Debug for Percent { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::debug_tuple_field1_finish(f, "Percent", &&self.0) } } pub enum PercentError { NotFinite, TooSmall, TooBig, } #[automatically_derived] impl ::core::fmt::Debug for PercentError { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::write_str( f, match self { PercentError::NotFinite => "NotFinite", PercentError::TooSmall => "TooSmall", PercentError::TooBig => "TooBig", }, ) } } #[automatically_derived] impl ::core::clone::Clone for PercentError { #[inline] fn clone(&self) -> PercentError { match self { PercentError::NotFinite => PercentError::NotFinite, PercentError::TooSmall => PercentError::TooSmall, PercentError::TooBig => PercentError::TooBig, } } } #[automatically_derived] impl ::core::marker::StructuralPartialEq for PercentError {} #[automatically_derived] impl ::core::cmp::PartialEq for PercentError { #[inline] fn eq(&self, other: &PercentError) -> bool { let __self_tag = ::core::intrinsics::discriminant_value(self); let __arg1_tag = ::core::intrinsics::discriminant_value(other); __self_tag == __arg1_tag } } #[automatically_derived] impl ::core::marker::StructuralEq for PercentError {} #[automatically_derived] impl ::core::cmp::Eq for PercentError { #[inline] #[doc(hidden)] #[no_coverage] fn assert_receiver_is_total_eq(&self) -> () {} } impl ::core::fmt::Display for PercentError { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { match self { PercentError::NotFinite => f.write_fmt(format_args!("not finite")), PercentError::TooSmall => f.write_fmt(format_args!("too small")), PercentError::TooBig => f.write_fmt(format_args!("too big")), } } } impl ::std::error::Error for PercentError { fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { None } } impl Percent { pub fn new(raw_value: f32) -> ::core::result::Result<Self, PercentError> { fn sanitize(mut value: f32) -> f32 { value } fn validate(val: f32) -> core::result::Result<(), PercentError> { if !val.is_finite() { return Err(PercentError::NotFinite); } if val < -100f32 { return Err(PercentError::TooSmall); } if val > 100f32 { return Err(PercentError::TooBig); } Ok(()) } let sanitized_value = sanitize(raw_value); validate(sanitized_value)?; Ok(Percent(sanitized_value)) } } impl Percent { pub fn into_inner(self) -> f32 { self.0 } } pub enum PercentParseError { Parse(<f32 as ::core::str::FromStr>::Err), Validate(PercentError), } #[automatically_derived] impl ::core::fmt::Debug for PercentParseError { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { match self { PercentParseError::Parse(__self_0) => { ::core::fmt::Formatter::debug_tuple_field1_finish( f, "Parse", &__self_0, ) } PercentParseError::Validate(__self_0) => { ::core::fmt::Formatter::debug_tuple_field1_finish( f, "Validate", &__self_0, ) } } } } impl ::core::fmt::Display for PercentParseError { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { match self { PercentParseError::Parse(err) => { f.write_fmt(format_args!("Failed to parse {0}: {1}", "Percent", err)) } PercentParseError::Validate(err) => { f.write_fmt(format_args!("Failed to parse {0}: {1}", "Percent", err)) } } } } impl ::std::error::Error for PercentParseError { fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { None } } impl ::core::str::FromStr for Percent { type Err = PercentParseError; fn from_str(raw_string: &str) -> ::core::result::Result<Self, Self::Err> { let raw_value: f32 = raw_string.parse().map_err(PercentParseError::Parse)?; Self::new(raw_value).map_err(PercentParseError::Validate) } } impl ::core::convert::AsRef<f32> for Percent { fn as_ref(&self) -> &f32 { &self.0 } } #[allow(clippy::derive_ord_xor_partial_ord)] impl ::core::cmp::Ord for Percent { fn cmp(&self, other: &Self) -> ::core::cmp::Ordering { self.partial_cmp(other) .unwrap_or_else(|| { let tp = "Percent"; { ::core::panicking::panic_fmt( format_args!( "{0}::cmp() panicked, because partial_cmp() returned None. Could it be that you\'re using unsafe {0}::new_unchecked() ?", tp ), ); }; }) } } impl ::core::convert::TryFrom<f32> for Percent { type Error = PercentError; fn try_from(raw_value: f32) -> Result<Percent, Self::Error> { Self::new(raw_value) } } impl ::core::cmp::Eq for Percent {} } pub use __nutype_private_Percent__::Percent; pub use __nutype_private_Percent__::PercentError; pub use __nutype_private_Percent__::PercentParseError; fn main() {}
まとめ
NewType パターンを用いて強めに型付けをしたい際に利用できるバリデーションチェック用のクレートを紹介しました。これから機能が増えることが期待されるクレートです。PhantomData
をもつ型は対応できていないようですが、これを対応してもらえると少し利用できる幅が広がるかもしれないとは思いました(PhantomData をもつものは NewType とは呼ばないのかもしれませんが…)。