SZKの(仮)

2013年02月

今回はテンプレートマッチングを高速化するための方法を紹介したいと思います。

1、残差逐次検定法
→あるしきい値を超えたら加算を打ち切り、次の位置での計算に移る方法。

テンプレートマッチングその2の二乗誤差を計算する部分を再掲します。

for(int y = 0; y < srchH - tempH; y++){
   for(int x = 0; x < srchW - tempW; x++){
        int ssd = 0;     //二乗誤差の和(SSD)
        for(int yt = 0; yt < tempH; yt++){
             for(int xt = 0; xt < tempW; xt++){
                  dif = srch[y + yt][x + xt] - temp[yt][xt];
                   ssd += Math.pow(dif, 2);
              }
         }
    
         if(min_ssd > ssd){
         xpos = x;
         ypos = y;
         min_ssd = ssd;
    }
}

目的は、もっとも小さい二乗誤差(min_ssd)を計算することです。
よって、加算している途中で変数ssdの値がmin_ssdの値を超えたら、まずそこが一致する場所
とは考えられませんので、加算を打ち切って次の位置に移ることが適切でしょう。
これが残差逐次検定法です。

 for(int y = 0; y < srchH - tempH; y++){
        for(int x = 0; x < srchW - tempW; x++){
                int ssd = 0;                    //二乗誤差の和(SSD)
                flag : for(int yt = 0; yt < tempH; yt++){
                    for(int xt = 0; xt < tempW; xt++){
                        dif = srch[y + yt][x + xt] - temp[yt][xt];
                        ssd += dif * dif;
          //もし、min_ssdの値を超えたら、次の位置に移る。
                        if(ssd > min_ssd){
                            continue flag;
                        }

                    }
                }
               
                if(min_ssd > ssd){
                    xpos = x;
                    ypos = y;
                    min_ssd = ssd;
                }
               
       }
  }

青色のところなのですが、以前は
ssd += Math.pow(dif, 2);
となっていたと思います。
ただ、これだとテーラー展開されてしまうらしいので、青色の式に直しました。
以下にまとめたコードを載せます。

TemplateMatching.java
import java.awt.Color; import java.awt.Graphics; import java.awt.Graphics2D; import java.awt.image.BufferedImage; import java.io.File; import javax.imageio.ImageIO; import javax.swing.JFrame; public class TemplateMatching { static int xpos = 0; //テンプレートが一致したときのx座標 static int ypos = 0; //テンプレートが一致したときのy座標 public static void main(String[] args){ //テンプレート、被探索画像の読み込み BufferedImage tempImg = imgRead("./temp.jpg"); BufferedImage srchImg = imgRead("./srch.jpg"); //テンプレート、被探索画像を配列に変換 int[][] temp = imgToArray(tempImg); int[][] srch = imgToArray(srchImg); //テンプレート、被探索画像をグレースケールに変換 int[][] g_temp = trans_grayscale(temp); int[][] g_srImg = trans_grayscale(srch); //テンプレートマッチングを行う templateMatcing(g_temp, g_srImg); //テンプレートマッチングした結果を表示する DispFrame frame = new DispFrame(srchImg, tempImg, xpos, ypos); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); frame.setVisible(true); } //画像ファイルを読み込むメソッド public static BufferedImage imgRead(String file_path){ BufferedImage img = null; try{ img = ImageIO.read(new File(file_path)); return img; }catch(Exception e){ return null; } } //画像ファイルを2次元配列に変換するメソッド public static int[][] imgToArray(BufferedImage img){ int width = img.getWidth(); int height = img.getHeight(); int[][] imgA = new int[height][width]; for(int y = 0; y < height; y++){ for(int x = 0; x < width; x++){ imgA[y][x] = img.getRGB(x, y); //画像上の(x, y)におけるRGB値を取得 } } return imgA; } //RGB値からグレースケールに変換するメソッド public static int[][] trans_grayscale(int[][] img){ int width = img[0].length; int height = img.length; int[][] gray_img = new int[height][width]; for(int y = 0; y < img.length; y++){ for(int x = 0; x < img[0].length; x++){ int rgb = img[y][x] - 0xFF000000; //アルファ値を取り除く int b = (rgb & 0xFF); //青の成分を取得 int g = (rgb & 0xFF00) >> 8; //緑の成分を取得 int r = (rgb & 0xFF0000) >> 16; //赤の成分を取得 int gray = (b + g + r) / 3; //グレーの値に変換 gray_img[y][x] = gray; } } return gray_img; } //テンプレートマッチングするメソッド public static void templateMatcing(int[][] temp, int[][] srch){ int tempW = temp[0].length; int tempH = temp.length; int srchW = srch[0].length; int srchH = srch.length; int min_ssd = Integer.MAX_VALUE; //最小の二乗誤差の和 int dif = 0;       //非探索画像とテンプレートのピクセル単位での差 for(int y = 0; y < srchH - tempH; y++){ for(int x = 0; x < srchW - tempW; x++){ int ssd = 0; //二乗誤差の和(SSD) flag : for(int yt = 0; yt < tempH; yt++){ for(int xt = 0; xt < tempW; xt++){ dif = srch[y + yt][x + xt] - temp[yt][xt]; ssd += dif * dif; //もし、min_ssdの値を超えたら、次の位置に移る。 if(ssd > min_ssd){ continue flag; } } } if(min_ssd > ssd){ xpos = x; ypos = y; min_ssd = ssd; } } } System.out.println("Min_SSD =" + min_ssd); //SSDの最小値を表示 System.out.println("position=" + xpos + "," + ypos); //テンプレートの一致した座標を表示 } } class DispFrame extends JFrame { BufferedImage srch; //フレーム上に表示するための被探索画像 int xpos = 0; //テンプレートが一致したx座標 int ypos = 0; //テンプレートが一致したy座標 int temp_width; int temp_height; DispFrame(BufferedImage srch, BufferedImage temp, int xpos, int ypos){ this.xpos = xpos; this.ypos = ypos; this.srch = srch; temp_width = temp.getWidth(); temp_height = temp.getHeight(); setSize(srch.getWidth(), srch.getHeight()); setTitle("RESULT"); } //非探索画像上のテンプレートが一致したところに四角を囲むメソッド public void paint(Graphics g){ Graphics2D off = srch.createGraphics(); off.setColor(new Color(0,0,255)); //四角の色を青にする off.drawRect(xpos, ypos, temp_width, temp_height); g.drawImage(srch, 0, 0, this); } }

研究で使う必要が出た判別分析法をjavaで実装しました。
とりあえず動かしたい方は、画像のファイルパスを設定してください。
大きい画像(3000×3000くらいの)をいれるとヒープのサイズがオーバーする可能性があるので、
JavaVMのヒープのサイズを512MB,あるいはそれ以上に設定しておくことがよいと思われます。


F1
               図1:原画像

output
                                     図2:処理結果

LinearDiscriminant.java

import java.awt.Graphics; import java.awt.image.BufferedImage; import java.io.File; import java.math.BigDecimal; import javax.imageio.ImageIO; import javax.swing.JFrame; public class LinearDiscriminant { private final int allpix; private int[] hist; private int[] ac; private BufferedImage fimg; public static void main(String[] args){ LinearDiscriminant bin = new LinearDiscriminant(); DispImg frame = new DispImg(bin.getImg()); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); frame.setVisible(true); } public LinearDiscriminant(){ BufferedImage img = imgRead("picture.jpg"); allpix = img.getWidth() * img.getHeight(); double[][] gray = grayscale(img); double[][] smooth = gaussian_filter(gray); hist = histogram(smooth); ac = accumulation(hist); double t = threshold(); System.out.println("threshold = " + t); fimg = binarize(img, gray, t); } private BufferedImage imgRead(String file_path){ BufferedImage img = null; try{ img = ImageIO.read(new File(file_path)); return img; }catch(Exception e){ return null; } } private double[][] grayscale(BufferedImage img){ double[][] gray = new double[img.getHeight()][img.getWidth()]; for(int y = 0; y < img.getHeight(); y++){ for(int x = 0; x < img.getWidth(); x++){ int rgb = img.getRGB(x, y); rgb -= 0xFF000000; int r = (rgb & 0xFF0000) >> 16; int g = (rgb & 0xFF00) >> 8 ; int b = rgb & 0xFF; double l = (double)(b + g + r) / 3; gray[y][x] = l; } } return gray; } private double[][] gaussian_filter(double[][] gray){ double[] GAUSSIAN = {0.0625, 0.125, 0.25}; double[][] buf = new double[gray.length][gray[0].length]; for(int y = 1; y < buf.length - 1; y++){ for(int x = 1; x < buf[0].length - 1; x++){ double sum = 0; sum = (gray[y-1][x-1] + gray[y-1][x+1] + gray[y+1][x-1] + gray[y+1][x+1]) * GAUSSIAN[0] + (gray[y-1][x] + gray[y][x-1] + gray[y][x+1] + gray[y+1][x]) * GAUSSIAN[1] + gray[y][x] * GAUSSIAN[2]; buf[y][x] = sum; } } return buf; } private int[] histogram(double[][] gray){ int[] histogram = new int[256]; for(int y = 0; y < gray.length; y++){ for(int x = 0; x < gray[0].length; x++){ histogram[(int)(Math.round(gray[y][x]))]++; } } return histogram; } private int[] accumulation(int[] hist){ int ac[] = new int[hist.length]; int sum = 0; for(int i = 0; i < hist.length; i++){ sum += hist[i]; ac[i] = sum; } return ac; } private double threshold(){ double threshold = 0; BigDecimal max = BigDecimal.ZERO; BigDecimal separation_metrics; BigDecimal pbB, pwB, mul; for(double t = 0.5; t < 255; t++){ int pb = ac[(int)(t - 0.5)]; int pw = allpix - pb; double mb = classBlackMean(t); double mw = classWhiteMean(t); if(mb == -1 || mw == -1){ continue; } //System.out.println(t + " : " +mb); pbB = BigDecimal.valueOf(pb); pwB = BigDecimal.valueOf(pw); mul = pbB.multiply(pwB); separation_metrics = mul.multiply(BigDecimal.valueOf(Math.pow((mb - mw), 2))); //System.out.println(separation_metrics); if(separation_metrics.compareTo(max) > 0){ max = separation_metrics; //System.out.println("MAX = " + max); threshold = t; } } return threshold; } private double classBlackMean(double t){ double sum = 0; double mean; for(int i = 0; i < t; i++){ sum += hist[i] * i; } if(ac[(int)t] == 0){ return -1; }else{ mean = sum / ac[(int)t]; return mean; } } private double classWhiteMean(double t){ double sum = 0; double mean; for(int i = (int)(t + 0.5); i <= 255; i++){ sum += hist[i] * i; } if(allpix - ac[(int)t] == 0){ return -1; }else{ mean = sum / (allpix - ac[(int)t]); return mean; } } private BufferedImage binarize(BufferedImage img, double[][] gray, double t){ for(int y = 0; y < img.getHeight(); y++){ for(int x = 0; x < img.getWidth(); x++){ if(gray[y][x] > t){ img.setRGB(x, y, 16777215); //white }else{ img.setRGB(x, y, 0); //black } } } return img; } public BufferedImage getImg(){ return fimg; } } class DispImg extends JFrame { BufferedImage img; final int param = 1; DispImg(BufferedImage img){ this.img = img; setSize(img.getWidth() / param + 4, img.getHeight() / param + 38); } public void paint(Graphics g){ g.drawImage(img, 8, 30, img.getWidth() / param, img.getHeight() / param, this); } }

このページのトップヘ