forked from jan/opencl_fdfd
		
	cleanup and improved reporting for csr
This commit is contained in:
		
							parent
							
								
									03f7f6d3c4
								
							
						
					
					
						commit
						bb7b90e938
					
				@ -1,4 +1,5 @@
 | 
				
			|||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
 | 
					from numpy.linalg import norm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pyopencl
 | 
					import pyopencl
 | 
				
			||||||
import pyopencl.array
 | 
					import pyopencl.array
 | 
				
			||||||
@ -52,17 +53,15 @@ def create_ops(context):
 | 
				
			|||||||
    v_out[i] = dot;
 | 
					    v_out[i] = dot;
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    v_out_args = ctype + ' *v_out, int v_len_half'
 | 
					    v_out_args = ctype + ' *v_out'
 | 
				
			||||||
    m_args = 'int m_nnz, int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
 | 
					    m_args = 'int *m_row_ptr, int *m_col_ind, ' + ctype + ' *m_data'
 | 
				
			||||||
    v_in_args = ctype + ' *v_in'
 | 
					    v_in_args = ctype + ' *v_in'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    spmv_kernel = ElementwiseKernel(context, operation=spmv_source, preamble=preamble,
 | 
					    spmv_kernel = ElementwiseKernel(context, operation=spmv_source, preamble=preamble,
 | 
				
			||||||
                                    arguments=', '.join((v_out_args, m_args, v_in_args)))
 | 
					                                    arguments=', '.join((v_out_args, m_args, v_in_args)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def spmv(v_out, m, v_in, e):
 | 
					    def spmv(v_out, m, v_in, e):
 | 
				
			||||||
        return spmv_kernel(v_out, (v_out.size - 1)//2,
 | 
					        return spmv_kernel(v_out, m.row_ptr, m.col_ind, m.data, v_in, wait_for=e)
 | 
				
			||||||
                           m.data.size, m.row_ptr, m.col_ind, m.data,
 | 
					 | 
				
			||||||
                           v_in, wait_for=e)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # -------------------------------------
 | 
					    # -------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -93,7 +92,7 @@ def create_ops(context):
 | 
				
			|||||||
                                dtype_out=ri_dtype,
 | 
					                                dtype_out=ri_dtype,
 | 
				
			||||||
                                neutral='(double3)(0.0, 0.0, 0.0)',
 | 
					                                neutral='(double3)(0.0, 0.0, 0.0)',
 | 
				
			||||||
                                map_expr=update_ri_source,
 | 
					                                map_expr=update_ri_source,
 | 
				
			||||||
                                reduce_expr='a+b',
 | 
					                                reduce_expr='a + b',
 | 
				
			||||||
                                arguments=ctype + ' *r')
 | 
					                                arguments=ctype + ' *r')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def ri_update(r, e):
 | 
					    def ri_update(r, e):
 | 
				
			||||||
@ -149,7 +148,7 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
 | 
				
			|||||||
    ops = create_ops(context)
 | 
					    ops = create_ops(context)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    x = pyopencl.array.zeros(queue, dtype=numpy.complex128, shape=b.shape)
 | 
					    x = pyopencl.array.zeros(queue, dtype=numpy.complex128, shape=b.shape)
 | 
				
			||||||
    v = pyopencl.array.empty_like(x)
 | 
					    v = pyopencl.array.zeros_like(x)
 | 
				
			||||||
    p = pyopencl.array.zeros_like(x)
 | 
					    p = pyopencl.array.zeros_like(x)
 | 
				
			||||||
    r = pyopencl.array.to_device(queue, b)
 | 
					    r = pyopencl.array.to_device(queue, b)
 | 
				
			||||||
    alpha = 1.0 + 0j
 | 
					    alpha = 1.0 + 0j
 | 
				
			||||||
@ -158,20 +157,19 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    m = CSRMatrix(queue, a)
 | 
					    m = CSRMatrix(queue, a)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    e = ops['spmv'](v, m, x, [])
 | 
					    _, err2 = ops['ri_update'](r, [])
 | 
				
			||||||
    e = ops['xr_update'](x, p, r, v, 0.0, [e])
 | 
					 | 
				
			||||||
    _, err2 = ops['ri_update'](r, [e])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    b_norm = numpy.sqrt(err2)
 | 
					    b_norm = numpy.sqrt(err2)
 | 
				
			||||||
    print('b_norm check: ', b_norm)
 | 
					    print('b_norm check: ', b_norm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    start_time2 = time.perf_counter()
 | 
					    start_time2 = time.perf_counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    success = False
 | 
				
			||||||
    for k in range(max_iters):
 | 
					    for k in range(max_iters):
 | 
				
			||||||
        if verbose:
 | 
					        if verbose:
 | 
				
			||||||
            print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ')
 | 
					            print('[{:06d}] rho {:.4} alpha {:4.4}'.format(k, rho, alpha), end=' ')
 | 
				
			||||||
        rho_prev = rho
 | 
					        rho_prev = rho
 | 
				
			||||||
        e = ops['xr_update'](x, p, r, v, alpha, [e])
 | 
					        e = ops['xr_update'](x, p, r, v, alpha, [])
 | 
				
			||||||
        rho, err2 = ops['ri_update'](r, [e])
 | 
					        rho, err2 = ops['ri_update'](r, [e])
 | 
				
			||||||
        errs += [numpy.sqrt(err2) / b_norm]
 | 
					        errs += [numpy.sqrt(err2) / b_norm]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -179,17 +177,35 @@ def cg(a, b, max_iters=10000, err_thresh=1e-6, context=None, queue=None, verbose
 | 
				
			|||||||
            print('err', errs[-1])
 | 
					            print('err', errs[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if errs[-1] < err_thresh:
 | 
					        if errs[-1] < err_thresh:
 | 
				
			||||||
            time_elapsed = time.perf_counter() - start_time
 | 
					            success = True
 | 
				
			||||||
            print('Success, {} iterations in {} sec: {} iterations/sec'.format(k,
 | 
					            break
 | 
				
			||||||
                   time_elapsed, k/time_elapsed))
 | 
					 | 
				
			||||||
            print('overhead', start_time2-start_time)
 | 
					 | 
				
			||||||
            return x.get(), errs, True
 | 
					 | 
				
			||||||
        e = ops['p_update'](p, r, rho/rho_prev, [])
 | 
					        e = ops['p_update'](p, r, rho/rho_prev, [])
 | 
				
			||||||
        e.wait()
 | 
					 | 
				
			||||||
        ops['spmv'](v, m, p, [e]).wait()
 | 
					        ops['spmv'](v, m, p, [e]).wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#        v2 = a @ p.get()
 | 
				
			||||||
 | 
					#        print('norm', norm(v2 - v.get()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        alpha = rho / pyopencl.array.dot(p, v).get()
 | 
					        alpha = rho / pyopencl.array.dot(p, v).get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if k % 1000 == 0:
 | 
					        if k % 1000 == 0:
 | 
				
			||||||
            print(k)
 | 
					            print(k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return x.get(), errs, False
 | 
					    '''
 | 
				
			||||||
 | 
					    Done solving
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    time_elapsed = time.perf_counter() - start_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    x = x.get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if success:
 | 
				
			||||||
 | 
					        print('Success', end='')
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        print('Failure', end=', ')
 | 
				
			||||||
 | 
					    print(', {} iterations in {} sec: {} iterations/sec \
 | 
				
			||||||
 | 
					                  '.format(k, time_elapsed, k / time_elapsed))
 | 
				
			||||||
 | 
					    print('final error', errs[-1])
 | 
				
			||||||
 | 
					    print('overhead {} sec'.format(start_time2 - start_time))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print('Post-everything residual:', norm(a @ x - b) / norm(b))
 | 
				
			||||||
 | 
					    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user