00001
00002
00003
00004
00005
00006
00007 #include "rgrl_data_manager.h"
00008 #include <rgrl/rgrl_matcher_k_nearest.h>
00009 #include <rgrl/rgrl_scale_est_all_weights.h>
00010 #include <rgrl/rgrl_scale_est_closest.h>
00011 #include <rgrl/rgrl_weighter_m_est.h>
00012 #include <rgrl/rgrl_convergence_on_median_error.h>
00013
00014 #include <rrel/rrel_lms_obj.h>
00015 #include <rrel/rrel_tukey_obj.h>
00016
00017 #include <vcl_cassert.h>
00018
00019
00020
00021
00022
00023 rgrl_data_manager::
00024 rgrl_data_manager( bool multi_stage )
00025 : multi_stage_( multi_stage ),
00026 multi_feature_( false )
00027 {
00028 }
00029
00030
00031 rgrl_data_manager::
00032 ~rgrl_data_manager()
00033 {
00034 }
00035
00036
00037 void
00038 rgrl_data_manager::
00039 add_data( unsigned stage,
00040 rgrl_feature_set_sptr from_set,
00041 rgrl_feature_set_sptr to_set,
00042 rgrl_matcher_sptr matcher,
00043 rgrl_weighter_sptr weighter,
00044 rgrl_scale_estimator_unwgted_sptr unwgted_scale_est,
00045 rgrl_scale_estimator_wgted_sptr wgted_scale_est,
00046 const vcl_string& label )
00047 {
00048
00049 assert( from_set && to_set );
00050
00051 if ( data_.has( stage ) ) multi_feature_ = true;
00052
00053
00054 generate_defaults( matcher, weighter, unwgted_scale_est );
00055
00056
00057 data_[stage].push_back( rgrl_data_manager_data_item( from_set, to_set,
00058 matcher, weighter,
00059 unwgted_scale_est,
00060 wgted_scale_est,
00061 label ) );
00062 data_.set_dimension_increase_for_next_stage(stage, 1);
00063 }
00064
00065
00066 void
00067 rgrl_data_manager::
00068 add_data( rgrl_feature_set_sptr from_set,
00069 rgrl_feature_set_sptr to_set,
00070 rgrl_matcher_sptr matcher,
00071 rgrl_weighter_sptr weighter,
00072 rgrl_scale_estimator_unwgted_sptr unwgted_scale_est,
00073 rgrl_scale_estimator_wgted_sptr wgted_scale_est,
00074 const vcl_string& label )
00075 {
00076 assert( !multi_stage_ );
00077
00078 unsigned stage = 0;
00079
00080 add_data( stage, from_set, to_set, matcher, weighter,
00081 unwgted_scale_est, wgted_scale_est, label );
00082 }
00083
00084 void
00085 rgrl_data_manager::
00086 add_estimator( unsigned stage,
00087 rgrl_estimator_sptr estimator)
00088 {
00089 data_.add_estimator(stage, estimator);
00090 }
00091
00092
00093 void
00094 rgrl_data_manager::
00095 add_estimator( rgrl_estimator_sptr estimator)
00096 {
00097 assert( !multi_stage_ );
00098
00099 unsigned stage = 0;
00100 data_.add_estimator(stage, estimator);
00101 }
00102
00103
00104 void
00105 rgrl_data_manager::
00106 set_dimension_increase_for_next_stage( unsigned stage, double rate)
00107 {
00108 data_.set_dimension_increase_for_next_stage(stage, rate);
00109 }
00110
00111 double
00112 rgrl_data_manager::
00113 dimension_increase_for_next_stage(unsigned stage) const
00114 {
00115 return data_.dimension_increase_for_next_stage( stage );
00116 }
00117
00118
00119 bool
00120 rgrl_data_manager::
00121 is_multi_feature() const
00122 {
00123 return multi_feature_;
00124 }
00125
00126
00127 void
00128 rgrl_data_manager::
00129 get_data_at_stage( unsigned stage,
00130 vcl_vector<rgrl_feature_set_sptr> & from_sets,
00131 vcl_vector<rgrl_feature_set_sptr> & to_sets,
00132 vcl_vector<rgrl_matcher_sptr> & matchers,
00133 vcl_vector<rgrl_weighter_sptr> & weighters,
00134 vcl_vector<rgrl_scale_estimator_unwgted_sptr> & unwgted_scale_ests,
00135 vcl_vector<rgrl_scale_estimator_wgted_sptr> & wgted_scale_ests,
00136 vcl_vector<rgrl_estimator_sptr> & estimators) const
00137 {
00138 from_sets.clear();
00139 to_sets.clear();
00140 matchers.clear();
00141 weighters.clear();
00142 unwgted_scale_ests.clear();
00143 wgted_scale_ests.clear();
00144 estimators.clear();
00145
00146 if ( data_.has( stage ) ) {
00147 typedef rgrl_data_manager_data_storage::data_vector::const_iterator iter_type;
00148 iter_type itr = data_[stage].begin();
00149 iter_type end = data_[stage].end();
00150 for ( ; itr != end; ++itr ) {
00151 from_sets.push_back( itr->from_set );
00152 to_sets.push_back( itr->to_set );
00153 matchers.push_back( itr->matcher );
00154 unwgted_scale_ests.push_back( itr->unwgted_scale_est );
00155 wgted_scale_ests.push_back( itr->wgted_scale_est );
00156 weighters.push_back( itr->weighter );
00157 }
00158
00159 if ( data_.has_estimator_hierarchy( stage ) )
00160 estimators = data_.estimator_hierarchy( stage );
00161 }
00162 }
00163
00164
00165 void
00166 rgrl_data_manager::
00167 get_data_at_stage( unsigned stage,
00168 rgrl_feature_set_sptr & from_set,
00169 rgrl_feature_set_sptr & to_set,
00170 rgrl_matcher_sptr & matcher,
00171 rgrl_weighter_sptr & weighter,
00172 rgrl_scale_estimator_unwgted_sptr & unwgted_scale_est,
00173 rgrl_scale_estimator_wgted_sptr & wgted_scale_est,
00174 vcl_vector<rgrl_estimator_sptr> & estimators ) const
00175 {
00176 assert( !multi_feature_ );
00177
00178 vcl_vector<rgrl_feature_set_sptr> from_sets;
00179 vcl_vector<rgrl_feature_set_sptr> to_sets;
00180 vcl_vector<rgrl_matcher_sptr> matchers;
00181 vcl_vector<rgrl_weighter_sptr> weighters;
00182 vcl_vector<rgrl_scale_estimator_unwgted_sptr> unwgted_scale_ests;
00183 vcl_vector<rgrl_scale_estimator_wgted_sptr> wgted_scale_ests;
00184
00185 rgrl_data_manager::get_data_at_stage( stage,
00186 from_sets,
00187 to_sets,
00188 matchers,
00189 weighters,
00190 unwgted_scale_ests,
00191 wgted_scale_ests,
00192 estimators );
00193
00194 from_set = from_sets[0];
00195 to_set = to_sets[0];
00196 matcher = matchers[0];
00197 unwgted_scale_est = unwgted_scale_ests[0];
00198 wgted_scale_est = wgted_scale_ests[0];
00199 weighter = weighters[0];
00200 }
00201
00202
00203 void
00204 rgrl_data_manager::
00205 get_data( vcl_vector<rgrl_feature_set_sptr> & from_sets,
00206 vcl_vector<rgrl_feature_set_sptr> & to_sets,
00207 vcl_vector<rgrl_matcher_sptr> & matchers,
00208 vcl_vector<rgrl_weighter_sptr> & weighters,
00209 vcl_vector<rgrl_scale_estimator_unwgted_sptr> & unwgted_scale_ests,
00210 vcl_vector<rgrl_scale_estimator_wgted_sptr> & wgted_scale_ests,
00211 vcl_vector<rgrl_estimator_sptr> & estimators) const
00212 {
00213 assert( !multi_stage_ );
00214
00215 unsigned stage = 0;
00216 get_data_at_stage( stage, from_sets, to_sets, matchers, weighters,
00217 unwgted_scale_ests, wgted_scale_ests, estimators );
00218 }
00219
00220
00221 void
00222 rgrl_data_manager::
00223 get_data( rgrl_feature_set_sptr & from_set,
00224 rgrl_feature_set_sptr & to_set,
00225 rgrl_matcher_sptr & matcher,
00226 rgrl_weighter_sptr & weighter,
00227 rgrl_scale_estimator_unwgted_sptr & unwgted_scale_est,
00228 rgrl_scale_estimator_wgted_sptr & wgted_scale_est,
00229 vcl_vector<rgrl_estimator_sptr> & estimators ) const
00230 {
00231 assert( !multi_stage_ );
00232
00233 unsigned stage = 0;
00234 get_data_at_stage( stage, from_set, to_set, matcher, weighter,
00235 unwgted_scale_est, wgted_scale_est, estimators );
00236 }
00237
00238
00239 bool
00240 rgrl_data_manager::
00241 has_stage(unsigned stage ) const
00242 {
00243 return data_.has( stage );
00244 }
00245
00246 void
00247 rgrl_data_manager::
00248 generate_defaults( rgrl_matcher_sptr &matcher,
00249 rgrl_weighter_sptr &weighter,
00250 rgrl_scale_estimator_unwgted_sptr &unwgted_scale_est )
00251 {
00252
00253 if ( !matcher ) {
00254 matcher = new rgrl_matcher_k_nearest( 1 );
00255 DebugMacro( 1, "Default matcher set to rgrl_matcher_k_nearest( 1 )\n");
00256 }
00257
00258
00259
00260
00261 if ( !weighter ) {
00262 vcl_auto_ptr<rrel_m_est_obj> m_est_obj( new rrel_tukey_obj(4) );
00263 weighter = new rgrl_weighter_m_est(m_est_obj, false, false);
00264 DebugMacro( 1, "Default weighter set to rgrl_weighter_m_est\n");
00265 }
00266
00267
00268 if ( !unwgted_scale_est ) {
00269 vcl_auto_ptr<rrel_objective> lms_obj( new rrel_lms_obj(1) );
00270 unwgted_scale_est = new rgrl_scale_est_closest( lms_obj );
00271 DebugMacro( 1, "Default unwgted scale estimator set to rgrl_scale_est_closest\n");
00272 }
00273 }
00274
00275
00276 void
00277 rgrl_data_manager::
00278 get_label( unsigned stage,
00279 vcl_vector<vcl_string>& labels) const
00280 {
00281 labels.clear();
00282 labels.reserve(10);
00283
00284 if ( data_.has( stage ) ) {
00285 typedef rgrl_data_manager_data_storage::data_vector::const_iterator iter_type;
00286 iter_type itr = data_[stage].begin();
00287 iter_type end = data_[stage].end();
00288 for ( ; itr != end; ++itr ) {
00289 labels.push_back( itr->label );
00290 }
00291 }
00292 }
00293
00294
00295 void
00296 rgrl_data_manager::
00297 get_label( vcl_vector<vcl_string>& labels) const
00298 {
00299 assert( !multi_stage_ );
00300
00301 const unsigned stage = 0;
00302 get_label( stage, labels );
00303 }
00304
00305