# 1277. Count Square Submatrices with All Ones

Given a `m * n` matrix of ones and zeros, return how many square submatrices have all ones.

Example 1:

Input:

``` matrix =
[
[0,1,1,1],
[1,1,1,1],
[0,1,1,1]
]
```

Output:

``` 15
```

Explanation:

```
There are 10 squares of side 1.
There are 4 squares of side 2.
There is one square of side 3.
Total number of squares = 10 + 4 + 1 = 15

```

Example 2:

Input:

``` matrix =
[
[1,0,1],
[1,1,0],
[1,1,0]
]
```

Output:

``` 7
```

Explanation:

```
There are 6 squares of side 1.
There is 1 square of side 2.
Total number of squares = 6 + 1 = 7```

Constraints:

• `1 <= arr.length <= 300`
• `1 <= arr[0].length <= 300`
• `0 <= arr[i][j] <= 1`

Example:

[0,1,1,1],
[1,1,1,1],
[0,1,1,1]

1. count single ones
2. count 2×2
3. count 3×3
4. ……

Goal: reduce n x n to (n-1)x(n-1) … to 2×2

Now scan 2×2 block, for each position(i,j) scan (i+1,j) (i,j+1) (i+1, j+1)
[0,1,1,1],
[1,1,1,1],
[0,1,1,1]

if the 2×2 scanning block contains 0, update (i,j) to 0
[0,1,1,1],
[0,1,1,1],
[0,1,1,1]

then we can reduce it by removing (ignoring) the last row and column.

[0,1,1],
[0,1,1]

Then walk through with the 2×2 scanning block again.

the counter could be updated when the whole 2×2 block is 1, OR count number of ones in the next round.

Version A, update counter during 2×2 scan

```class Solution {
public int countSquares(int[][] matrix) {
int c = 0;

int m  = matrix.length;
int n  = matrix[0].length;

for (int i = 0; i < m; i ++){
for (int j = 0; j < n; j ++){
if (matrix[i][j] == 1){
c++;
}
}
}

while (m > 1 && n > 1){
for (int i = 0; i < m-1; i ++){
for (int j = 0; j < n-1; j ++){
if (  matrix[i][j]     == 0
|| matrix[i+1][j]   == 0
|| matrix[i][j+1]   == 0
|| matrix[i+1][j+1] == 0)
{
matrix[i][j] = 0;
}  else {
c++;
}
}
}
// lower m and n (ignoring last row and column)
m --;
n --;
}
return c;

}
}```

B, a more compact version, update counter next round, count number of ones.
slightly slower for more loops

```class Solution {
public int countSquares(int[][] matrix) {
int c = 0;

int m  = matrix.length;
int n  = matrix[0].length;

while (m > 0 && n > 0){

for (int i = 0; i < m; i ++){
for (int j = 0; j < n; j ++){
if (matrix[i][j] == 1){
c++;
}
}
}

for (int i = 0; i < m-1; i ++){
for (int j = 0; j < n-1; j ++){
if (matrix[i][j] == 0
|| matrix[i+1][j] == 0
|| matrix[i][j+1] == 0
|| matrix[i+1][j+1] == 0){
matrix[i][j] = 0;
}
}
}
// lower m and n (ignoring last row and column)
m --;
n --;
}
return c;

}
}```

A: 86-110 ms
B: 160-170 ms