import java.awt.BorderLayout;
import javax.swing.JFrame;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.Pair;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
public class NistRat43Fitter extends JFrame {
private static final long serialVersionUID = 1L;
public static void main(String[] args) {
NistRat43Fitter frame = new NistRat43Fitter();
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setBounds(10, 10, 640, 480);
frame.setTitle("NIST Rat43 problem");
frame.setVisible(true);
}
public NistRat43Fitter() {
JFreeChart chart =
ChartFactory.createXYLineChart("Levenberg-Marquardt Optimizer",
"x",
"y",
createData(),
PlotOrientation.VERTICAL,
true,
false,
false);
XYPlot plot = chart.getXYPlot();
XYLineAndShapeRenderer renderer =new XYLineAndShapeRenderer();
NumberAxis yNumAxis = (NumberAxis)plot.getRangeAxis();
yNumAxis.setRange(0.0, 800.0);
renderer.setSeriesLinesVisible(0, false);
renderer.setSeriesShapesVisible(1, false);
plot.setRenderer(renderer);
ChartPanel cpanel = new ChartPanel(chart);
getContentPane().add(cpanel, BorderLayout.CENTER);
}
private XYSeriesCollection createData(){
// 1. NIST Rat43 (Ratkowsky3) データセット
final double[] xData = {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, 9.00, 10.00,
11.00, 12.00, 13.00, 14.00, 15.00};
final int n = xData.length;
final double[] yObserved = {16.08, 33.83, 65.80, 97.20, 191.55, 326.20, 386.87, 520.53, 590.03, 651.92,
724.93, 699.56, 689.96, 637.56, 717.41}; // 観測値 y
// 2. モデル関数とヤコビ行列の定義
MultivariateJacobianFunction modelFunction = (RealVector params) -> {
double b1 = params.getEntry(0);
double b2 = params.getEntry(1);
double b3 = params.getEntry(2);
double b4 = params.getEntry(3);
RealVector values = new ArrayRealVector(n);
RealMatrix jacobian = new Array2DRowRealMatrix(n, 4);
for (int i = 0; i < n; i++) {
double xi = xData[i];
// モデルの計算値をセット
values.setEntry(i, b1 / Math.pow(1.0 + Math.exp(b2 - b3 * xi), 1.0 / b4));
// ヤコビ行列 (モデル関数の各パラメータに関する偏導関数) をセット
// d(model)/db1
jacobian.setEntry(i, 0, 1.0 / Math.pow(1.0 + Math.exp(b2 - b3 * xi), 1.0 / b4));
// d(model)/db2
jacobian.setEntry(i, 1, - b1* Math.exp(b2 - b3 * xi) /
(b4*Math.pow(1.0 + Math.exp(b2 - b3 * xi), 1.0 / b4+1.0)));
// d(model)/db3
jacobian.setEntry(i, 2, b1 * xi * Math.exp(b2 - b3 * xi) /
(b4 * Math.pow(1.0 + Math.exp(b2 - b3 * xi), 1.0 / b4 + 1.0)));
// d(model)/db4
jacobian.setEntry(i, 3, Math.log(1.0 + Math.exp(b2 - b3 * xi)) * b1 /
(b4 *b4* Math.pow(1.0 + Math.exp(b2 - b3 * xi), 1.0 / b4 )));
}
return new Pair<>(values, jacobian);
};
// 3. 最小二乗問題の構築
// NISTが提供する初期値 (Start 1)
double[] startPoint = {100.0, 10.0, 1.0, 1.0};
LeastSquaresProblem problem = new LeastSquaresBuilder()
.start(startPoint) // パラメータの初期値
.model(modelFunction) // モデル関数とヤコビ行列
.target(yObserved) // 観測データ (ターゲット)
.lazyEvaluation(false)
.maxEvaluations(1000)
.maxIterations(1000)
.build();
// 4. オプティマイザの作成と最適化の実行
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
LeastSquaresOptimizer.Optimum optimum = optimizer.optimize(problem);
// 5. 結果の表示 (認証値を修正)
System.out.println("NIST Rat43 (Ratkowsky3) 非線形最小二乗法 (Levenberg-Marquardt)");
System.out.println("-----------------------------------------------------");
System.out.println("反復回数: " + optimum.getIterations());
System.out.println("評価回数: " + optimum.getEvaluations());
System.out.println();
// 最適化されたパラメータ
RealVector optimalParams = optimum.getPoint();
System.out.printf("b1 (certified: ~699.64151270): %.8f%n", optimalParams.getEntry(0));
System.out.printf("b2 (certified: ~5.2771253025): %.8f%n", optimalParams.getEntry(1));
System.out.printf("b3 (certified: ~0.75962938329): %.8f%n", optimalParams.getEntry(2));
System.out.printf("b4 (certified: ~1.2792483859): %.8f%n", optimalParams.getEntry(3));
System.out.println();
// 残差平方和 (Residual Sum of Squares) の計算
RealVector residuals = optimum.getResiduals();
double rss = residuals.dotProduct(residuals);
System.out.printf("残差平方和 (certified: ~8786.4049080): %.8f%n", rss);
System.out.println("-----------------------------------------------------");
XYSeriesCollection data = new XYSeriesCollection();
XYSeries series1 = new XYSeries("Original Points");
for (int i = 0 ; i < n ; i++){
series1.add(xData[i], yObserved[i]);
}
XYSeries series2 = new XYSeries("Curve Fitting");
double[] xval = new double[100];
double[] yval = new double[100];
for (int i = 0; i < xval.length; i++) {
xval[i] = 1.0 + (15.0 - (1.0)) * (double)i / (double) (xval.length - 1);
yval[i] = fittedFunction(xval[i], optimalParams.getEntry(0), optimalParams.getEntry(1),
optimalParams.getEntry(2), optimalParams.getEntry(3));
series2.add(xval[i], yval[i]);
}
data.addSeries(series1);
data.addSeries(series2);
return data;
}
double fittedFunction (double xi, double b1, double b2, double b3, double b4) {
return b1 / Math.pow(1.0 + Math.exp(b2 - b3 * xi), 1.0 / b4);
}
}
最近のコメント